token_provider.go raw

   1  package imds
   2  
   3  import (
   4  	"context"
   5  	"errors"
   6  	"fmt"
   7  	"github.com/aws/aws-sdk-go-v2/aws"
   8  	"github.com/aws/smithy-go"
   9  	"github.com/aws/smithy-go/logging"
  10  	"net/http"
  11  	"sync"
  12  	"sync/atomic"
  13  	"time"
  14  
  15  	"github.com/aws/smithy-go/middleware"
  16  	smithyhttp "github.com/aws/smithy-go/transport/http"
  17  )
  18  
  19  const (
  20  	// Headers for Token and TTL
  21  	tokenHeader     = "x-aws-ec2-metadata-token"
  22  	defaultTokenTTL = 5 * time.Minute
  23  )
  24  
  25  type tokenProvider struct {
  26  	client   *Client
  27  	tokenTTL time.Duration
  28  
  29  	token    *apiToken
  30  	tokenMux sync.RWMutex
  31  
  32  	disabled uint32 // Atomic updated
  33  }
  34  
  35  func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
  36  	return &tokenProvider{
  37  		client:   client,
  38  		tokenTTL: ttl,
  39  	}
  40  }
  41  
  42  // apiToken provides the API token used by all operation calls for th EC2
  43  // Instance metadata service.
  44  type apiToken struct {
  45  	token   string
  46  	expires time.Time
  47  }
  48  
  49  var timeNow = time.Now
  50  
  51  // Expired returns if the token is expired.
  52  func (t *apiToken) Expired() bool {
  53  	// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
  54  	// time is always based on reported wall-clock time.
  55  	return timeNow().Round(0).After(t.expires)
  56  }
  57  
  58  func (t *tokenProvider) ID() string { return "APITokenProvider" }
  59  
  60  // HandleFinalize is the finalize stack middleware, that if the token provider is
  61  // enabled, will attempt to add the cached API token to the request. If the API
  62  // token is not cached, it will be retrieved in a separate API call, getToken.
  63  //
  64  // For retry attempts, handler must be added after attempt retryer.
  65  //
  66  // If request for getToken fails the token provider may be disabled from future
  67  // requests, depending on the response status code.
  68  func (t *tokenProvider) HandleFinalize(
  69  	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
  70  ) (
  71  	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
  72  ) {
  73  	if t.fallbackEnabled() && !t.enabled() {
  74  		// short-circuits to insecure data flow if token provider is disabled.
  75  		return next.HandleFinalize(ctx, input)
  76  	}
  77  
  78  	req, ok := input.Request.(*smithyhttp.Request)
  79  	if !ok {
  80  		return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
  81  	}
  82  
  83  	tok, err := t.getToken(ctx)
  84  	if err != nil {
  85  		// If the error allows the token to downgrade to insecure flow allow that.
  86  		var bypassErr *bypassTokenRetrievalError
  87  		if errors.As(err, &bypassErr) {
  88  			return next.HandleFinalize(ctx, input)
  89  		}
  90  
  91  		return out, metadata, fmt.Errorf("failed to get API token, %w", err)
  92  	}
  93  
  94  	req.Header.Set(tokenHeader, tok.token)
  95  
  96  	return next.HandleFinalize(ctx, input)
  97  }
  98  
  99  // HandleDeserialize is the deserialize stack middleware for determining if the
 100  // operation the token provider is decorating failed because of a 401
 101  // unauthorized status code. If the operation failed for that reason the token
 102  // provider needs to be re-enabled so that it can start adding the API token to
 103  // operation calls.
 104  func (t *tokenProvider) HandleDeserialize(
 105  	ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
 106  ) (
 107  	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
 108  ) {
 109  	out, metadata, err = next.HandleDeserialize(ctx, input)
 110  	if err == nil {
 111  		return out, metadata, err
 112  	}
 113  
 114  	resp, ok := out.RawResponse.(*smithyhttp.Response)
 115  	if !ok {
 116  		return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
 117  	}
 118  
 119  	if resp.StatusCode == http.StatusUnauthorized { // unauthorized
 120  		t.enable()
 121  		err = &retryableError{Err: err, isRetryable: true}
 122  	}
 123  
 124  	return out, metadata, err
 125  }
 126  
 127  func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
 128  	if t.fallbackEnabled() && !t.enabled() {
 129  		return nil, &bypassTokenRetrievalError{
 130  			Err: fmt.Errorf("cannot get API token, provider disabled"),
 131  		}
 132  	}
 133  
 134  	t.tokenMux.RLock()
 135  	tok = t.token
 136  	t.tokenMux.RUnlock()
 137  
 138  	if tok != nil && !tok.Expired() {
 139  		return tok, nil
 140  	}
 141  
 142  	tok, err = t.updateToken(ctx)
 143  	if err != nil {
 144  		return nil, err
 145  	}
 146  
 147  	return tok, nil
 148  }
 149  
 150  func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
 151  	t.tokenMux.Lock()
 152  	defer t.tokenMux.Unlock()
 153  
 154  	// Prevent multiple requests to update retrieving the token.
 155  	if t.token != nil && !t.token.Expired() {
 156  		tok := t.token
 157  		return tok, nil
 158  	}
 159  
 160  	result, err := t.client.getToken(ctx, &getTokenInput{
 161  		TokenTTL: t.tokenTTL,
 162  	})
 163  	if err != nil {
 164  		var statusErr interface{ HTTPStatusCode() int }
 165  		if errors.As(err, &statusErr) {
 166  			switch statusErr.HTTPStatusCode() {
 167  			// Disable future get token if failed because of 403, 404, or 405
 168  			case http.StatusForbidden,
 169  				http.StatusNotFound,
 170  				http.StatusMethodNotAllowed:
 171  
 172  				if t.fallbackEnabled() {
 173  					logger := middleware.GetLogger(ctx)
 174  					logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
 175  					t.disable()
 176  				}
 177  
 178  			// 400 errors are terminal, and need to be upstreamed
 179  			case http.StatusBadRequest:
 180  				return nil, err
 181  			}
 182  		}
 183  
 184  		// Disable if request send failed or timed out getting response
 185  		var re *smithyhttp.RequestSendError
 186  		var ce *smithy.CanceledError
 187  		if errors.As(err, &re) || errors.As(err, &ce) {
 188  			atomic.StoreUint32(&t.disabled, 1)
 189  		}
 190  
 191  		if !t.fallbackEnabled() {
 192  			// NOTE: getToken() is an implementation detail of some outer operation
 193  			// (e.g. GetMetadata). It has its own retries that have already been exhausted.
 194  			// Mark the underlying error as a terminal error.
 195  			err = &retryableError{Err: err, isRetryable: false}
 196  			return nil, err
 197  		}
 198  
 199  		// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
 200  		// and allow the request to proceed. Future requests _may_ re-attempt fetching a
 201  		// token if not disabled.
 202  		return nil, &bypassTokenRetrievalError{Err: err}
 203  	}
 204  
 205  	tok := &apiToken{
 206  		token:   result.Token,
 207  		expires: timeNow().Add(result.TokenTTL),
 208  	}
 209  	t.token = tok
 210  
 211  	return tok, nil
 212  }
 213  
 214  // enabled returns if the token provider is current enabled or not.
 215  func (t *tokenProvider) enabled() bool {
 216  	return atomic.LoadUint32(&t.disabled) == 0
 217  }
 218  
 219  // fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
 220  func (t *tokenProvider) fallbackEnabled() bool {
 221  	switch t.client.options.EnableFallback {
 222  	case aws.FalseTernary:
 223  		return false
 224  	default:
 225  		return true
 226  	}
 227  }
 228  
 229  // disable disables the token provider and it will no longer attempt to inject
 230  // the token, nor request updates.
 231  func (t *tokenProvider) disable() {
 232  	atomic.StoreUint32(&t.disabled, 1)
 233  }
 234  
 235  // enable enables the token provide to start refreshing tokens, and adding them
 236  // to the pending request.
 237  func (t *tokenProvider) enable() {
 238  	t.tokenMux.Lock()
 239  	t.token = nil
 240  	t.tokenMux.Unlock()
 241  	atomic.StoreUint32(&t.disabled, 0)
 242  }
 243  
 244  type bypassTokenRetrievalError struct {
 245  	Err error
 246  }
 247  
 248  func (e *bypassTokenRetrievalError) Error() string {
 249  	return fmt.Sprintf("bypass token retrieval, %v", e.Err)
 250  }
 251  
 252  func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
 253  
 254  type retryableError struct {
 255  	Err         error
 256  	isRetryable bool
 257  }
 258  
 259  func (e *retryableError) RetryableError() bool { return e.isRetryable }
 260  
 261  func (e *retryableError) Error() string { return e.Err.Error() }
 262