oauth.go raw
1 // Copyright (c) Microsoft Corporation.
2 // Licensed under the MIT license.
3
4 package oauth
5
6 import (
7 "context"
8 "encoding/json"
9 "fmt"
10 "io"
11 "time"
12
13 "github.com/google/uuid"
14
15 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
16 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
17 internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
18 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
19 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
20 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
21 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
22 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
23 )
24
25 // ResolveEndpointer contains the methods for resolving authority endpoints.
26 type ResolveEndpointer interface {
27 ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error)
28 }
29
30 // AccessTokens contains the methods for fetching tokens from different sources.
31 type AccessTokens interface {
32 DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error)
33 FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error)
34 FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error)
35 FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error)
36 FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error)
37 FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error)
38 FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error)
39 FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error)
40 FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error)
41 FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error)
42 }
43
44 // FetchAuthority will be implemented by authority.Authority.
45 type FetchAuthority interface {
46 UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
47 AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
48 }
49
50 // FetchWSTrust contains the methods for interacting with WSTrust endpoints.
51 type FetchWSTrust interface {
52 Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error)
53 SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error)
54 }
55
56 // Client provides tokens for various types of token requests.
57 type Client struct {
58 Resolver ResolveEndpointer
59 AccessTokens AccessTokens
60 Authority FetchAuthority
61 WSTrust FetchWSTrust
62 }
63
64 // New is the constructor for Token.
65 func New(httpClient ops.HTTPClient) *Client {
66 r := ops.New(httpClient)
67 return &Client{
68 Resolver: newAuthorityEndpoint(r),
69 AccessTokens: r.AccessTokens(),
70 Authority: r.Authority(),
71 WSTrust: r.WSTrust(),
72 }
73 }
74
75 // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance.
76 func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
77 return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName)
78 }
79
80 // AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
81 // This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
82 func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
83 return t.Authority.AADInstanceDiscovery(ctx, authorityInfo)
84 }
85
86 // AuthCode returns a token based on an authorization code.
87 func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
88 if err := scopeError(req.AuthParams); err != nil {
89 return accesstokens.TokenResponse{}, err
90 }
91 if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil {
92 return accesstokens.TokenResponse{}, err
93 }
94
95 tResp, err := t.AccessTokens.FromAuthCode(ctx, req)
96 if err != nil {
97 return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err)
98 }
99 return tResp, nil
100 }
101
102 // Credential acquires a token from the authority using a client credentials grant.
103 func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
104 if cred.TokenProvider != nil {
105 now := time.Now()
106 scopes := make([]string, len(authParams.Scopes))
107 copy(scopes, authParams.Scopes)
108 params := exported.TokenProviderParameters{
109 Claims: authParams.Claims,
110 CorrelationID: uuid.New().String(),
111 Scopes: scopes,
112 TenantID: authParams.AuthorityInfo.Tenant,
113 }
114 pr, err := cred.TokenProvider(ctx, params)
115 if err != nil {
116 if len(scopes) == 0 {
117 err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
118 return accesstokens.TokenResponse{}, err
119 }
120 return accesstokens.TokenResponse{}, err
121 }
122 tr := accesstokens.TokenResponse{
123 TokenType: authParams.AuthnScheme.AccessTokenType(),
124 AccessToken: pr.AccessToken,
125 ExpiresOn: now.Add(time.Duration(pr.ExpiresInSeconds) * time.Second),
126 GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
127 }
128 if pr.RefreshInSeconds > 0 {
129 tr.RefreshOn = internalTime.DurationTime{
130 T: now.Add(time.Duration(pr.RefreshInSeconds) * time.Second),
131 }
132 }
133 return tr, nil
134 }
135
136 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
137 return accesstokens.TokenResponse{}, err
138 }
139
140 if cred.Secret != "" {
141 return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret)
142 }
143 jwt, err := cred.JWT(ctx, authParams)
144 if err != nil {
145 return accesstokens.TokenResponse{}, err
146 }
147 return t.AccessTokens.FromAssertion(ctx, authParams, jwt)
148 }
149
150 // Credential acquires a token from the authority using a client credentials grant.
151 func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
152 if err := scopeError(authParams); err != nil {
153 return accesstokens.TokenResponse{}, err
154 }
155 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
156 return accesstokens.TokenResponse{}, err
157 }
158
159 if cred.Secret != "" {
160 return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret)
161 }
162 jwt, err := cred.JWT(ctx, authParams)
163 if err != nil {
164 return accesstokens.TokenResponse{}, err
165 }
166 tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt)
167 if err != nil {
168 return accesstokens.TokenResponse{}, err
169 }
170 return tr, nil
171 }
172
173 func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) {
174 if err := scopeError(authParams); err != nil {
175 return accesstokens.TokenResponse{}, err
176 }
177 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
178 return accesstokens.TokenResponse{}, err
179 }
180
181 tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret)
182 if err != nil {
183 return accesstokens.TokenResponse{}, err
184 }
185 return tr, nil
186 }
187
188 // UsernamePassword retrieves a token where a username and password is used. However, if this is
189 // a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password.
190 func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) {
191 if err := scopeError(authParams); err != nil {
192 return accesstokens.TokenResponse{}, err
193 }
194
195 if authParams.AuthorityInfo.AuthorityType == authority.ADFS {
196 if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil {
197 return accesstokens.TokenResponse{}, err
198 }
199 return t.AccessTokens.FromUsernamePassword(ctx, authParams)
200 }
201 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
202 return accesstokens.TokenResponse{}, err
203 }
204
205 userRealm, err := t.Authority.UserRealm(ctx, authParams)
206 if err != nil {
207 return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err)
208 }
209
210 switch userRealm.AccountType {
211 case authority.Federated:
212 mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL)
213 if err != nil {
214 err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err)
215 return accesstokens.TokenResponse{}, err
216 }
217
218 saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint)
219 if err != nil {
220 err = fmt.Errorf("problem getting SAML token info: %w", err)
221 return accesstokens.TokenResponse{}, err
222 }
223 tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml)
224 if err != nil {
225 return accesstokens.TokenResponse{}, err
226 }
227 return tr, nil
228 case authority.Managed:
229 if len(authParams.Scopes) == 0 {
230 err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
231 return accesstokens.TokenResponse{}, err
232 }
233 return t.AccessTokens.FromUsernamePassword(ctx, authParams)
234 }
235 return accesstokens.TokenResponse{}, errors.New("unknown account type")
236 }
237
238 // DeviceCode is the result of a call to Token.DeviceCode().
239 type DeviceCode struct {
240 // Result is the device code result from the first call in the device code flow. This allows
241 // the caller to retrieve the displayed code that is used to authorize on the second device.
242 Result accesstokens.DeviceCodeResult
243 authParams authority.AuthParams
244
245 accessTokens AccessTokens
246 }
247
248 // Token returns a token AFTER the user uses the user code on the second device. This will block
249 // until either: (1) the code is input by the user and the service releases a token, (2) the token
250 // expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service
251 // error occurs.
252 func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) {
253 if d.accessTokens == nil {
254 return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid")
255 }
256
257 var cancel context.CancelFunc
258 if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) {
259 ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn)
260 } else {
261 ctx, cancel = context.WithCancel(ctx)
262 }
263 defer cancel()
264
265 var interval = 50 * time.Millisecond
266 timer := time.NewTimer(interval)
267 defer timer.Stop()
268
269 for {
270 timer.Reset(interval)
271 select {
272 case <-ctx.Done():
273 return accesstokens.TokenResponse{}, ctx.Err()
274 case <-timer.C:
275 interval += interval * 2
276 if interval > 5*time.Second {
277 interval = 5 * time.Second
278 }
279 }
280
281 token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result)
282 if err != nil && isWaitDeviceCodeErr(err) {
283 continue
284 }
285 return token, err // This handles if it was a non-wait error or success
286 }
287 }
288
289 type deviceCodeError struct {
290 Error string `json:"error"`
291 }
292
293 func isWaitDeviceCodeErr(err error) bool {
294 var c errors.CallErr
295 if !errors.As(err, &c) {
296 return false
297 }
298 if c.Resp.StatusCode != 400 {
299 return false
300 }
301 var dCErr deviceCodeError
302 defer c.Resp.Body.Close()
303 body, err := io.ReadAll(c.Resp.Body)
304 if err != nil {
305 return false
306 }
307 err = json.Unmarshal(body, &dCErr)
308 if err != nil {
309 return false
310 }
311 if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" {
312 return true
313 }
314 return false
315 }
316
317 // DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second
318 // device and optionally the token once the code has been entered on the second device.
319 func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) {
320 if err := scopeError(authParams); err != nil {
321 return DeviceCode{}, err
322 }
323
324 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
325 return DeviceCode{}, err
326 }
327
328 dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams)
329 if err != nil {
330 return DeviceCode{}, err
331 }
332
333 return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil
334 }
335
336 func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error {
337 endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName)
338 if err != nil {
339 return fmt.Errorf("unable to resolve an endpoint: %w", err)
340 }
341 authParams.Endpoints = endpoints
342 return nil
343 }
344
345 // scopeError takes an authority.AuthParams and returns an error
346 // if len(AuthParams.Scope) == 0.
347 func scopeError(a authority.AuthParams) error {
348 // TODO(someone): we could look deeper at the message to determine if
349 // it's a scope error, but this is a good start.
350 /*
351 {error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope
352 openid offline_access profile is not valid. Client credential flows must have a scope value
353 with /.default suffixed to the resource identifier (application ID URI)...}
354 */
355 if len(a.Scopes) == 0 {
356 return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid")
357 }
358 return nil
359 }
360