provider.go raw

   1  // Package logincreds implements AWS credential provision for sessions created
   2  // via an `aws login` command.
   3  package logincreds
   4  
   5  import (
   6  	"context"
   7  	"encoding/json"
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  	"os"
  12  
  13  	"github.com/aws/aws-sdk-go-v2/aws"
  14  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
  15  	"github.com/aws/aws-sdk-go-v2/service/signin"
  16  	"github.com/aws/aws-sdk-go-v2/service/signin/types"
  17  )
  18  
  19  // ProviderName identifies the login provider.
  20  const ProviderName = "LoginProvider"
  21  
  22  // TokenAPIClient provides the interface for the login session's token
  23  // retrieval operation.
  24  type TokenAPIClient interface {
  25  	CreateOAuth2Token(context.Context, *signin.CreateOAuth2TokenInput, ...func(*signin.Options)) (*signin.CreateOAuth2TokenOutput, error)
  26  }
  27  
  28  // Provider supplies credentials for an `aws login` session.
  29  type Provider struct {
  30  	options Options
  31  }
  32  
  33  var _ aws.CredentialsProvider = (*Provider)(nil)
  34  
  35  // Options configures the Provider.
  36  type Options struct {
  37  	Client TokenAPIClient
  38  
  39  	// APIOptions to pass to the underlying CreateOAuth2Token operation.
  40  	ClientOptions []func(*signin.Options)
  41  
  42  	// The path to the cached login token.
  43  	CachedTokenFilepath string
  44  
  45  	// The chain of providers that was used to create this provider.
  46  	//
  47  	// These values are for reporting purposes and are not meant to be set up
  48  	// directly.
  49  	CredentialSources []aws.CredentialSource
  50  }
  51  
  52  // New returns a new login session credentials provider.
  53  func New(client TokenAPIClient, path string, opts ...func(*Options)) *Provider {
  54  	options := Options{
  55  		Client:              client,
  56  		CachedTokenFilepath: path,
  57  	}
  58  
  59  	for _, opt := range opts {
  60  		opt(&options)
  61  	}
  62  
  63  	return &Provider{options}
  64  }
  65  
  66  // Retrieve generates a new set of temporary credentials using an `aws login`
  67  // session.
  68  func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
  69  	token, err := p.loadToken()
  70  	if err != nil {
  71  		return aws.Credentials{}, fmt.Errorf("load login token: %w", err)
  72  	}
  73  	if err := token.Validate(); err != nil {
  74  		return aws.Credentials{}, fmt.Errorf("validate login token: %w", err)
  75  	}
  76  
  77  	// the token may have been refreshed elsewhere or the login session might
  78  	// have just been created
  79  	if sdk.NowTime().Before(token.AccessToken.ExpiresAt) {
  80  		return token.Credentials(), nil
  81  	}
  82  
  83  	opts := make([]func(*signin.Options), len(p.options.ClientOptions)+1)
  84  	opts[0] = addSignDPOP(token)
  85  	copy(opts[1:], p.options.ClientOptions)
  86  
  87  	out, err := p.options.Client.CreateOAuth2Token(ctx, &signin.CreateOAuth2TokenInput{
  88  		TokenInput: &types.CreateOAuth2TokenRequestBody{
  89  			ClientId:     aws.String(token.ClientID),
  90  			GrantType:    aws.String("refresh_token"),
  91  			RefreshToken: aws.String(token.RefreshToken),
  92  		},
  93  	}, opts...)
  94  	if err != nil {
  95  		var terr *types.AccessDeniedException
  96  		if errors.As(err, &terr) {
  97  			err = toAccessDeniedError(terr)
  98  		}
  99  		return aws.Credentials{}, fmt.Errorf("create oauth2 token: %w", err)
 100  	}
 101  
 102  	token.Update(out)
 103  	if err := p.saveToken(token); err != nil {
 104  		return aws.Credentials{}, fmt.Errorf("save token: %w", err)
 105  	}
 106  
 107  	return token.Credentials(), nil
 108  }
 109  
 110  // ProviderSources returns the credential chain that was used to construct this
 111  // provider.
 112  func (p *Provider) ProviderSources() []aws.CredentialSource {
 113  	if p.options.CredentialSources == nil {
 114  		return []aws.CredentialSource{aws.CredentialSourceLogin}
 115  	}
 116  	return p.options.CredentialSources
 117  }
 118  
 119  func (p *Provider) loadToken() (*loginToken, error) {
 120  	f, err := openFile(p.options.CachedTokenFilepath)
 121  	if err != nil && os.IsNotExist(err) {
 122  		return nil, fmt.Errorf("token file not found, please reauthenticate")
 123  	}
 124  	if err != nil {
 125  		return nil, err
 126  	}
 127  	defer f.Close()
 128  
 129  	j, err := io.ReadAll(f)
 130  	if err != nil {
 131  		return nil, err
 132  	}
 133  
 134  	var t *loginToken
 135  	if err := json.Unmarshal(j, &t); err != nil {
 136  		return nil, err
 137  	}
 138  
 139  	return t, nil
 140  }
 141  
 142  func (p *Provider) saveToken(token *loginToken) error {
 143  	j, err := json.Marshal(token)
 144  	if err != nil {
 145  		return err
 146  	}
 147  
 148  	f, err := createFile(p.options.CachedTokenFilepath)
 149  	if err != nil {
 150  		return err
 151  	}
 152  	defer f.Close()
 153  
 154  	if _, err := f.Write(j); err != nil {
 155  		return err
 156  	}
 157  
 158  	return nil
 159  }
 160  
 161  func toAccessDeniedError(err *types.AccessDeniedException) error {
 162  	switch err.Error_ {
 163  	case types.OAuth2ErrorCodeTokenExpired:
 164  		return fmt.Errorf("login session has expired, please reauthenticate")
 165  	case types.OAuth2ErrorCodeUserCredentialsChanged:
 166  		return fmt.Errorf("login session password has changed, please reauthenticate")
 167  	case types.OAuth2ErrorCodeInsufficientPermissions:
 168  		return fmt.Errorf("insufficient permissions, you may be missing permissions for the CreateOAuth2Token action")
 169  	default:
 170  		return err
 171  	}
 172  }
 173