confidential_client.go raw
1 //go:build go1.18
2 // +build go1.18
3
4 // Copyright (c) Microsoft Corporation. All rights reserved.
5 // Licensed under the MIT License.
6
7 package azidentity
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "net/http"
14 "os"
15 "strings"
16 "sync"
17
18 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
19 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
20 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
21 "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
22 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
23 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
24 )
25
26 type confidentialClientOptions struct {
27 azcore.ClientOptions
28
29 AdditionallyAllowedTenants []string
30 // Assertion for on-behalf-of authentication
31 Assertion string
32 Cache Cache
33 DisableInstanceDiscovery, SendX5C bool
34 }
35
36 // confidentialClient wraps the MSAL confidential client
37 type confidentialClient struct {
38 cae, noCAE msalConfidentialClient
39 caeMu, noCAEMu, clientMu *sync.Mutex
40 clientID, tenantID string
41 cred confidential.Credential
42 host string
43 name string
44 opts confidentialClientOptions
45 region string
46 azClient *azcore.Client
47 }
48
49 func newConfidentialClient(tenantID, clientID, name string, cred confidential.Credential, opts confidentialClientOptions) (*confidentialClient, error) {
50 if !validTenantID(tenantID) {
51 return nil, errInvalidTenantID
52 }
53 host, err := setAuthorityHost(opts.Cloud)
54 if err != nil {
55 return nil, err
56 }
57 client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
58 Tracing: runtime.TracingOptions{
59 Namespace: traceNamespace,
60 },
61 }, &opts.ClientOptions)
62 if err != nil {
63 return nil, err
64 }
65 opts.AdditionallyAllowedTenants = resolveAdditionalTenants(opts.AdditionallyAllowedTenants)
66 return &confidentialClient{
67 caeMu: &sync.Mutex{},
68 clientID: clientID,
69 clientMu: &sync.Mutex{},
70 cred: cred,
71 host: host,
72 name: name,
73 noCAEMu: &sync.Mutex{},
74 opts: opts,
75 region: os.Getenv(azureRegionalAuthorityName),
76 tenantID: tenantID,
77 azClient: client,
78 }, nil
79 }
80
81 // GetToken requests an access token from MSAL, checking the cache first.
82 func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
83 if len(tro.Scopes) < 1 {
84 return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name)
85 }
86 // we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants
87 if c.name != credNameManagedIdentity {
88 tenant, err := c.resolveTenant(tro.TenantID)
89 if err != nil {
90 return azcore.AccessToken{}, err
91 }
92 tro.TenantID = tenant
93 }
94 client, mu, err := c.client(tro)
95 if err != nil {
96 return azcore.AccessToken{}, err
97 }
98 mu.Lock()
99 defer mu.Unlock()
100 var ar confidential.AuthResult
101 if c.opts.Assertion != "" {
102 ar, err = client.AcquireTokenOnBehalfOf(ctx, c.opts.Assertion, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
103 } else {
104 ar, err = client.AcquireTokenSilent(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
105 if err != nil {
106 ar, err = client.AcquireTokenByCredential(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
107 }
108 }
109 if err != nil {
110 var (
111 authFailedErr *AuthenticationFailedError
112 unavailableErr credentialUnavailable
113 )
114 if !(errors.As(err, &unavailableErr) || errors.As(err, &authFailedErr)) {
115 err = newAuthenticationFailedErrorFromMSAL(c.name, err)
116 }
117 } else {
118 msg := fmt.Sprintf(scopeLogFmt, c.name, strings.Join(ar.GrantedScopes, ", "))
119 log.Write(EventAuthentication, msg)
120 }
121 return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC(), RefreshOn: ar.Metadata.RefreshOn.UTC()}, err
122 }
123
124 func (c *confidentialClient) client(tro policy.TokenRequestOptions) (msalConfidentialClient, *sync.Mutex, error) {
125 c.clientMu.Lock()
126 defer c.clientMu.Unlock()
127 if tro.EnableCAE {
128 if c.cae == nil {
129 client, err := c.newMSALClient(true)
130 if err != nil {
131 return nil, nil, err
132 }
133 c.cae = client
134 }
135 return c.cae, c.caeMu, nil
136 }
137 if c.noCAE == nil {
138 client, err := c.newMSALClient(false)
139 if err != nil {
140 return nil, nil, err
141 }
142 c.noCAE = client
143 }
144 return c.noCAE, c.noCAEMu, nil
145 }
146
147 func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) {
148 cache, err := internal.ExportReplace(c.opts.Cache, enableCAE)
149 if err != nil {
150 return nil, err
151 }
152 authority := runtime.JoinPaths(c.host, c.tenantID)
153 o := []confidential.Option{
154 confidential.WithAzureRegion(c.region),
155 confidential.WithCache(cache),
156 confidential.WithHTTPClient(c),
157 }
158 if enableCAE {
159 o = append(o, confidential.WithClientCapabilities(cp1))
160 }
161 if c.opts.SendX5C {
162 o = append(o, confidential.WithX5C())
163 }
164 if c.opts.DisableInstanceDiscovery || strings.ToLower(c.tenantID) == "adfs" {
165 o = append(o, confidential.WithInstanceDiscovery(false))
166 }
167 return confidential.New(authority, c.clientID, c.cred, o...)
168 }
169
170 // resolveTenant returns the correct WithTenantID() argument for a token request given the client's
171 // configuration, or an error when that configuration doesn't allow the specified tenant
172 func (c *confidentialClient) resolveTenant(specified string) (string, error) {
173 return resolveTenant(c.tenantID, specified, c.name, c.opts.AdditionallyAllowedTenants)
174 }
175
176 // these methods satisfy the MSAL ops.HTTPClient interface
177
178 func (c *confidentialClient) CloseIdleConnections() {
179 // do nothing
180 }
181
182 func (c *confidentialClient) Do(r *http.Request) (*http.Response, error) {
183 return doForClient(c.azClient, r)
184 }
185