token_provider.go raw
1 package imds
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "github.com/aws/aws-sdk-go-v2/aws"
8 "github.com/aws/smithy-go"
9 "github.com/aws/smithy-go/logging"
10 "net/http"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/aws/smithy-go/middleware"
16 smithyhttp "github.com/aws/smithy-go/transport/http"
17 )
18
19 const (
20 // Headers for Token and TTL
21 tokenHeader = "x-aws-ec2-metadata-token"
22 defaultTokenTTL = 5 * time.Minute
23 )
24
25 type tokenProvider struct {
26 client *Client
27 tokenTTL time.Duration
28
29 token *apiToken
30 tokenMux sync.RWMutex
31
32 disabled uint32 // Atomic updated
33 }
34
35 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
36 return &tokenProvider{
37 client: client,
38 tokenTTL: ttl,
39 }
40 }
41
42 // apiToken provides the API token used by all operation calls for th EC2
43 // Instance metadata service.
44 type apiToken struct {
45 token string
46 expires time.Time
47 }
48
49 var timeNow = time.Now
50
51 // Expired returns if the token is expired.
52 func (t *apiToken) Expired() bool {
53 // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
54 // time is always based on reported wall-clock time.
55 return timeNow().Round(0).After(t.expires)
56 }
57
58 func (t *tokenProvider) ID() string { return "APITokenProvider" }
59
60 // HandleFinalize is the finalize stack middleware, that if the token provider is
61 // enabled, will attempt to add the cached API token to the request. If the API
62 // token is not cached, it will be retrieved in a separate API call, getToken.
63 //
64 // For retry attempts, handler must be added after attempt retryer.
65 //
66 // If request for getToken fails the token provider may be disabled from future
67 // requests, depending on the response status code.
68 func (t *tokenProvider) HandleFinalize(
69 ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
70 ) (
71 out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
72 ) {
73 if t.fallbackEnabled() && !t.enabled() {
74 // short-circuits to insecure data flow if token provider is disabled.
75 return next.HandleFinalize(ctx, input)
76 }
77
78 req, ok := input.Request.(*smithyhttp.Request)
79 if !ok {
80 return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
81 }
82
83 tok, err := t.getToken(ctx)
84 if err != nil {
85 // If the error allows the token to downgrade to insecure flow allow that.
86 var bypassErr *bypassTokenRetrievalError
87 if errors.As(err, &bypassErr) {
88 return next.HandleFinalize(ctx, input)
89 }
90
91 return out, metadata, fmt.Errorf("failed to get API token, %w", err)
92 }
93
94 req.Header.Set(tokenHeader, tok.token)
95
96 return next.HandleFinalize(ctx, input)
97 }
98
99 // HandleDeserialize is the deserialize stack middleware for determining if the
100 // operation the token provider is decorating failed because of a 401
101 // unauthorized status code. If the operation failed for that reason the token
102 // provider needs to be re-enabled so that it can start adding the API token to
103 // operation calls.
104 func (t *tokenProvider) HandleDeserialize(
105 ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
106 ) (
107 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
108 ) {
109 out, metadata, err = next.HandleDeserialize(ctx, input)
110 if err == nil {
111 return out, metadata, err
112 }
113
114 resp, ok := out.RawResponse.(*smithyhttp.Response)
115 if !ok {
116 return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
117 }
118
119 if resp.StatusCode == http.StatusUnauthorized { // unauthorized
120 t.enable()
121 err = &retryableError{Err: err, isRetryable: true}
122 }
123
124 return out, metadata, err
125 }
126
127 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
128 if t.fallbackEnabled() && !t.enabled() {
129 return nil, &bypassTokenRetrievalError{
130 Err: fmt.Errorf("cannot get API token, provider disabled"),
131 }
132 }
133
134 t.tokenMux.RLock()
135 tok = t.token
136 t.tokenMux.RUnlock()
137
138 if tok != nil && !tok.Expired() {
139 return tok, nil
140 }
141
142 tok, err = t.updateToken(ctx)
143 if err != nil {
144 return nil, err
145 }
146
147 return tok, nil
148 }
149
150 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
151 t.tokenMux.Lock()
152 defer t.tokenMux.Unlock()
153
154 // Prevent multiple requests to update retrieving the token.
155 if t.token != nil && !t.token.Expired() {
156 tok := t.token
157 return tok, nil
158 }
159
160 result, err := t.client.getToken(ctx, &getTokenInput{
161 TokenTTL: t.tokenTTL,
162 })
163 if err != nil {
164 var statusErr interface{ HTTPStatusCode() int }
165 if errors.As(err, &statusErr) {
166 switch statusErr.HTTPStatusCode() {
167 // Disable future get token if failed because of 403, 404, or 405
168 case http.StatusForbidden,
169 http.StatusNotFound,
170 http.StatusMethodNotAllowed:
171
172 if t.fallbackEnabled() {
173 logger := middleware.GetLogger(ctx)
174 logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
175 t.disable()
176 }
177
178 // 400 errors are terminal, and need to be upstreamed
179 case http.StatusBadRequest:
180 return nil, err
181 }
182 }
183
184 // Disable if request send failed or timed out getting response
185 var re *smithyhttp.RequestSendError
186 var ce *smithy.CanceledError
187 if errors.As(err, &re) || errors.As(err, &ce) {
188 atomic.StoreUint32(&t.disabled, 1)
189 }
190
191 if !t.fallbackEnabled() {
192 // NOTE: getToken() is an implementation detail of some outer operation
193 // (e.g. GetMetadata). It has its own retries that have already been exhausted.
194 // Mark the underlying error as a terminal error.
195 err = &retryableError{Err: err, isRetryable: false}
196 return nil, err
197 }
198
199 // Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
200 // and allow the request to proceed. Future requests _may_ re-attempt fetching a
201 // token if not disabled.
202 return nil, &bypassTokenRetrievalError{Err: err}
203 }
204
205 tok := &apiToken{
206 token: result.Token,
207 expires: timeNow().Add(result.TokenTTL),
208 }
209 t.token = tok
210
211 return tok, nil
212 }
213
214 // enabled returns if the token provider is current enabled or not.
215 func (t *tokenProvider) enabled() bool {
216 return atomic.LoadUint32(&t.disabled) == 0
217 }
218
219 // fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
220 func (t *tokenProvider) fallbackEnabled() bool {
221 switch t.client.options.EnableFallback {
222 case aws.FalseTernary:
223 return false
224 default:
225 return true
226 }
227 }
228
229 // disable disables the token provider and it will no longer attempt to inject
230 // the token, nor request updates.
231 func (t *tokenProvider) disable() {
232 atomic.StoreUint32(&t.disabled, 1)
233 }
234
235 // enable enables the token provide to start refreshing tokens, and adding them
236 // to the pending request.
237 func (t *tokenProvider) enable() {
238 t.tokenMux.Lock()
239 t.token = nil
240 t.tokenMux.Unlock()
241 atomic.StoreUint32(&t.disabled, 0)
242 }
243
244 type bypassTokenRetrievalError struct {
245 Err error
246 }
247
248 func (e *bypassTokenRetrievalError) Error() string {
249 return fmt.Sprintf("bypass token retrieval, %v", e.Err)
250 }
251
252 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
253
254 type retryableError struct {
255 Err error
256 isRetryable bool
257 }
258
259 func (e *retryableError) RetryableError() bool { return e.isRetryable }
260
261 func (e *retryableError) Error() string { return e.Err.Error() }
262