sso_credentials_provider.go raw

   1  package ssocreds
   2  
   3  import (
   4  	"context"
   5  	"time"
   6  
   7  	"github.com/aws/aws-sdk-go-v2/aws"
   8  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
   9  	"github.com/aws/aws-sdk-go-v2/service/sso"
  10  )
  11  
  12  // ProviderName is the name of the provider used to specify the source of
  13  // credentials.
  14  const ProviderName = "SSOProvider"
  15  
  16  // GetRoleCredentialsAPIClient is a API client that implements the
  17  // GetRoleCredentials operation.
  18  type GetRoleCredentialsAPIClient interface {
  19  	GetRoleCredentials(context.Context, *sso.GetRoleCredentialsInput, ...func(*sso.Options)) (
  20  		*sso.GetRoleCredentialsOutput, error,
  21  	)
  22  }
  23  
  24  // Options is the Provider options structure.
  25  type Options struct {
  26  	// The Client which is configured for the AWS Region where the AWS SSO user
  27  	// portal is located.
  28  	Client GetRoleCredentialsAPIClient
  29  
  30  	// The AWS account that is assigned to the user.
  31  	AccountID string
  32  
  33  	// The role name that is assigned to the user.
  34  	RoleName string
  35  
  36  	// The URL that points to the organization's AWS Single Sign-On (AWS SSO)
  37  	// user portal.
  38  	StartURL string
  39  
  40  	// The filepath the cached token will be retrieved from. If unset Provider will
  41  	// use the startURL to determine the filepath at.
  42  	//
  43  	//    ~/.aws/sso/cache/<sha1-hex-encoded-startURL>.json
  44  	//
  45  	// If custom cached token filepath is used, the Provider's startUrl
  46  	// parameter will be ignored.
  47  	CachedTokenFilepath string
  48  
  49  	// Used by the SSOCredentialProvider if a token configuration
  50  	// profile is used in the shared config
  51  	SSOTokenProvider *SSOTokenProvider
  52  
  53  	// The chain of providers that was used to create this provider.
  54  	// These values are for reporting purposes and are not meant to be set up directly
  55  	CredentialSources []aws.CredentialSource
  56  }
  57  
  58  // Provider is an AWS credential provider that retrieves temporary AWS
  59  // credentials by exchanging an SSO login token.
  60  type Provider struct {
  61  	options Options
  62  
  63  	cachedTokenFilepath string
  64  }
  65  
  66  // New returns a new AWS Single Sign-On (AWS SSO) credential provider. The
  67  // provided client is expected to be configured for the AWS Region where the
  68  // AWS SSO user portal is located.
  69  func New(client GetRoleCredentialsAPIClient, accountID, roleName, startURL string, optFns ...func(options *Options)) *Provider {
  70  	options := Options{
  71  		Client:    client,
  72  		AccountID: accountID,
  73  		RoleName:  roleName,
  74  		StartURL:  startURL,
  75  	}
  76  
  77  	for _, fn := range optFns {
  78  		fn(&options)
  79  	}
  80  
  81  	return &Provider{
  82  		options:             options,
  83  		cachedTokenFilepath: options.CachedTokenFilepath,
  84  	}
  85  }
  86  
  87  // Retrieve retrieves temporary AWS credentials from the configured Amazon
  88  // Single Sign-On (AWS SSO) user portal by exchanging the accessToken present
  89  // in ~/.aws/sso/cache. However, if a token provider configuration exists
  90  // in the shared config, then we ought to use the token provider rather then
  91  // direct access on the cached token.
  92  func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
  93  	var accessToken *string
  94  	if p.options.SSOTokenProvider != nil {
  95  		token, err := p.options.SSOTokenProvider.RetrieveBearerToken(ctx)
  96  		if err != nil {
  97  			return aws.Credentials{}, err
  98  		}
  99  		accessToken = &token.Value
 100  	} else {
 101  		if p.cachedTokenFilepath == "" {
 102  			cachedTokenFilepath, err := StandardCachedTokenFilepath(p.options.StartURL)
 103  			if err != nil {
 104  				return aws.Credentials{}, &InvalidTokenError{Err: err}
 105  			}
 106  			p.cachedTokenFilepath = cachedTokenFilepath
 107  		}
 108  
 109  		tokenFile, err := loadCachedToken(p.cachedTokenFilepath)
 110  		if err != nil {
 111  			return aws.Credentials{}, &InvalidTokenError{Err: err}
 112  		}
 113  
 114  		if tokenFile.ExpiresAt == nil || sdk.NowTime().After(time.Time(*tokenFile.ExpiresAt)) {
 115  			return aws.Credentials{}, &InvalidTokenError{}
 116  		}
 117  		accessToken = &tokenFile.AccessToken
 118  	}
 119  
 120  	output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{
 121  		AccessToken: accessToken,
 122  		AccountId:   &p.options.AccountID,
 123  		RoleName:    &p.options.RoleName,
 124  	})
 125  	if err != nil {
 126  		return aws.Credentials{}, err
 127  	}
 128  
 129  	return aws.Credentials{
 130  		AccessKeyID:     aws.ToString(output.RoleCredentials.AccessKeyId),
 131  		SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey),
 132  		SessionToken:    aws.ToString(output.RoleCredentials.SessionToken),
 133  		CanExpire:       true,
 134  		Expires:         time.Unix(0, output.RoleCredentials.Expiration*int64(time.Millisecond)).UTC(),
 135  		Source:          ProviderName,
 136  		AccountID:       p.options.AccountID,
 137  	}, nil
 138  }
 139  
 140  // ProviderSources returns the credential chain that was used to construct this provider
 141  func (p *Provider) ProviderSources() []aws.CredentialSource {
 142  	if p.options.CredentialSources == nil {
 143  		return []aws.CredentialSource{aws.CredentialSourceSSO}
 144  	}
 145  	return p.options.CredentialSources
 146  }
 147  
 148  // InvalidTokenError is the error type that is returned if loaded token has
 149  // expired or is otherwise invalid. To refresh the SSO session run AWS SSO
 150  // login with the corresponding profile.
 151  type InvalidTokenError struct {
 152  	Err error
 153  }
 154  
 155  func (i *InvalidTokenError) Unwrap() error {
 156  	return i.Err
 157  }
 158  
 159  func (i *InvalidTokenError) Error() string {
 160  	const msg = "the SSO session has expired or is invalid"
 161  	if i.Err == nil {
 162  		return msg
 163  	}
 164  	return msg + ": " + i.Err.Error()
 165  }
 166