token.go raw
1 package adal
2
3 // Copyright 2017 Microsoft Corporation
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 import (
18 "context"
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/sha1"
22 "crypto/x509"
23 "encoding/base64"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "io"
28 "io/ioutil"
29 "math"
30 "net/http"
31 "net/url"
32 "os"
33 "strconv"
34 "strings"
35 "sync"
36 "time"
37
38 "github.com/Azure/go-autorest/autorest/date"
39 "github.com/Azure/go-autorest/logger"
40 "github.com/golang-jwt/jwt/v4"
41 )
42
43 const (
44 defaultRefresh = 5 * time.Minute
45
46 // OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
47 OAuthGrantTypeDeviceCode = "device_code"
48
49 // OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
50 OAuthGrantTypeClientCredentials = "client_credentials"
51
52 // OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
53 OAuthGrantTypeUserPass = "password"
54
55 // OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
56 OAuthGrantTypeRefreshToken = "refresh_token"
57
58 // OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
59 OAuthGrantTypeAuthorizationCode = "authorization_code"
60
61 // metadataHeader is the header required by MSI extension
62 metadataHeader = "Metadata"
63
64 // msiEndpoint is the well known endpoint for getting MSI authentications tokens
65 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
66
67 // the API version to use for the MSI endpoint
68 msiAPIVersion = "2018-02-01"
69
70 // the default number of attempts to refresh an MSI authentication token
71 defaultMaxMSIRefreshAttempts = 5
72
73 // asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
74 msiEndpointEnv = "MSI_ENDPOINT"
75
76 // asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
77 msiSecretEnv = "MSI_SECRET"
78
79 // the API version to use for the legacy App Service MSI endpoint
80 appServiceAPIVersion2017 = "2017-09-01"
81
82 // secret header used when authenticating against app service MSI endpoint
83 secretHeader = "Secret"
84
85 // the format for expires_on in UTC with AM/PM
86 expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
87
88 // the format for expires_on in UTC without AM/PM
89 expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
90 )
91
92 // OAuthTokenProvider is an interface which should be implemented by an access token retriever
93 type OAuthTokenProvider interface {
94 OAuthToken() string
95 }
96
97 // MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
98 type MultitenantOAuthTokenProvider interface {
99 PrimaryOAuthToken() string
100 AuxiliaryOAuthTokens() []string
101 }
102
103 // TokenRefreshError is an interface used by errors returned during token refresh.
104 type TokenRefreshError interface {
105 error
106 Response() *http.Response
107 }
108
109 // Refresher is an interface for token refresh functionality
110 type Refresher interface {
111 Refresh() error
112 RefreshExchange(resource string) error
113 EnsureFresh() error
114 }
115
116 // RefresherWithContext is an interface for token refresh functionality
117 type RefresherWithContext interface {
118 RefreshWithContext(ctx context.Context) error
119 RefreshExchangeWithContext(ctx context.Context, resource string) error
120 EnsureFreshWithContext(ctx context.Context) error
121 }
122
123 // TokenRefreshCallback is the type representing callbacks that will be called after
124 // a successful token refresh
125 type TokenRefreshCallback func(Token) error
126
127 // TokenRefresh is a type representing a custom callback to refresh a token
128 type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
129
130 // Token encapsulates the access token used to authorize Azure requests.
131 // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
132 type Token struct {
133 AccessToken string `json:"access_token"`
134 RefreshToken string `json:"refresh_token"`
135
136 ExpiresIn json.Number `json:"expires_in"`
137 ExpiresOn json.Number `json:"expires_on"`
138 NotBefore json.Number `json:"not_before"`
139
140 Resource string `json:"resource"`
141 Type string `json:"token_type"`
142 }
143
144 func newToken() Token {
145 return Token{
146 ExpiresIn: "0",
147 ExpiresOn: "0",
148 NotBefore: "0",
149 }
150 }
151
152 // IsZero returns true if the token object is zero-initialized.
153 func (t Token) IsZero() bool {
154 return t == Token{}
155 }
156
157 // Expires returns the time.Time when the Token expires.
158 func (t Token) Expires() time.Time {
159 s, err := t.ExpiresOn.Float64()
160 if err != nil {
161 s = -3600
162 }
163
164 expiration := date.NewUnixTimeFromSeconds(s)
165
166 return time.Time(expiration).UTC()
167 }
168
169 // IsExpired returns true if the Token is expired, false otherwise.
170 func (t Token) IsExpired() bool {
171 return t.WillExpireIn(0)
172 }
173
174 // WillExpireIn returns true if the Token will expire after the passed time.Duration interval
175 // from now, false otherwise.
176 func (t Token) WillExpireIn(d time.Duration) bool {
177 return !t.Expires().After(time.Now().Add(d))
178 }
179
180 // OAuthToken return the current access token
181 func (t *Token) OAuthToken() string {
182 return t.AccessToken
183 }
184
185 // ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
186 // that is submitted when acquiring an oAuth token.
187 type ServicePrincipalSecret interface {
188 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
189 }
190
191 // ServicePrincipalNoSecret represents a secret type that contains no secret
192 // meaning it is not valid for fetching a fresh token. This is used by Manual
193 type ServicePrincipalNoSecret struct {
194 }
195
196 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret
197 // It only returns an error for the ServicePrincipalNoSecret type
198 func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
199 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
200 }
201
202 // MarshalJSON implements the json.Marshaler interface.
203 func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
204 type tokenType struct {
205 Type string `json:"type"`
206 }
207 return json.Marshal(tokenType{
208 Type: "ServicePrincipalNoSecret",
209 })
210 }
211
212 // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
213 type ServicePrincipalTokenSecret struct {
214 ClientSecret string `json:"value"`
215 }
216
217 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
218 // It will populate the form submitted during oAuth Token Acquisition using the client_secret.
219 func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
220 v.Set("client_secret", tokenSecret.ClientSecret)
221 return nil
222 }
223
224 // MarshalJSON implements the json.Marshaler interface.
225 func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
226 type tokenType struct {
227 Type string `json:"type"`
228 Value string `json:"value"`
229 }
230 return json.Marshal(tokenType{
231 Type: "ServicePrincipalTokenSecret",
232 Value: tokenSecret.ClientSecret,
233 })
234 }
235
236 // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
237 type ServicePrincipalCertificateSecret struct {
238 Certificate *x509.Certificate
239 PrivateKey *rsa.PrivateKey
240 }
241
242 // SignJwt returns the JWT signed with the certificate's private key.
243 func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
244 hasher := sha1.New()
245 _, err := hasher.Write(secret.Certificate.Raw)
246 if err != nil {
247 return "", err
248 }
249
250 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
251
252 // The jti (JWT ID) claim provides a unique identifier for the JWT.
253 jti := make([]byte, 20)
254 _, err = rand.Read(jti)
255 if err != nil {
256 return "", err
257 }
258
259 token := jwt.New(jwt.SigningMethodRS256)
260 token.Header["x5t"] = thumbprint
261 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
262 token.Header["x5c"] = x5c
263 token.Claims = jwt.MapClaims{
264 "aud": spt.inner.OauthConfig.TokenEndpoint.String(),
265 "iss": spt.inner.ClientID,
266 "sub": spt.inner.ClientID,
267 "jti": base64.URLEncoding.EncodeToString(jti),
268 "nbf": time.Now().Unix(),
269 "exp": time.Now().Add(24 * time.Hour).Unix(),
270 }
271
272 signedString, err := token.SignedString(secret.PrivateKey)
273 return signedString, err
274 }
275
276 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
277 // It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
278 func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
279 jwt, err := secret.SignJwt(spt)
280 if err != nil {
281 return err
282 }
283
284 v.Set("client_assertion", jwt)
285 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
286 return nil
287 }
288
289 // MarshalJSON implements the json.Marshaler interface.
290 func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
291 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
292 }
293
294 // ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
295 type ServicePrincipalMSISecret struct {
296 msiType msiType
297 clientResourceID string
298 }
299
300 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
301 func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
302 return nil
303 }
304
305 // MarshalJSON implements the json.Marshaler interface.
306 func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
307 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
308 }
309
310 // ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
311 type ServicePrincipalUsernamePasswordSecret struct {
312 Username string `json:"username"`
313 Password string `json:"password"`
314 }
315
316 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
317 func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
318 v.Set("username", secret.Username)
319 v.Set("password", secret.Password)
320 return nil
321 }
322
323 // MarshalJSON implements the json.Marshaler interface.
324 func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
325 type tokenType struct {
326 Type string `json:"type"`
327 Username string `json:"username"`
328 Password string `json:"password"`
329 }
330 return json.Marshal(tokenType{
331 Type: "ServicePrincipalUsernamePasswordSecret",
332 Username: secret.Username,
333 Password: secret.Password,
334 })
335 }
336
337 // ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
338 type ServicePrincipalAuthorizationCodeSecret struct {
339 ClientSecret string `json:"value"`
340 AuthorizationCode string `json:"authCode"`
341 RedirectURI string `json:"redirect"`
342 }
343
344 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
345 func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
346 v.Set("code", secret.AuthorizationCode)
347 v.Set("client_secret", secret.ClientSecret)
348 v.Set("redirect_uri", secret.RedirectURI)
349 return nil
350 }
351
352 // MarshalJSON implements the json.Marshaler interface.
353 func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
354 type tokenType struct {
355 Type string `json:"type"`
356 Value string `json:"value"`
357 AuthCode string `json:"authCode"`
358 Redirect string `json:"redirect"`
359 }
360 return json.Marshal(tokenType{
361 Type: "ServicePrincipalAuthorizationCodeSecret",
362 Value: secret.ClientSecret,
363 AuthCode: secret.AuthorizationCode,
364 Redirect: secret.RedirectURI,
365 })
366 }
367
368 // ServicePrincipalFederatedSecret implements ServicePrincipalSecret for Federated JWTs.
369 type ServicePrincipalFederatedSecret struct {
370 jwt string
371 }
372
373 // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
374 // It will populate the form submitted during OAuth Token Acquisition using a JWT signed by an OIDC issuer.
375 func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
376
377 v.Set("client_assertion", secret.jwt)
378 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
379 return nil
380 }
381
382 // MarshalJSON implements the json.Marshaler interface.
383 func (secret ServicePrincipalFederatedSecret) MarshalJSON() ([]byte, error) {
384 return nil, errors.New("marshalling ServicePrincipalFederatedSecret is not supported")
385 }
386
387 // ServicePrincipalToken encapsulates a Token created for a Service Principal.
388 type ServicePrincipalToken struct {
389 inner servicePrincipalToken
390 refreshLock *sync.RWMutex
391 sender Sender
392 customRefreshFunc TokenRefresh
393 refreshCallbacks []TokenRefreshCallback
394 // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
395 // Settings this to a value less than 1 will use the default value.
396 MaxMSIRefreshAttempts int
397 }
398
399 // MarshalTokenJSON returns the marshalled inner token.
400 func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
401 return json.Marshal(spt.inner.Token)
402 }
403
404 // SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
405 func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
406 spt.refreshCallbacks = callbacks
407 }
408
409 // SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
410 func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
411 spt.customRefreshFunc = customRefreshFunc
412 }
413
414 // MarshalJSON implements the json.Marshaler interface.
415 func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
416 return json.Marshal(spt.inner)
417 }
418
419 // UnmarshalJSON implements the json.Unmarshaler interface.
420 func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
421 // need to determine the token type
422 raw := map[string]interface{}{}
423 err := json.Unmarshal(data, &raw)
424 if err != nil {
425 return err
426 }
427 secret := raw["secret"].(map[string]interface{})
428 switch secret["type"] {
429 case "ServicePrincipalNoSecret":
430 spt.inner.Secret = &ServicePrincipalNoSecret{}
431 case "ServicePrincipalTokenSecret":
432 spt.inner.Secret = &ServicePrincipalTokenSecret{}
433 case "ServicePrincipalCertificateSecret":
434 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
435 case "ServicePrincipalMSISecret":
436 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
437 case "ServicePrincipalUsernamePasswordSecret":
438 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
439 case "ServicePrincipalAuthorizationCodeSecret":
440 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
441 case "ServicePrincipalFederatedSecret":
442 return errors.New("unmarshalling ServicePrincipalFederatedSecret is not supported")
443 default:
444 return fmt.Errorf("unrecognized token type '%s'", secret["type"])
445 }
446 err = json.Unmarshal(data, &spt.inner)
447 if err != nil {
448 return err
449 }
450 // Don't override the refreshLock or the sender if those have been already set.
451 if spt.refreshLock == nil {
452 spt.refreshLock = &sync.RWMutex{}
453 }
454 if spt.sender == nil {
455 spt.sender = sender()
456 }
457 return nil
458 }
459
460 // internal type used for marshalling/unmarshalling
461 type servicePrincipalToken struct {
462 Token Token `json:"token"`
463 Secret ServicePrincipalSecret `json:"secret"`
464 OauthConfig OAuthConfig `json:"oauth"`
465 ClientID string `json:"clientID"`
466 Resource string `json:"resource"`
467 AutoRefresh bool `json:"autoRefresh"`
468 RefreshWithin time.Duration `json:"refreshWithin"`
469 }
470
471 func validateOAuthConfig(oac OAuthConfig) error {
472 if oac.IsZero() {
473 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
474 }
475 return nil
476 }
477
478 // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
479 func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
480 if err := validateOAuthConfig(oauthConfig); err != nil {
481 return nil, err
482 }
483 if err := validateStringParam(id, "id"); err != nil {
484 return nil, err
485 }
486 if err := validateStringParam(resource, "resource"); err != nil {
487 return nil, err
488 }
489 if secret == nil {
490 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
491 }
492 spt := &ServicePrincipalToken{
493 inner: servicePrincipalToken{
494 Token: newToken(),
495 OauthConfig: oauthConfig,
496 Secret: secret,
497 ClientID: id,
498 Resource: resource,
499 AutoRefresh: true,
500 RefreshWithin: defaultRefresh,
501 },
502 refreshLock: &sync.RWMutex{},
503 sender: sender(),
504 refreshCallbacks: callbacks,
505 }
506 return spt, nil
507 }
508
509 // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
510 func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
511 if err := validateOAuthConfig(oauthConfig); err != nil {
512 return nil, err
513 }
514 if err := validateStringParam(clientID, "clientID"); err != nil {
515 return nil, err
516 }
517 if err := validateStringParam(resource, "resource"); err != nil {
518 return nil, err
519 }
520 if token.IsZero() {
521 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
522 }
523 spt, err := NewServicePrincipalTokenWithSecret(
524 oauthConfig,
525 clientID,
526 resource,
527 &ServicePrincipalNoSecret{},
528 callbacks...)
529 if err != nil {
530 return nil, err
531 }
532
533 spt.inner.Token = token
534
535 return spt, nil
536 }
537
538 // NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
539 func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
540 if err := validateOAuthConfig(oauthConfig); err != nil {
541 return nil, err
542 }
543 if err := validateStringParam(clientID, "clientID"); err != nil {
544 return nil, err
545 }
546 if err := validateStringParam(resource, "resource"); err != nil {
547 return nil, err
548 }
549 if secret == nil {
550 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
551 }
552 if token.IsZero() {
553 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
554 }
555 spt, err := NewServicePrincipalTokenWithSecret(
556 oauthConfig,
557 clientID,
558 resource,
559 secret,
560 callbacks...)
561 if err != nil {
562 return nil, err
563 }
564
565 spt.inner.Token = token
566
567 return spt, nil
568 }
569
570 // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
571 // credentials scoped to the named resource.
572 func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
573 if err := validateOAuthConfig(oauthConfig); err != nil {
574 return nil, err
575 }
576 if err := validateStringParam(clientID, "clientID"); err != nil {
577 return nil, err
578 }
579 if err := validateStringParam(secret, "secret"); err != nil {
580 return nil, err
581 }
582 if err := validateStringParam(resource, "resource"); err != nil {
583 return nil, err
584 }
585 return NewServicePrincipalTokenWithSecret(
586 oauthConfig,
587 clientID,
588 resource,
589 &ServicePrincipalTokenSecret{
590 ClientSecret: secret,
591 },
592 callbacks...,
593 )
594 }
595
596 // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
597 func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
598 if err := validateOAuthConfig(oauthConfig); err != nil {
599 return nil, err
600 }
601 if err := validateStringParam(clientID, "clientID"); err != nil {
602 return nil, err
603 }
604 if err := validateStringParam(resource, "resource"); err != nil {
605 return nil, err
606 }
607 if certificate == nil {
608 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
609 }
610 if privateKey == nil {
611 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
612 }
613 return NewServicePrincipalTokenWithSecret(
614 oauthConfig,
615 clientID,
616 resource,
617 &ServicePrincipalCertificateSecret{
618 PrivateKey: privateKey,
619 Certificate: certificate,
620 },
621 callbacks...,
622 )
623 }
624
625 // NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
626 func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
627 if err := validateOAuthConfig(oauthConfig); err != nil {
628 return nil, err
629 }
630 if err := validateStringParam(clientID, "clientID"); err != nil {
631 return nil, err
632 }
633 if err := validateStringParam(username, "username"); err != nil {
634 return nil, err
635 }
636 if err := validateStringParam(password, "password"); err != nil {
637 return nil, err
638 }
639 if err := validateStringParam(resource, "resource"); err != nil {
640 return nil, err
641 }
642 return NewServicePrincipalTokenWithSecret(
643 oauthConfig,
644 clientID,
645 resource,
646 &ServicePrincipalUsernamePasswordSecret{
647 Username: username,
648 Password: password,
649 },
650 callbacks...,
651 )
652 }
653
654 // NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
655 func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
656
657 if err := validateOAuthConfig(oauthConfig); err != nil {
658 return nil, err
659 }
660 if err := validateStringParam(clientID, "clientID"); err != nil {
661 return nil, err
662 }
663 if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
664 return nil, err
665 }
666 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
667 return nil, err
668 }
669 if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
670 return nil, err
671 }
672 if err := validateStringParam(resource, "resource"); err != nil {
673 return nil, err
674 }
675
676 return NewServicePrincipalTokenWithSecret(
677 oauthConfig,
678 clientID,
679 resource,
680 &ServicePrincipalAuthorizationCodeSecret{
681 ClientSecret: clientSecret,
682 AuthorizationCode: authorizationCode,
683 RedirectURI: redirectURI,
684 },
685 callbacks...,
686 )
687 }
688
689 // NewServicePrincipalTokenFromFederatedToken creates a ServicePrincipalToken from the supplied federated OIDC JWT.
690 func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
691 if err := validateOAuthConfig(oauthConfig); err != nil {
692 return nil, err
693 }
694 if err := validateStringParam(clientID, "clientID"); err != nil {
695 return nil, err
696 }
697 if err := validateStringParam(resource, "resource"); err != nil {
698 return nil, err
699 }
700 if jwt == "" {
701 return nil, fmt.Errorf("parameter 'jwt' cannot be empty")
702 }
703 return NewServicePrincipalTokenWithSecret(
704 oauthConfig,
705 clientID,
706 resource,
707 &ServicePrincipalFederatedSecret{
708 jwt: jwt,
709 },
710 callbacks...,
711 )
712 }
713
714 type msiType int
715
716 const (
717 msiTypeUnavailable msiType = iota
718 msiTypeAppServiceV20170901
719 msiTypeCloudShell
720 msiTypeIMDS
721 )
722
723 func (m msiType) String() string {
724 switch m {
725 case msiTypeAppServiceV20170901:
726 return "AppServiceV20170901"
727 case msiTypeCloudShell:
728 return "CloudShell"
729 case msiTypeIMDS:
730 return "IMDS"
731 default:
732 return fmt.Sprintf("unhandled MSI type %d", m)
733 }
734 }
735
736 // returns the MSI type and endpoint, or an error
737 func getMSIType() (msiType, string, error) {
738 if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
739 // if the env var MSI_ENDPOINT is set
740 if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
741 // if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
742 return msiTypeAppServiceV20170901, endpointEnvVar, nil
743 }
744 // if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
745 return msiTypeCloudShell, endpointEnvVar, nil
746 }
747 // if MSI_ENDPOINT is NOT set assume the msiType is IMDS
748 return msiTypeIMDS, msiEndpoint, nil
749 }
750
751 // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
752 // NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell.
753 // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
754 func GetMSIVMEndpoint() (string, error) {
755 return msiEndpoint, nil
756 }
757
758 // GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions.
759 // It will return an error when not running in an app service/functions environment.
760 // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
761 func GetMSIAppServiceEndpoint() (string, error) {
762 msiType, endpoint, err := getMSIType()
763 if err != nil {
764 return "", err
765 }
766 switch msiType {
767 case msiTypeAppServiceV20170901:
768 return endpoint, nil
769 default:
770 return "", fmt.Errorf("%s is not app service environment", msiType)
771 }
772 }
773
774 // GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
775 // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
776 func GetMSIEndpoint() (string, error) {
777 _, endpoint, err := getMSIType()
778 return endpoint, err
779 }
780
781 // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
782 // It will use the system assigned identity when creating the token.
783 // msiEndpoint - empty string, or pass a non-empty string to override the default value.
784 // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
785 func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
786 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
787 }
788
789 // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
790 // It will use the clientID of specified user assigned identity when creating the token.
791 // msiEndpoint - empty string, or pass a non-empty string to override the default value.
792 // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
793 func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
794 if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
795 return nil, err
796 }
797 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
798 }
799
800 // NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
801 // It will use the azure resource id of user assigned identity when creating the token.
802 // msiEndpoint - empty string, or pass a non-empty string to override the default value.
803 // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
804 func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
805 if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
806 return nil, err
807 }
808 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
809 }
810
811 // ManagedIdentityOptions contains optional values for configuring managed identity authentication.
812 type ManagedIdentityOptions struct {
813 // ClientID is the user-assigned identity to use during authentication.
814 // It is mutually exclusive with IdentityResourceID.
815 ClientID string
816
817 // IdentityResourceID is the resource ID of the user-assigned identity to use during authentication.
818 // It is mutually exclusive with ClientID.
819 IdentityResourceID string
820 }
821
822 // NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity.
823 // It supports the following managed identity environments.
824 // - App Service Environment (API version 2017-09-01 only)
825 // - Cloud shell
826 // - IMDS with a system or user assigned identity
827 func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
828 if options == nil {
829 options = &ManagedIdentityOptions{}
830 }
831 return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
832 }
833
834 func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
835 if err := validateStringParam(resource, "resource"); err != nil {
836 return nil, err
837 }
838 if userAssignedID != "" && identityResourceID != "" {
839 return nil, errors.New("cannot specify userAssignedID and identityResourceID")
840 }
841 msiType, endpoint, err := getMSIType()
842 if err != nil {
843 logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err)
844 return nil, err
845 }
846 logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint)
847 if msiEndpoint != "" {
848 endpoint = msiEndpoint
849 logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint)
850 }
851 msiEndpointURL, err := url.Parse(endpoint)
852 if err != nil {
853 return nil, err
854 }
855 // cloud shell sends its data in the request body
856 if msiType != msiTypeCloudShell {
857 v := url.Values{}
858 v.Set("resource", resource)
859 clientIDParam := "client_id"
860 switch msiType {
861 case msiTypeAppServiceV20170901:
862 clientIDParam = "clientid"
863 v.Set("api-version", appServiceAPIVersion2017)
864 break
865 case msiTypeIMDS:
866 v.Set("api-version", msiAPIVersion)
867 }
868 if userAssignedID != "" {
869 v.Set(clientIDParam, userAssignedID)
870 } else if identityResourceID != "" {
871 v.Set("mi_res_id", identityResourceID)
872 }
873 msiEndpointURL.RawQuery = v.Encode()
874 }
875
876 spt := &ServicePrincipalToken{
877 inner: servicePrincipalToken{
878 Token: newToken(),
879 OauthConfig: OAuthConfig{
880 TokenEndpoint: *msiEndpointURL,
881 },
882 Secret: &ServicePrincipalMSISecret{
883 msiType: msiType,
884 clientResourceID: identityResourceID,
885 },
886 Resource: resource,
887 AutoRefresh: true,
888 RefreshWithin: defaultRefresh,
889 ClientID: userAssignedID,
890 },
891 refreshLock: &sync.RWMutex{},
892 sender: sender(),
893 refreshCallbacks: callbacks,
894 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
895 }
896
897 return spt, nil
898 }
899
900 // internal type that implements TokenRefreshError
901 type tokenRefreshError struct {
902 message string
903 resp *http.Response
904 }
905
906 // Error implements the error interface which is part of the TokenRefreshError interface.
907 func (tre tokenRefreshError) Error() string {
908 return tre.message
909 }
910
911 // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
912 func (tre tokenRefreshError) Response() *http.Response {
913 return tre.resp
914 }
915
916 func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
917 return tokenRefreshError{message: message, resp: resp}
918 }
919
920 // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
921 // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
922 func (spt *ServicePrincipalToken) EnsureFresh() error {
923 return spt.EnsureFreshWithContext(context.Background())
924 }
925
926 // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
927 // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
928 func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
929 // must take the read lock when initially checking the token's expiration
930 if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
931 // take the write lock then check again to see if the token was already refreshed
932 spt.refreshLock.Lock()
933 defer spt.refreshLock.Unlock()
934 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
935 return spt.refreshInternal(ctx, spt.inner.Resource)
936 }
937 }
938 return nil
939 }
940
941 // InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
942 func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
943 if spt.refreshCallbacks != nil {
944 for _, callback := range spt.refreshCallbacks {
945 err := callback(spt.inner.Token)
946 if err != nil {
947 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
948 }
949 }
950 }
951 return nil
952 }
953
954 // Refresh obtains a fresh token for the Service Principal.
955 // This method is safe for concurrent use.
956 func (spt *ServicePrincipalToken) Refresh() error {
957 return spt.RefreshWithContext(context.Background())
958 }
959
960 // RefreshWithContext obtains a fresh token for the Service Principal.
961 // This method is safe for concurrent use.
962 func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
963 spt.refreshLock.Lock()
964 defer spt.refreshLock.Unlock()
965 return spt.refreshInternal(ctx, spt.inner.Resource)
966 }
967
968 // RefreshExchange refreshes the token, but for a different resource.
969 // This method is safe for concurrent use.
970 func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
971 return spt.RefreshExchangeWithContext(context.Background(), resource)
972 }
973
974 // RefreshExchangeWithContext refreshes the token, but for a different resource.
975 // This method is safe for concurrent use.
976 func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
977 spt.refreshLock.Lock()
978 defer spt.refreshLock.Unlock()
979 return spt.refreshInternal(ctx, resource)
980 }
981
982 func (spt *ServicePrincipalToken) getGrantType() string {
983 switch spt.inner.Secret.(type) {
984 case *ServicePrincipalUsernamePasswordSecret:
985 return OAuthGrantTypeUserPass
986 case *ServicePrincipalAuthorizationCodeSecret:
987 return OAuthGrantTypeAuthorizationCode
988 default:
989 return OAuthGrantTypeClientCredentials
990 }
991 }
992
993 func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
994 if spt.customRefreshFunc != nil {
995 token, err := spt.customRefreshFunc(ctx, resource)
996 if err != nil {
997 return err
998 }
999 spt.inner.Token = *token
1000 return spt.InvokeRefreshCallbacks(spt.inner.Token)
1001 }
1002 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
1003 if err != nil {
1004 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
1005 }
1006 req.Header.Add("User-Agent", UserAgent())
1007 req = req.WithContext(ctx)
1008 var resp *http.Response
1009 authBodyFilter := func(b []byte) []byte {
1010 if logger.Level() != logger.LogAuth {
1011 return []byte("**REDACTED** authentication body")
1012 }
1013 return b
1014 }
1015 if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
1016 switch msiSecret.msiType {
1017 case msiTypeAppServiceV20170901:
1018 req.Method = http.MethodGet
1019 req.Header.Set("secret", os.Getenv(msiSecretEnv))
1020 break
1021 case msiTypeCloudShell:
1022 req.Header.Set("Metadata", "true")
1023 data := url.Values{}
1024 data.Set("resource", spt.inner.Resource)
1025 if spt.inner.ClientID != "" {
1026 data.Set("client_id", spt.inner.ClientID)
1027 } else if msiSecret.clientResourceID != "" {
1028 data.Set("msi_res_id", msiSecret.clientResourceID)
1029 }
1030 req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
1031 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1032 break
1033 case msiTypeIMDS:
1034 req.Method = http.MethodGet
1035 req.Header.Set("Metadata", "true")
1036 break
1037 }
1038 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1039 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
1040 } else {
1041 v := url.Values{}
1042 v.Set("client_id", spt.inner.ClientID)
1043 v.Set("resource", resource)
1044
1045 if spt.inner.Token.RefreshToken != "" {
1046 v.Set("grant_type", OAuthGrantTypeRefreshToken)
1047 v.Set("refresh_token", spt.inner.Token.RefreshToken)
1048 // web apps must specify client_secret when refreshing tokens
1049 // see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
1050 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
1051 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1052 if err != nil {
1053 return err
1054 }
1055 }
1056 } else {
1057 v.Set("grant_type", spt.getGrantType())
1058 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1059 if err != nil {
1060 return err
1061 }
1062 }
1063
1064 s := v.Encode()
1065 body := ioutil.NopCloser(strings.NewReader(s))
1066 req.ContentLength = int64(len(s))
1067 req.Header.Set(contentType, mimeTypeFormPost)
1068 req.Body = body
1069 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1070 resp, err = spt.sender.Do(req)
1071 }
1072
1073 // don't return a TokenRefreshError here; this will allow retry logic to apply
1074 if err != nil {
1075 return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
1076 } else if resp == nil {
1077 return fmt.Errorf("adal: received nil response and error")
1078 }
1079
1080 logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter})
1081 defer resp.Body.Close()
1082 rb, err := ioutil.ReadAll(resp.Body)
1083
1084 if resp.StatusCode != http.StatusOK {
1085 if err != nil {
1086 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
1087 }
1088 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
1089 }
1090
1091 // for the following error cases don't return a TokenRefreshError. the operation succeeded
1092 // but some transient failure happened during deserialization. by returning a generic error
1093 // the retry logic will kick in (we don't retry on TokenRefreshError).
1094
1095 if err != nil {
1096 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
1097 }
1098 if len(strings.Trim(string(rb), " ")) == 0 {
1099 return fmt.Errorf("adal: Empty service principal token received during refresh")
1100 }
1101 token := struct {
1102 AccessToken string `json:"access_token"`
1103 RefreshToken string `json:"refresh_token"`
1104
1105 // AAD returns expires_in as a string, ADFS returns it as an int
1106 ExpiresIn json.Number `json:"expires_in"`
1107 // expires_on can be in three formats, a UTC time stamp, or the number of seconds as a string *or* int.
1108 ExpiresOn interface{} `json:"expires_on"`
1109 NotBefore json.Number `json:"not_before"`
1110
1111 Resource string `json:"resource"`
1112 Type string `json:"token_type"`
1113 }{}
1114 // return a TokenRefreshError in the follow error cases as the token is in an unexpected format
1115 err = json.Unmarshal(rb, &token)
1116 if err != nil {
1117 return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
1118 }
1119 expiresOn := json.Number("")
1120 // ADFS doesn't include the expires_on field
1121 if token.ExpiresOn != nil {
1122 if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
1123 return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
1124 }
1125 }
1126 spt.inner.Token.AccessToken = token.AccessToken
1127 spt.inner.Token.RefreshToken = token.RefreshToken
1128 spt.inner.Token.ExpiresIn = token.ExpiresIn
1129 spt.inner.Token.ExpiresOn = expiresOn
1130 spt.inner.Token.NotBefore = token.NotBefore
1131 spt.inner.Token.Resource = token.Resource
1132 spt.inner.Token.Type = token.Type
1133
1134 return spt.InvokeRefreshCallbacks(spt.inner.Token)
1135 }
1136
1137 // converts expires_on to the number of seconds
1138 func parseExpiresOn(s interface{}) (json.Number, error) {
1139 // the JSON unmarshaler treats JSON numbers unmarshaled into an interface{} as float64
1140 asFloat64, ok := s.(float64)
1141 if ok {
1142 // this is the number of seconds as int case
1143 return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil
1144 }
1145 asStr, ok := s.(string)
1146 if !ok {
1147 return "", fmt.Errorf("unexpected expires_on type %T", s)
1148 }
1149 // convert the expiration date to the number of seconds from the unix epoch
1150 timeToDuration := func(t time.Time) json.Number {
1151 return json.Number(strconv.FormatInt(t.UTC().Unix(), 10))
1152 }
1153 if _, err := json.Number(asStr).Int64(); err == nil {
1154 // this is the number of seconds case, no conversion required
1155 return json.Number(asStr), nil
1156 } else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil {
1157 return timeToDuration(eo), nil
1158 } else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil {
1159 return timeToDuration(eo), nil
1160 } else {
1161 // unknown format
1162 return json.Number(""), err
1163 }
1164 }
1165
1166 // retry logic specific to retrieving a token from the IMDS endpoint
1167 func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
1168 // copied from client.go due to circular dependency
1169 retries := []int{
1170 http.StatusRequestTimeout, // 408
1171 http.StatusTooManyRequests, // 429
1172 http.StatusInternalServerError, // 500
1173 http.StatusBadGateway, // 502
1174 http.StatusServiceUnavailable, // 503
1175 http.StatusGatewayTimeout, // 504
1176 }
1177 // extra retry status codes specific to IMDS
1178 retries = append(retries,
1179 http.StatusNotFound,
1180 http.StatusGone,
1181 // all remaining 5xx
1182 http.StatusNotImplemented,
1183 http.StatusHTTPVersionNotSupported,
1184 http.StatusVariantAlsoNegotiates,
1185 http.StatusInsufficientStorage,
1186 http.StatusLoopDetected,
1187 http.StatusNotExtended,
1188 http.StatusNetworkAuthenticationRequired)
1189
1190 // see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
1191
1192 const maxDelay time.Duration = 60 * time.Second
1193
1194 attempt := 0
1195 delay := time.Duration(0)
1196
1197 // maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
1198 if maxAttempts < 1 {
1199 maxAttempts = defaultMaxMSIRefreshAttempts
1200 }
1201
1202 for attempt < maxAttempts {
1203 if resp != nil && resp.Body != nil {
1204 io.Copy(ioutil.Discard, resp.Body)
1205 resp.Body.Close()
1206 }
1207 resp, err = sender.Do(req)
1208 // we want to retry if err is not nil or the status code is in the list of retry codes
1209 if err == nil && !responseHasStatusCode(resp, retries...) {
1210 return
1211 }
1212
1213 // perform exponential backoff with a cap.
1214 // must increment attempt before calculating delay.
1215 attempt++
1216 // the base value of 2 is the "delta backoff" as specified in the guidance doc
1217 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
1218 if delay > maxDelay {
1219 delay = maxDelay
1220 }
1221
1222 select {
1223 case <-time.After(delay):
1224 // intentionally left blank
1225 case <-req.Context().Done():
1226 err = req.Context().Err()
1227 return
1228 }
1229 }
1230 return
1231 }
1232
1233 func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1234 if resp != nil {
1235 for _, i := range codes {
1236 if i == resp.StatusCode {
1237 return true
1238 }
1239 }
1240 }
1241 return false
1242 }
1243
1244 // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
1245 func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1246 spt.inner.AutoRefresh = autoRefresh
1247 }
1248
1249 // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
1250 // refresh the token.
1251 func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1252 spt.inner.RefreshWithin = d
1253 return
1254 }
1255
1256 // SetSender sets the http.Client used when obtaining the Service Principal token. An
1257 // undecorated http.Client is used by default.
1258 func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1259
1260 // OAuthToken implements the OAuthTokenProvider interface. It returns the current access token.
1261 func (spt *ServicePrincipalToken) OAuthToken() string {
1262 spt.refreshLock.RLock()
1263 defer spt.refreshLock.RUnlock()
1264 return spt.inner.Token.OAuthToken()
1265 }
1266
1267 // Token returns a copy of the current token.
1268 func (spt *ServicePrincipalToken) Token() Token {
1269 spt.refreshLock.RLock()
1270 defer spt.refreshLock.RUnlock()
1271 return spt.inner.Token
1272 }
1273
1274 // MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
1275 type MultiTenantServicePrincipalToken struct {
1276 PrimaryToken *ServicePrincipalToken
1277 AuxiliaryTokens []*ServicePrincipalToken
1278 }
1279
1280 // PrimaryOAuthToken returns the primary authorization token.
1281 func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1282 return mt.PrimaryToken.OAuthToken()
1283 }
1284
1285 // AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
1286 func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1287 tokens := make([]string, len(mt.AuxiliaryTokens))
1288 for i := range mt.AuxiliaryTokens {
1289 tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1290 }
1291 return tokens
1292 }
1293
1294 // NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
1295 func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1296 if err := validateStringParam(clientID, "clientID"); err != nil {
1297 return nil, err
1298 }
1299 if err := validateStringParam(secret, "secret"); err != nil {
1300 return nil, err
1301 }
1302 if err := validateStringParam(resource, "resource"); err != nil {
1303 return nil, err
1304 }
1305 auxTenants := multiTenantCfg.AuxiliaryTenants()
1306 m := MultiTenantServicePrincipalToken{
1307 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1308 }
1309 primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1310 if err != nil {
1311 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1312 }
1313 m.PrimaryToken = primary
1314 for i := range auxTenants {
1315 aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1316 if err != nil {
1317 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1318 }
1319 m.AuxiliaryTokens[i] = aux
1320 }
1321 return &m, nil
1322 }
1323
1324 // NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
1325 func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
1326 if err := validateStringParam(clientID, "clientID"); err != nil {
1327 return nil, err
1328 }
1329 if err := validateStringParam(resource, "resource"); err != nil {
1330 return nil, err
1331 }
1332 if certificate == nil {
1333 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
1334 }
1335 if privateKey == nil {
1336 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
1337 }
1338 auxTenants := multiTenantCfg.AuxiliaryTenants()
1339 m := MultiTenantServicePrincipalToken{
1340 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1341 }
1342 primary, err := NewServicePrincipalTokenWithSecret(
1343 *multiTenantCfg.PrimaryTenant(),
1344 clientID,
1345 resource,
1346 &ServicePrincipalCertificateSecret{
1347 PrivateKey: privateKey,
1348 Certificate: certificate,
1349 },
1350 )
1351 if err != nil {
1352 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1353 }
1354 m.PrimaryToken = primary
1355 for i := range auxTenants {
1356 aux, err := NewServicePrincipalTokenWithSecret(
1357 *auxTenants[i],
1358 clientID,
1359 resource,
1360 &ServicePrincipalCertificateSecret{
1361 PrivateKey: privateKey,
1362 Certificate: certificate,
1363 },
1364 )
1365 if err != nil {
1366 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1367 }
1368 m.AuxiliaryTokens[i] = aux
1369 }
1370 return &m, nil
1371 }
1372
1373 // MSIAvailable returns true if the MSI endpoint is available for authentication.
1374 func MSIAvailable(ctx context.Context, s Sender) bool {
1375 msiType, _, err := getMSIType()
1376
1377 if err != nil {
1378 return false
1379 }
1380
1381 if msiType != msiTypeIMDS {
1382 return true
1383 }
1384
1385 if s == nil {
1386 s = sender()
1387 }
1388
1389 resp, err := getMSIEndpoint(ctx, s)
1390
1391 if err == nil {
1392 resp.Body.Close()
1393 }
1394
1395 return err == nil
1396 }
1397