provider.go raw
1 // Package logincreds implements AWS credential provision for sessions created
2 // via an `aws login` command.
3 package logincreds
4
5 import (
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "os"
12
13 "github.com/aws/aws-sdk-go-v2/aws"
14 "github.com/aws/aws-sdk-go-v2/internal/sdk"
15 "github.com/aws/aws-sdk-go-v2/service/signin"
16 "github.com/aws/aws-sdk-go-v2/service/signin/types"
17 )
18
19 // ProviderName identifies the login provider.
20 const ProviderName = "LoginProvider"
21
22 // TokenAPIClient provides the interface for the login session's token
23 // retrieval operation.
24 type TokenAPIClient interface {
25 CreateOAuth2Token(context.Context, *signin.CreateOAuth2TokenInput, ...func(*signin.Options)) (*signin.CreateOAuth2TokenOutput, error)
26 }
27
28 // Provider supplies credentials for an `aws login` session.
29 type Provider struct {
30 options Options
31 }
32
33 var _ aws.CredentialsProvider = (*Provider)(nil)
34
35 // Options configures the Provider.
36 type Options struct {
37 Client TokenAPIClient
38
39 // APIOptions to pass to the underlying CreateOAuth2Token operation.
40 ClientOptions []func(*signin.Options)
41
42 // The path to the cached login token.
43 CachedTokenFilepath string
44
45 // The chain of providers that was used to create this provider.
46 //
47 // These values are for reporting purposes and are not meant to be set up
48 // directly.
49 CredentialSources []aws.CredentialSource
50 }
51
52 // New returns a new login session credentials provider.
53 func New(client TokenAPIClient, path string, opts ...func(*Options)) *Provider {
54 options := Options{
55 Client: client,
56 CachedTokenFilepath: path,
57 }
58
59 for _, opt := range opts {
60 opt(&options)
61 }
62
63 return &Provider{options}
64 }
65
66 // Retrieve generates a new set of temporary credentials using an `aws login`
67 // session.
68 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
69 token, err := p.loadToken()
70 if err != nil {
71 return aws.Credentials{}, fmt.Errorf("load login token: %w", err)
72 }
73 if err := token.Validate(); err != nil {
74 return aws.Credentials{}, fmt.Errorf("validate login token: %w", err)
75 }
76
77 // the token may have been refreshed elsewhere or the login session might
78 // have just been created
79 if sdk.NowTime().Before(token.AccessToken.ExpiresAt) {
80 return token.Credentials(), nil
81 }
82
83 opts := make([]func(*signin.Options), len(p.options.ClientOptions)+1)
84 opts[0] = addSignDPOP(token)
85 copy(opts[1:], p.options.ClientOptions)
86
87 out, err := p.options.Client.CreateOAuth2Token(ctx, &signin.CreateOAuth2TokenInput{
88 TokenInput: &types.CreateOAuth2TokenRequestBody{
89 ClientId: aws.String(token.ClientID),
90 GrantType: aws.String("refresh_token"),
91 RefreshToken: aws.String(token.RefreshToken),
92 },
93 }, opts...)
94 if err != nil {
95 var terr *types.AccessDeniedException
96 if errors.As(err, &terr) {
97 err = toAccessDeniedError(terr)
98 }
99 return aws.Credentials{}, fmt.Errorf("create oauth2 token: %w", err)
100 }
101
102 token.Update(out)
103 if err := p.saveToken(token); err != nil {
104 return aws.Credentials{}, fmt.Errorf("save token: %w", err)
105 }
106
107 return token.Credentials(), nil
108 }
109
110 // ProviderSources returns the credential chain that was used to construct this
111 // provider.
112 func (p *Provider) ProviderSources() []aws.CredentialSource {
113 if p.options.CredentialSources == nil {
114 return []aws.CredentialSource{aws.CredentialSourceLogin}
115 }
116 return p.options.CredentialSources
117 }
118
119 func (p *Provider) loadToken() (*loginToken, error) {
120 f, err := openFile(p.options.CachedTokenFilepath)
121 if err != nil && os.IsNotExist(err) {
122 return nil, fmt.Errorf("token file not found, please reauthenticate")
123 }
124 if err != nil {
125 return nil, err
126 }
127 defer f.Close()
128
129 j, err := io.ReadAll(f)
130 if err != nil {
131 return nil, err
132 }
133
134 var t *loginToken
135 if err := json.Unmarshal(j, &t); err != nil {
136 return nil, err
137 }
138
139 return t, nil
140 }
141
142 func (p *Provider) saveToken(token *loginToken) error {
143 j, err := json.Marshal(token)
144 if err != nil {
145 return err
146 }
147
148 f, err := createFile(p.options.CachedTokenFilepath)
149 if err != nil {
150 return err
151 }
152 defer f.Close()
153
154 if _, err := f.Write(j); err != nil {
155 return err
156 }
157
158 return nil
159 }
160
161 func toAccessDeniedError(err *types.AccessDeniedException) error {
162 switch err.Error_ {
163 case types.OAuth2ErrorCodeTokenExpired:
164 return fmt.Errorf("login session has expired, please reauthenticate")
165 case types.OAuth2ErrorCodeUserCredentialsChanged:
166 return fmt.Errorf("login session password has changed, please reauthenticate")
167 case types.OAuth2ErrorCodeInsufficientPermissions:
168 return fmt.Errorf("insufficient permissions, you may be missing permissions for the CreateOAuth2Token action")
169 default:
170 return err
171 }
172 }
173