token_cache.go raw

   1  package bearer
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  	"sync/atomic"
   7  	"time"
   8  
   9  	smithycontext "github.com/aws/smithy-go/context"
  10  	"github.com/aws/smithy-go/internal/sync/singleflight"
  11  )
  12  
  13  // package variable that can be override in unit tests.
  14  var timeNow = time.Now
  15  
  16  // TokenCacheOptions provides a set of optional configuration options for the
  17  // TokenCache TokenProvider.
  18  type TokenCacheOptions struct {
  19  	// The duration before the token will expire when the credentials will be
  20  	// refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls
  21  	// will be blocking.
  22  	//
  23  	// Asynchronous refreshes are deduplicated, and only one will be in-flight
  24  	// at a time. If the token expires while an asynchronous refresh is in
  25  	// flight, the next call to RetrieveBearerToken will block on that refresh
  26  	// to return.
  27  	RefreshBeforeExpires time.Duration
  28  
  29  	// The timeout the underlying TokenProvider's RetrieveBearerToken call must
  30  	// return within, or will be canceled. Defaults to 0, no timeout.
  31  	//
  32  	// If 0 timeout, its possible for the underlying tokenProvider's
  33  	// RetrieveBearerToken call to block forever. Preventing subsequent
  34  	// TokenCache attempts to refresh the token.
  35  	//
  36  	// If this timeout is reached all pending deduplicated calls to
  37  	// TokenCache RetrieveBearerToken will fail with an error.
  38  	RetrieveBearerTokenTimeout time.Duration
  39  
  40  	// The minimum duration between asynchronous refresh attempts. If the next
  41  	// asynchronous recent refresh attempt was within the minimum delay
  42  	// duration, the call to retrieve will return the current cached token, if
  43  	// not expired.
  44  	//
  45  	// The asynchronous retrieve is deduplicated across multiple calls when
  46  	// RetrieveBearerToken is called. The asynchronous retrieve is not a
  47  	// periodic task. It is only performed when the token has not yet expired,
  48  	// and the current item is within the RefreshBeforeExpires window, and the
  49  	// TokenCache's RetrieveBearerToken method is called.
  50  	//
  51  	// If 0, (default) there will be no minimum delay between asynchronous
  52  	// refresh attempts.
  53  	//
  54  	// If DisableAsyncRefresh is true, this option is ignored.
  55  	AsyncRefreshMinimumDelay time.Duration
  56  
  57  	// Sets if the TokenCache will attempt to refresh the token in the
  58  	// background asynchronously instead of blocking for credentials to be
  59  	// refreshed. If disabled token refresh will be blocking.
  60  	//
  61  	// The first call to RetrieveBearerToken will always be blocking, because
  62  	// there is no cached token.
  63  	DisableAsyncRefresh bool
  64  }
  65  
  66  // TokenCache provides an utility to cache Bearer Authentication tokens from a
  67  // wrapped TokenProvider. The TokenCache can be has options to configure the
  68  // cache's early and asynchronous refresh of the token.
  69  type TokenCache struct {
  70  	options  TokenCacheOptions
  71  	provider TokenProvider
  72  
  73  	cachedToken            atomic.Value
  74  	lastRefreshAttemptTime atomic.Value
  75  	sfGroup                singleflight.Group
  76  }
  77  
  78  // NewTokenCache returns a initialized TokenCache that implements the
  79  // TokenProvider interface. Wrapping the provider passed in. Also taking a set
  80  // of optional functional option parameters to configure the token cache.
  81  func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache {
  82  	var options TokenCacheOptions
  83  	for _, fn := range optFns {
  84  		fn(&options)
  85  	}
  86  
  87  	return &TokenCache{
  88  		options:  options,
  89  		provider: provider,
  90  	}
  91  }
  92  
  93  // RetrieveBearerToken returns the token if it could be obtained, or error if a
  94  // valid token could not be retrieved.
  95  //
  96  // The passed in Context's cancel/deadline/timeout will impacting only this
  97  // individual retrieve call and not any other already queued up calls. This
  98  // means underlying provider's RetrieveBearerToken calls could block for ever,
  99  // and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
 100  // provide a timeout, preventing the underlying TokenProvider blocking forever.
 101  //
 102  // By default, if the passed in Context is canceled, all of its values will be
 103  // considered expired. The wrapped TokenProvider will not be able to lookup the
 104  // values from the Context once it is expired. This is done to protect against
 105  // expired values no longer being valid. To disable this behavior, use
 106  // smithy-go's context.WithPreserveExpiredValues to add a value to the Context
 107  // before calling RetrieveBearerToken to enable support for expired values.
 108  //
 109  // Without RetrieveBearerTokenTimeout there is the potential for a underlying
 110  // Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
 111  // attempts at refreshing the token.
 112  func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) {
 113  	cachedToken, ok := p.getCachedToken()
 114  	if !ok || cachedToken.Expired(timeNow()) {
 115  		return p.refreshBearerToken(ctx)
 116  	}
 117  
 118  	// Check if the token should be refreshed before it expires.
 119  	refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires))
 120  	if !refreshToken {
 121  		return cachedToken, nil
 122  	}
 123  
 124  	if p.options.DisableAsyncRefresh {
 125  		return p.refreshBearerToken(ctx)
 126  	}
 127  
 128  	p.tryAsyncRefresh(ctx)
 129  
 130  	return cachedToken, nil
 131  }
 132  
 133  // tryAsyncRefresh attempts to asynchronously refresh the token returning the
 134  // already cached token. If it AsyncRefreshMinimumDelay option is not zero, and
 135  // the duration since the last refresh is less than that value, nothing will be
 136  // done.
 137  func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {
 138  	if p.options.AsyncRefreshMinimumDelay != 0 {
 139  		var lastRefreshAttempt time.Time
 140  		if v := p.lastRefreshAttemptTime.Load(); v != nil {
 141  			lastRefreshAttempt = v.(time.Time)
 142  		}
 143  
 144  		if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) {
 145  			return
 146  		}
 147  	}
 148  
 149  	// Ignore the returned channel so this won't be blocking, and limit the
 150  	// number of additional goroutines created.
 151  	p.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
 152  		res, err := p.refreshBearerToken(ctx)
 153  		if p.options.AsyncRefreshMinimumDelay != 0 {
 154  			var refreshAttempt time.Time
 155  			if err != nil {
 156  				refreshAttempt = timeNow()
 157  			}
 158  			p.lastRefreshAttemptTime.Store(refreshAttempt)
 159  		}
 160  
 161  		return res, err
 162  	})
 163  }
 164  
 165  func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
 166  	resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
 167  		ctx := smithycontext.WithSuppressCancel(ctx)
 168  		if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
 169  			var cancel func()
 170  			ctx, cancel = context.WithTimeout(ctx, v)
 171  			defer cancel()
 172  		}
 173  		return p.singleRetrieve(ctx)
 174  	})
 175  
 176  	select {
 177  	case res := <-resCh:
 178  		return res.Val.(Token), res.Err
 179  	case <-ctx.Done():
 180  		return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err())
 181  	}
 182  }
 183  
 184  func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) {
 185  	token, err := p.provider.RetrieveBearerToken(ctx)
 186  	if err != nil {
 187  		return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err)
 188  	}
 189  
 190  	p.cachedToken.Store(&token)
 191  	return token, nil
 192  }
 193  
 194  // getCachedToken returns the currently cached token and true if found. Returns
 195  // false if no token is cached.
 196  func (p *TokenCache) getCachedToken() (Token, bool) {
 197  	v := p.cachedToken.Load()
 198  	if v == nil {
 199  		return Token{}, false
 200  	}
 201  
 202  	t := v.(*Token)
 203  	if t == nil || t.Value == "" {
 204  		return Token{}, false
 205  	}
 206  
 207  	return *t, true
 208  }
 209