chained_token_credential.go raw

   1  //go:build go1.18
   2  // +build go1.18
   3  
   4  // Copyright (c) Microsoft Corporation. All rights reserved.
   5  // Licensed under the MIT License.
   6  
   7  package azidentity
   8  
   9  import (
  10  	"context"
  11  	"errors"
  12  	"fmt"
  13  	"strings"
  14  	"sync"
  15  
  16  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
  17  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
  18  	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
  19  )
  20  
  21  // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
  22  type ChainedTokenCredentialOptions struct {
  23  	// RetrySources configures how the credential uses its sources. When true, the credential always attempts to
  24  	// authenticate through each source in turn, stopping when one succeeds. When false, the credential authenticates
  25  	// only through this first successful source--it never again tries the sources which failed.
  26  	RetrySources bool
  27  }
  28  
  29  // ChainedTokenCredential links together multiple credentials and tries them sequentially when authenticating. By default,
  30  // it tries all the credentials until one authenticates, after which it always uses that credential. For more information,
  31  // see [ChainedTokenCredential overview].
  32  //
  33  // [ChainedTokenCredential overview]: https://aka.ms/azsdk/go/identity/credential-chains#chainedtokencredential-overview
  34  type ChainedTokenCredential struct {
  35  	cond                 *sync.Cond
  36  	iterating            bool
  37  	name                 string
  38  	retrySources         bool
  39  	sources              []azcore.TokenCredential
  40  	successfulCredential azcore.TokenCredential
  41  }
  42  
  43  // NewChainedTokenCredential creates a ChainedTokenCredential. Pass nil for options to accept defaults.
  44  func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
  45  	if len(sources) == 0 {
  46  		return nil, errors.New("sources must contain at least one TokenCredential")
  47  	}
  48  	for _, source := range sources {
  49  		if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
  50  			return nil, errors.New("sources cannot contain nil")
  51  		}
  52  		if mc, ok := source.(*ManagedIdentityCredential); ok {
  53  			mc.mic.chained = true
  54  		}
  55  	}
  56  	cp := make([]azcore.TokenCredential, len(sources))
  57  	copy(cp, sources)
  58  	if options == nil {
  59  		options = &ChainedTokenCredentialOptions{}
  60  	}
  61  	return &ChainedTokenCredential{
  62  		cond:         sync.NewCond(&sync.Mutex{}),
  63  		name:         "ChainedTokenCredential",
  64  		retrySources: options.RetrySources,
  65  		sources:      cp,
  66  	}, nil
  67  }
  68  
  69  // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token.
  70  // This method is called automatically by Azure SDK clients.
  71  func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
  72  	if !c.retrySources {
  73  		// ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
  74  		c.cond.L.Lock()
  75  		for {
  76  			if c.successfulCredential != nil {
  77  				c.cond.L.Unlock()
  78  				return c.successfulCredential.GetToken(ctx, opts)
  79  			}
  80  			if !c.iterating {
  81  				c.iterating = true
  82  				// allow other goroutines to wait while this one iterates
  83  				c.cond.L.Unlock()
  84  				break
  85  			}
  86  			c.cond.Wait()
  87  		}
  88  	}
  89  
  90  	var (
  91  		err                  error
  92  		errs                 []error
  93  		successfulCredential azcore.TokenCredential
  94  		token                azcore.AccessToken
  95  		unavailableErr       credentialUnavailable
  96  	)
  97  	for _, cred := range c.sources {
  98  		token, err = cred.GetToken(ctx, opts)
  99  		if err == nil {
 100  			log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
 101  			successfulCredential = cred
 102  			break
 103  		}
 104  		errs = append(errs, err)
 105  		// continue to the next source iff this one returned credentialUnavailableError
 106  		if !errors.As(err, &unavailableErr) {
 107  			break
 108  		}
 109  	}
 110  	if c.iterating {
 111  		c.cond.L.Lock()
 112  		// this is nil when all credentials returned an error
 113  		c.successfulCredential = successfulCredential
 114  		c.iterating = false
 115  		c.cond.L.Unlock()
 116  		c.cond.Broadcast()
 117  	}
 118  	// err is the error returned by the last GetToken call. It will be nil when that call succeeds
 119  	if err != nil {
 120  		// return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
 121  		msg := createChainedErrorMessage(errs)
 122  		var authFailedErr *AuthenticationFailedError
 123  		switch {
 124  		case errors.As(err, &authFailedErr):
 125  			err = newAuthenticationFailedError(c.name, msg, authFailedErr.RawResponse)
 126  			if af, ok := err.(*AuthenticationFailedError); ok {
 127  				// stop Error() printing the response again; it's already in msg
 128  				af.omitResponse = true
 129  			}
 130  		case errors.As(err, &unavailableErr):
 131  			err = newCredentialUnavailableError(c.name, msg)
 132  		default:
 133  			res := getResponseFromError(err)
 134  			err = newAuthenticationFailedError(c.name, msg, res)
 135  		}
 136  	}
 137  	return token, err
 138  }
 139  
 140  func createChainedErrorMessage(errs []error) string {
 141  	msg := "failed to acquire a token.\nAttempted credentials:"
 142  	for _, err := range errs {
 143  		msg += fmt.Sprintf("\n\t%s", strings.ReplaceAll(err.Error(), "\n", "\n\t\t"))
 144  	}
 145  	return msg
 146  }
 147  
 148  func extractCredentialName(credential azcore.TokenCredential) string {
 149  	return strings.TrimPrefix(fmt.Sprintf("%T", credential), "*azidentity.")
 150  }
 151  
 152  var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil)
 153