public_client.go raw
1 //go:build go1.18
2 // +build go1.18
3
4 // Copyright (c) Microsoft Corporation. All rights reserved.
5 // Licensed under the MIT License.
6
7 package azidentity
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "net/http"
14 "strings"
15 "sync"
16
17 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
18 "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
19 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
20 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
21 "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
22 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
23 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
24
25 // this import ensures well-known configurations in azcore/cloud have ARM audiences for Authenticate()
26 _ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
27 )
28
29 type publicClientOptions struct {
30 azcore.ClientOptions
31
32 AdditionallyAllowedTenants []string
33 Cache Cache
34 DeviceCodePrompt func(context.Context, DeviceCodeMessage) error
35 DisableAutomaticAuthentication bool
36 DisableInstanceDiscovery bool
37 LoginHint, RedirectURL string
38 Record AuthenticationRecord
39 Username, Password string
40 }
41
42 // publicClient wraps the MSAL public client
43 type publicClient struct {
44 cae, noCAE msalPublicClient
45 caeMu, noCAEMu, clientMu *sync.Mutex
46 clientID, tenantID string
47 defaultScope []string
48 host string
49 name string
50 opts publicClientOptions
51 record AuthenticationRecord
52 azClient *azcore.Client
53 }
54
55 var errScopeRequired = errors.New("authenticating in this environment requires specifying a scope in TokenRequestOptions")
56
57 func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) {
58 if !validTenantID(tenantID) {
59 return nil, errInvalidTenantID
60 }
61 host, err := setAuthorityHost(o.Cloud)
62 if err != nil {
63 return nil, err
64 }
65 // if the application specified a cloud configuration, use its ARM audience as the default scope for Authenticate()
66 audience := o.Cloud.Services[cloud.ResourceManager].Audience
67 if audience == "" {
68 // no cloud configuration, or no ARM audience, specified; try to map the host to a well-known one (all of which have a trailing slash)
69 if !strings.HasSuffix(host, "/") {
70 host += "/"
71 }
72 switch host {
73 case cloud.AzureChina.ActiveDirectoryAuthorityHost:
74 audience = cloud.AzureChina.Services[cloud.ResourceManager].Audience
75 case cloud.AzureGovernment.ActiveDirectoryAuthorityHost:
76 audience = cloud.AzureGovernment.Services[cloud.ResourceManager].Audience
77 case cloud.AzurePublic.ActiveDirectoryAuthorityHost:
78 audience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
79 }
80 }
81 // if we didn't come up with an audience, the application will have to specify a scope for Authenticate()
82 var defaultScope []string
83 if audience != "" {
84 defaultScope = []string{audience + defaultSuffix}
85 }
86 client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
87 Tracing: runtime.TracingOptions{
88 Namespace: traceNamespace,
89 },
90 }, &o.ClientOptions)
91 if err != nil {
92 return nil, err
93 }
94 o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants)
95 return &publicClient{
96 caeMu: &sync.Mutex{},
97 clientID: clientID,
98 clientMu: &sync.Mutex{},
99 defaultScope: defaultScope,
100 host: host,
101 name: name,
102 noCAEMu: &sync.Mutex{},
103 opts: o,
104 record: o.Record,
105 tenantID: tenantID,
106 azClient: client,
107 }, nil
108 }
109
110 func (p *publicClient) Authenticate(ctx context.Context, tro *policy.TokenRequestOptions) (AuthenticationRecord, error) {
111 if tro == nil {
112 tro = &policy.TokenRequestOptions{}
113 }
114 if len(tro.Scopes) == 0 {
115 if p.defaultScope == nil {
116 return AuthenticationRecord{}, errScopeRequired
117 }
118 tro.Scopes = p.defaultScope
119 }
120 client, mu, err := p.client(*tro)
121 if err != nil {
122 return AuthenticationRecord{}, err
123 }
124 mu.Lock()
125 defer mu.Unlock()
126 _, err = p.reqToken(ctx, client, *tro)
127 if err == nil {
128 scope := strings.Join(tro.Scopes, ", ")
129 msg := fmt.Sprintf("%s.Authenticate() acquired a token for scope %q", p.name, scope)
130 log.Write(EventAuthentication, msg)
131 }
132 return p.record, err
133 }
134
135 // GetToken requests an access token from MSAL, checking the cache first.
136 func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
137 if len(tro.Scopes) < 1 {
138 return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name)
139 }
140 tenant, err := p.resolveTenant(tro.TenantID)
141 if err != nil {
142 return azcore.AccessToken{}, err
143 }
144 client, mu, err := p.client(tro)
145 if err != nil {
146 return azcore.AccessToken{}, err
147 }
148 mu.Lock()
149 defer mu.Unlock()
150 ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.record.account()), public.WithClaims(tro.Claims), public.WithTenantID(tenant))
151 if err == nil {
152 return p.token(ar, err)
153 }
154 if p.opts.DisableAutomaticAuthentication {
155 return azcore.AccessToken{}, newAuthenticationRequiredError(p.name, tro)
156 }
157 return p.reqToken(ctx, client, tro)
158 }
159
160 // reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache.
161 func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
162 tenant, err := p.resolveTenant(tro.TenantID)
163 if err != nil {
164 return azcore.AccessToken{}, err
165 }
166 var ar public.AuthResult
167 switch p.name {
168 case credNameBrowser:
169 ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes,
170 public.WithClaims(tro.Claims),
171 public.WithLoginHint(p.opts.LoginHint),
172 public.WithRedirectURI(p.opts.RedirectURL),
173 public.WithTenantID(tenant),
174 )
175 case credNameDeviceCode:
176 dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
177 if e != nil {
178 return azcore.AccessToken{}, e
179 }
180 err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{
181 Message: dc.Result.Message,
182 UserCode: dc.Result.UserCode,
183 VerificationURL: dc.Result.VerificationURL,
184 })
185 if err == nil {
186 ar, err = dc.AuthenticationResult(ctx)
187 }
188 case credNameUserPassword:
189 ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
190 default:
191 return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name)
192 }
193 return p.token(ar, err)
194 }
195
196 func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) {
197 p.clientMu.Lock()
198 defer p.clientMu.Unlock()
199 if tro.EnableCAE {
200 if p.cae == nil {
201 client, err := p.newMSALClient(true)
202 if err != nil {
203 return nil, nil, err
204 }
205 p.cae = client
206 }
207 return p.cae, p.caeMu, nil
208 }
209 if p.noCAE == nil {
210 client, err := p.newMSALClient(false)
211 if err != nil {
212 return nil, nil, err
213 }
214 p.noCAE = client
215 }
216 return p.noCAE, p.noCAEMu, nil
217 }
218
219 func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) {
220 c, err := internal.ExportReplace(p.opts.Cache, enableCAE)
221 if err != nil {
222 return nil, err
223 }
224 o := []public.Option{
225 public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)),
226 public.WithCache(c),
227 public.WithHTTPClient(p),
228 }
229 if enableCAE {
230 o = append(o, public.WithClientCapabilities(cp1))
231 }
232 if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" {
233 o = append(o, public.WithInstanceDiscovery(false))
234 }
235 return public.New(p.clientID, o...)
236 }
237
238 func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) {
239 if err == nil {
240 msg := fmt.Sprintf(scopeLogFmt, p.name, strings.Join(ar.GrantedScopes, ", "))
241 log.Write(EventAuthentication, msg)
242 p.record, err = newAuthenticationRecord(ar)
243 } else {
244 err = newAuthenticationFailedErrorFromMSAL(p.name, err)
245 }
246 return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC(), RefreshOn: ar.Metadata.RefreshOn.UTC()}, err
247 }
248
249 // resolveTenant returns the correct WithTenantID() argument for a token request given the client's
250 // configuration, or an error when that configuration doesn't allow the specified tenant
251 func (p *publicClient) resolveTenant(specified string) (string, error) {
252 t, err := resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants)
253 if t == p.tenantID {
254 // callers pass this value to MSAL's WithTenantID(). There's no need to redundantly specify
255 // the client's default tenant and doing so is an error when that tenant is "organizations"
256 t = ""
257 }
258 return t, err
259 }
260
261 // these methods satisfy the MSAL ops.HTTPClient interface
262
263 func (p *publicClient) CloseIdleConnections() {
264 // do nothing
265 }
266
267 func (p *publicClient) Do(r *http.Request) (*http.Response, error) {
268 return doForClient(p.azClient, r)
269 }
270