1 //go:build go1.18
2 // +build go1.18
3 4 // Copyright (c) Microsoft Corporation. All rights reserved.
5 // Licensed under the MIT License.
6 7 package azidentity
8 9 import (
10 "context"
11 "errors"
12 "fmt"
13 "strings"
14 "sync"
15 16 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
17 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
18 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
19 )
20 21 // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
22 type ChainedTokenCredentialOptions struct {
23 // RetrySources configures how the credential uses its sources. When true, the credential always attempts to
24 // authenticate through each source in turn, stopping when one succeeds. When false, the credential authenticates
25 // only through this first successful source--it never again tries the sources which failed.
26 RetrySources bool
27 }
28 29 // ChainedTokenCredential links together multiple credentials and tries them sequentially when authenticating. By default,
30 // it tries all the credentials until one authenticates, after which it always uses that credential. For more information,
31 // see [ChainedTokenCredential overview].
32 //
33 // [ChainedTokenCredential overview]: https://aka.ms/azsdk/go/identity/credential-chains#chainedtokencredential-overview
34 type ChainedTokenCredential struct {
35 cond *sync.Cond
36 iterating bool
37 name string
38 retrySources bool
39 sources []azcore.TokenCredential
40 successfulCredential azcore.TokenCredential
41 }
42 43 // NewChainedTokenCredential creates a ChainedTokenCredential. Pass nil for options to accept defaults.
44 func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
45 if len(sources) == 0 {
46 return nil, errors.New("sources must contain at least one TokenCredential")
47 }
48 for _, source := range sources {
49 if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
50 return nil, errors.New("sources cannot contain nil")
51 }
52 if mc, ok := source.(*ManagedIdentityCredential); ok {
53 mc.mic.chained = true
54 }
55 }
56 cp := make([]azcore.TokenCredential, len(sources))
57 copy(cp, sources)
58 if options == nil {
59 options = &ChainedTokenCredentialOptions{}
60 }
61 return &ChainedTokenCredential{
62 cond: sync.NewCond(&sync.Mutex{}),
63 name: "ChainedTokenCredential",
64 retrySources: options.RetrySources,
65 sources: cp,
66 }, nil
67 }
68 69 // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token.
70 // This method is called automatically by Azure SDK clients.
71 func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
72 if !c.retrySources {
73 // ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
74 c.cond.L.Lock()
75 for {
76 if c.successfulCredential != nil {
77 c.cond.L.Unlock()
78 return c.successfulCredential.GetToken(ctx, opts)
79 }
80 if !c.iterating {
81 c.iterating = true
82 // allow other goroutines to wait while this one iterates
83 c.cond.L.Unlock()
84 break
85 }
86 c.cond.Wait()
87 }
88 }
89 90 var (
91 err error
92 errs []error
93 successfulCredential azcore.TokenCredential
94 token azcore.AccessToken
95 unavailableErr credentialUnavailable
96 )
97 for _, cred := range c.sources {
98 token, err = cred.GetToken(ctx, opts)
99 if err == nil {
100 log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
101 successfulCredential = cred
102 break
103 }
104 errs = append(errs, err)
105 // continue to the next source iff this one returned credentialUnavailableError
106 if !errors.As(err, &unavailableErr) {
107 break
108 }
109 }
110 if c.iterating {
111 c.cond.L.Lock()
112 // this is nil when all credentials returned an error
113 c.successfulCredential = successfulCredential
114 c.iterating = false
115 c.cond.L.Unlock()
116 c.cond.Broadcast()
117 }
118 // err is the error returned by the last GetToken call. It will be nil when that call succeeds
119 if err != nil {
120 // return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
121 msg := createChainedErrorMessage(errs)
122 var authFailedErr *AuthenticationFailedError
123 switch {
124 case errors.As(err, &authFailedErr):
125 err = newAuthenticationFailedError(c.name, msg, authFailedErr.RawResponse)
126 if af, ok := err.(*AuthenticationFailedError); ok {
127 // stop Error() printing the response again; it's already in msg
128 af.omitResponse = true
129 }
130 case errors.As(err, &unavailableErr):
131 err = newCredentialUnavailableError(c.name, msg)
132 default:
133 res := getResponseFromError(err)
134 err = newAuthenticationFailedError(c.name, msg, res)
135 }
136 }
137 return token, err
138 }
139 140 func createChainedErrorMessage(errs []error) string {
141 msg := "failed to acquire a token.\nAttempted credentials:"
142 for _, err := range errs {
143 msg += fmt.Sprintf("\n\t%s", strings.ReplaceAll(err.Error(), "\n", "\n\t\t"))
144 }
145 return msg
146 }
147 148 func extractCredentialName(credential azcore.TokenCredential) string {
149 return strings.TrimPrefix(fmt.Sprintf("%T", credential), "*azidentity.")
150 }
151 152 var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil)
153