token.go raw

   1  package adal
   2  
   3  // Copyright 2017 Microsoft Corporation
   4  //
   5  //  Licensed under the Apache License, Version 2.0 (the "License");
   6  //  you may not use this file except in compliance with the License.
   7  //  You may obtain a copy of the License at
   8  //
   9  //      http://www.apache.org/licenses/LICENSE-2.0
  10  //
  11  //  Unless required by applicable law or agreed to in writing, software
  12  //  distributed under the License is distributed on an "AS IS" BASIS,
  13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14  //  See the License for the specific language governing permissions and
  15  //  limitations under the License.
  16  
  17  import (
  18  	"context"
  19  	"crypto/rand"
  20  	"crypto/rsa"
  21  	"crypto/sha1"
  22  	"crypto/x509"
  23  	"encoding/base64"
  24  	"encoding/json"
  25  	"errors"
  26  	"fmt"
  27  	"io"
  28  	"io/ioutil"
  29  	"math"
  30  	"net/http"
  31  	"net/url"
  32  	"os"
  33  	"strconv"
  34  	"strings"
  35  	"sync"
  36  	"time"
  37  
  38  	"github.com/Azure/go-autorest/autorest/date"
  39  	"github.com/Azure/go-autorest/logger"
  40  	"github.com/golang-jwt/jwt/v4"
  41  )
  42  
  43  const (
  44  	defaultRefresh = 5 * time.Minute
  45  
  46  	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
  47  	OAuthGrantTypeDeviceCode = "device_code"
  48  
  49  	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
  50  	OAuthGrantTypeClientCredentials = "client_credentials"
  51  
  52  	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
  53  	OAuthGrantTypeUserPass = "password"
  54  
  55  	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
  56  	OAuthGrantTypeRefreshToken = "refresh_token"
  57  
  58  	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
  59  	OAuthGrantTypeAuthorizationCode = "authorization_code"
  60  
  61  	// metadataHeader is the header required by MSI extension
  62  	metadataHeader = "Metadata"
  63  
  64  	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
  65  	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
  66  
  67  	// the API version to use for the MSI endpoint
  68  	msiAPIVersion = "2018-02-01"
  69  
  70  	// the default number of attempts to refresh an MSI authentication token
  71  	defaultMaxMSIRefreshAttempts = 5
  72  
  73  	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
  74  	msiEndpointEnv = "MSI_ENDPOINT"
  75  
  76  	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
  77  	msiSecretEnv = "MSI_SECRET"
  78  
  79  	// the API version to use for the legacy App Service MSI endpoint
  80  	appServiceAPIVersion2017 = "2017-09-01"
  81  
  82  	// secret header used when authenticating against app service MSI endpoint
  83  	secretHeader = "Secret"
  84  
  85  	// the format for expires_on in UTC with AM/PM
  86  	expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
  87  
  88  	// the format for expires_on in UTC without AM/PM
  89  	expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
  90  )
  91  
  92  // OAuthTokenProvider is an interface which should be implemented by an access token retriever
  93  type OAuthTokenProvider interface {
  94  	OAuthToken() string
  95  }
  96  
  97  // MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
  98  type MultitenantOAuthTokenProvider interface {
  99  	PrimaryOAuthToken() string
 100  	AuxiliaryOAuthTokens() []string
 101  }
 102  
 103  // TokenRefreshError is an interface used by errors returned during token refresh.
 104  type TokenRefreshError interface {
 105  	error
 106  	Response() *http.Response
 107  }
 108  
 109  // Refresher is an interface for token refresh functionality
 110  type Refresher interface {
 111  	Refresh() error
 112  	RefreshExchange(resource string) error
 113  	EnsureFresh() error
 114  }
 115  
 116  // RefresherWithContext is an interface for token refresh functionality
 117  type RefresherWithContext interface {
 118  	RefreshWithContext(ctx context.Context) error
 119  	RefreshExchangeWithContext(ctx context.Context, resource string) error
 120  	EnsureFreshWithContext(ctx context.Context) error
 121  }
 122  
 123  // TokenRefreshCallback is the type representing callbacks that will be called after
 124  // a successful token refresh
 125  type TokenRefreshCallback func(Token) error
 126  
 127  // TokenRefresh is a type representing a custom callback to refresh a token
 128  type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
 129  
 130  // Token encapsulates the access token used to authorize Azure requests.
 131  // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
 132  type Token struct {
 133  	AccessToken  string `json:"access_token"`
 134  	RefreshToken string `json:"refresh_token"`
 135  
 136  	ExpiresIn json.Number `json:"expires_in"`
 137  	ExpiresOn json.Number `json:"expires_on"`
 138  	NotBefore json.Number `json:"not_before"`
 139  
 140  	Resource string `json:"resource"`
 141  	Type     string `json:"token_type"`
 142  }
 143  
 144  func newToken() Token {
 145  	return Token{
 146  		ExpiresIn: "0",
 147  		ExpiresOn: "0",
 148  		NotBefore: "0",
 149  	}
 150  }
 151  
 152  // IsZero returns true if the token object is zero-initialized.
 153  func (t Token) IsZero() bool {
 154  	return t == Token{}
 155  }
 156  
 157  // Expires returns the time.Time when the Token expires.
 158  func (t Token) Expires() time.Time {
 159  	s, err := t.ExpiresOn.Float64()
 160  	if err != nil {
 161  		s = -3600
 162  	}
 163  
 164  	expiration := date.NewUnixTimeFromSeconds(s)
 165  
 166  	return time.Time(expiration).UTC()
 167  }
 168  
 169  // IsExpired returns true if the Token is expired, false otherwise.
 170  func (t Token) IsExpired() bool {
 171  	return t.WillExpireIn(0)
 172  }
 173  
 174  // WillExpireIn returns true if the Token will expire after the passed time.Duration interval
 175  // from now, false otherwise.
 176  func (t Token) WillExpireIn(d time.Duration) bool {
 177  	return !t.Expires().After(time.Now().Add(d))
 178  }
 179  
 180  // OAuthToken return the current access token
 181  func (t *Token) OAuthToken() string {
 182  	return t.AccessToken
 183  }
 184  
 185  // ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
 186  // that is submitted when acquiring an oAuth token.
 187  type ServicePrincipalSecret interface {
 188  	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
 189  }
 190  
 191  // ServicePrincipalNoSecret represents a secret type that contains no secret
 192  // meaning it is not valid for fetching a fresh token. This is used by Manual
 193  type ServicePrincipalNoSecret struct {
 194  }
 195  
 196  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret
 197  // It only returns an error for the ServicePrincipalNoSecret type
 198  func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 199  	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
 200  }
 201  
 202  // MarshalJSON implements the json.Marshaler interface.
 203  func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
 204  	type tokenType struct {
 205  		Type string `json:"type"`
 206  	}
 207  	return json.Marshal(tokenType{
 208  		Type: "ServicePrincipalNoSecret",
 209  	})
 210  }
 211  
 212  // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
 213  type ServicePrincipalTokenSecret struct {
 214  	ClientSecret string `json:"value"`
 215  }
 216  
 217  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 218  // It will populate the form submitted during oAuth Token Acquisition using the client_secret.
 219  func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 220  	v.Set("client_secret", tokenSecret.ClientSecret)
 221  	return nil
 222  }
 223  
 224  // MarshalJSON implements the json.Marshaler interface.
 225  func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
 226  	type tokenType struct {
 227  		Type  string `json:"type"`
 228  		Value string `json:"value"`
 229  	}
 230  	return json.Marshal(tokenType{
 231  		Type:  "ServicePrincipalTokenSecret",
 232  		Value: tokenSecret.ClientSecret,
 233  	})
 234  }
 235  
 236  // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
 237  type ServicePrincipalCertificateSecret struct {
 238  	Certificate *x509.Certificate
 239  	PrivateKey  *rsa.PrivateKey
 240  }
 241  
 242  // SignJwt returns the JWT signed with the certificate's private key.
 243  func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
 244  	hasher := sha1.New()
 245  	_, err := hasher.Write(secret.Certificate.Raw)
 246  	if err != nil {
 247  		return "", err
 248  	}
 249  
 250  	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
 251  
 252  	// The jti (JWT ID) claim provides a unique identifier for the JWT.
 253  	jti := make([]byte, 20)
 254  	_, err = rand.Read(jti)
 255  	if err != nil {
 256  		return "", err
 257  	}
 258  
 259  	token := jwt.New(jwt.SigningMethodRS256)
 260  	token.Header["x5t"] = thumbprint
 261  	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
 262  	token.Header["x5c"] = x5c
 263  	token.Claims = jwt.MapClaims{
 264  		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
 265  		"iss": spt.inner.ClientID,
 266  		"sub": spt.inner.ClientID,
 267  		"jti": base64.URLEncoding.EncodeToString(jti),
 268  		"nbf": time.Now().Unix(),
 269  		"exp": time.Now().Add(24 * time.Hour).Unix(),
 270  	}
 271  
 272  	signedString, err := token.SignedString(secret.PrivateKey)
 273  	return signedString, err
 274  }
 275  
 276  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 277  // It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
 278  func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 279  	jwt, err := secret.SignJwt(spt)
 280  	if err != nil {
 281  		return err
 282  	}
 283  
 284  	v.Set("client_assertion", jwt)
 285  	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
 286  	return nil
 287  }
 288  
 289  // MarshalJSON implements the json.Marshaler interface.
 290  func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
 291  	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
 292  }
 293  
 294  // ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
 295  type ServicePrincipalMSISecret struct {
 296  	msiType          msiType
 297  	clientResourceID string
 298  }
 299  
 300  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 301  func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 302  	return nil
 303  }
 304  
 305  // MarshalJSON implements the json.Marshaler interface.
 306  func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
 307  	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
 308  }
 309  
 310  // ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
 311  type ServicePrincipalUsernamePasswordSecret struct {
 312  	Username string `json:"username"`
 313  	Password string `json:"password"`
 314  }
 315  
 316  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 317  func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 318  	v.Set("username", secret.Username)
 319  	v.Set("password", secret.Password)
 320  	return nil
 321  }
 322  
 323  // MarshalJSON implements the json.Marshaler interface.
 324  func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
 325  	type tokenType struct {
 326  		Type     string `json:"type"`
 327  		Username string `json:"username"`
 328  		Password string `json:"password"`
 329  	}
 330  	return json.Marshal(tokenType{
 331  		Type:     "ServicePrincipalUsernamePasswordSecret",
 332  		Username: secret.Username,
 333  		Password: secret.Password,
 334  	})
 335  }
 336  
 337  // ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
 338  type ServicePrincipalAuthorizationCodeSecret struct {
 339  	ClientSecret      string `json:"value"`
 340  	AuthorizationCode string `json:"authCode"`
 341  	RedirectURI       string `json:"redirect"`
 342  }
 343  
 344  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 345  func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 346  	v.Set("code", secret.AuthorizationCode)
 347  	v.Set("client_secret", secret.ClientSecret)
 348  	v.Set("redirect_uri", secret.RedirectURI)
 349  	return nil
 350  }
 351  
 352  // MarshalJSON implements the json.Marshaler interface.
 353  func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
 354  	type tokenType struct {
 355  		Type     string `json:"type"`
 356  		Value    string `json:"value"`
 357  		AuthCode string `json:"authCode"`
 358  		Redirect string `json:"redirect"`
 359  	}
 360  	return json.Marshal(tokenType{
 361  		Type:     "ServicePrincipalAuthorizationCodeSecret",
 362  		Value:    secret.ClientSecret,
 363  		AuthCode: secret.AuthorizationCode,
 364  		Redirect: secret.RedirectURI,
 365  	})
 366  }
 367  
 368  // ServicePrincipalFederatedSecret implements ServicePrincipalSecret for Federated JWTs.
 369  type ServicePrincipalFederatedSecret struct {
 370  	jwt string
 371  }
 372  
 373  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
 374  // It will populate the form submitted during OAuth Token Acquisition using a JWT signed by an OIDC issuer.
 375  func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
 376  
 377  	v.Set("client_assertion", secret.jwt)
 378  	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
 379  	return nil
 380  }
 381  
 382  // MarshalJSON implements the json.Marshaler interface.
 383  func (secret ServicePrincipalFederatedSecret) MarshalJSON() ([]byte, error) {
 384  	return nil, errors.New("marshalling ServicePrincipalFederatedSecret is not supported")
 385  }
 386  
 387  // ServicePrincipalToken encapsulates a Token created for a Service Principal.
 388  type ServicePrincipalToken struct {
 389  	inner             servicePrincipalToken
 390  	refreshLock       *sync.RWMutex
 391  	sender            Sender
 392  	customRefreshFunc TokenRefresh
 393  	refreshCallbacks  []TokenRefreshCallback
 394  	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
 395  	// Settings this to a value less than 1 will use the default value.
 396  	MaxMSIRefreshAttempts int
 397  }
 398  
 399  // MarshalTokenJSON returns the marshalled inner token.
 400  func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
 401  	return json.Marshal(spt.inner.Token)
 402  }
 403  
 404  // SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
 405  func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
 406  	spt.refreshCallbacks = callbacks
 407  }
 408  
 409  // SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
 410  func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
 411  	spt.customRefreshFunc = customRefreshFunc
 412  }
 413  
 414  // MarshalJSON implements the json.Marshaler interface.
 415  func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
 416  	return json.Marshal(spt.inner)
 417  }
 418  
 419  // UnmarshalJSON implements the json.Unmarshaler interface.
 420  func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
 421  	// need to determine the token type
 422  	raw := map[string]interface{}{}
 423  	err := json.Unmarshal(data, &raw)
 424  	if err != nil {
 425  		return err
 426  	}
 427  	secret := raw["secret"].(map[string]interface{})
 428  	switch secret["type"] {
 429  	case "ServicePrincipalNoSecret":
 430  		spt.inner.Secret = &ServicePrincipalNoSecret{}
 431  	case "ServicePrincipalTokenSecret":
 432  		spt.inner.Secret = &ServicePrincipalTokenSecret{}
 433  	case "ServicePrincipalCertificateSecret":
 434  		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
 435  	case "ServicePrincipalMSISecret":
 436  		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
 437  	case "ServicePrincipalUsernamePasswordSecret":
 438  		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
 439  	case "ServicePrincipalAuthorizationCodeSecret":
 440  		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
 441  	case "ServicePrincipalFederatedSecret":
 442  		return errors.New("unmarshalling ServicePrincipalFederatedSecret is not supported")
 443  	default:
 444  		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
 445  	}
 446  	err = json.Unmarshal(data, &spt.inner)
 447  	if err != nil {
 448  		return err
 449  	}
 450  	// Don't override the refreshLock or the sender if those have been already set.
 451  	if spt.refreshLock == nil {
 452  		spt.refreshLock = &sync.RWMutex{}
 453  	}
 454  	if spt.sender == nil {
 455  		spt.sender = sender()
 456  	}
 457  	return nil
 458  }
 459  
 460  // internal type used for marshalling/unmarshalling
 461  type servicePrincipalToken struct {
 462  	Token         Token                  `json:"token"`
 463  	Secret        ServicePrincipalSecret `json:"secret"`
 464  	OauthConfig   OAuthConfig            `json:"oauth"`
 465  	ClientID      string                 `json:"clientID"`
 466  	Resource      string                 `json:"resource"`
 467  	AutoRefresh   bool                   `json:"autoRefresh"`
 468  	RefreshWithin time.Duration          `json:"refreshWithin"`
 469  }
 470  
 471  func validateOAuthConfig(oac OAuthConfig) error {
 472  	if oac.IsZero() {
 473  		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
 474  	}
 475  	return nil
 476  }
 477  
 478  // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
 479  func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 480  	if err := validateOAuthConfig(oauthConfig); err != nil {
 481  		return nil, err
 482  	}
 483  	if err := validateStringParam(id, "id"); err != nil {
 484  		return nil, err
 485  	}
 486  	if err := validateStringParam(resource, "resource"); err != nil {
 487  		return nil, err
 488  	}
 489  	if secret == nil {
 490  		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
 491  	}
 492  	spt := &ServicePrincipalToken{
 493  		inner: servicePrincipalToken{
 494  			Token:         newToken(),
 495  			OauthConfig:   oauthConfig,
 496  			Secret:        secret,
 497  			ClientID:      id,
 498  			Resource:      resource,
 499  			AutoRefresh:   true,
 500  			RefreshWithin: defaultRefresh,
 501  		},
 502  		refreshLock:      &sync.RWMutex{},
 503  		sender:           sender(),
 504  		refreshCallbacks: callbacks,
 505  	}
 506  	return spt, nil
 507  }
 508  
 509  // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
 510  func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 511  	if err := validateOAuthConfig(oauthConfig); err != nil {
 512  		return nil, err
 513  	}
 514  	if err := validateStringParam(clientID, "clientID"); err != nil {
 515  		return nil, err
 516  	}
 517  	if err := validateStringParam(resource, "resource"); err != nil {
 518  		return nil, err
 519  	}
 520  	if token.IsZero() {
 521  		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
 522  	}
 523  	spt, err := NewServicePrincipalTokenWithSecret(
 524  		oauthConfig,
 525  		clientID,
 526  		resource,
 527  		&ServicePrincipalNoSecret{},
 528  		callbacks...)
 529  	if err != nil {
 530  		return nil, err
 531  	}
 532  
 533  	spt.inner.Token = token
 534  
 535  	return spt, nil
 536  }
 537  
 538  // NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
 539  func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 540  	if err := validateOAuthConfig(oauthConfig); err != nil {
 541  		return nil, err
 542  	}
 543  	if err := validateStringParam(clientID, "clientID"); err != nil {
 544  		return nil, err
 545  	}
 546  	if err := validateStringParam(resource, "resource"); err != nil {
 547  		return nil, err
 548  	}
 549  	if secret == nil {
 550  		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
 551  	}
 552  	if token.IsZero() {
 553  		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
 554  	}
 555  	spt, err := NewServicePrincipalTokenWithSecret(
 556  		oauthConfig,
 557  		clientID,
 558  		resource,
 559  		secret,
 560  		callbacks...)
 561  	if err != nil {
 562  		return nil, err
 563  	}
 564  
 565  	spt.inner.Token = token
 566  
 567  	return spt, nil
 568  }
 569  
 570  // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
 571  // credentials scoped to the named resource.
 572  func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 573  	if err := validateOAuthConfig(oauthConfig); err != nil {
 574  		return nil, err
 575  	}
 576  	if err := validateStringParam(clientID, "clientID"); err != nil {
 577  		return nil, err
 578  	}
 579  	if err := validateStringParam(secret, "secret"); err != nil {
 580  		return nil, err
 581  	}
 582  	if err := validateStringParam(resource, "resource"); err != nil {
 583  		return nil, err
 584  	}
 585  	return NewServicePrincipalTokenWithSecret(
 586  		oauthConfig,
 587  		clientID,
 588  		resource,
 589  		&ServicePrincipalTokenSecret{
 590  			ClientSecret: secret,
 591  		},
 592  		callbacks...,
 593  	)
 594  }
 595  
 596  // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
 597  func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 598  	if err := validateOAuthConfig(oauthConfig); err != nil {
 599  		return nil, err
 600  	}
 601  	if err := validateStringParam(clientID, "clientID"); err != nil {
 602  		return nil, err
 603  	}
 604  	if err := validateStringParam(resource, "resource"); err != nil {
 605  		return nil, err
 606  	}
 607  	if certificate == nil {
 608  		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
 609  	}
 610  	if privateKey == nil {
 611  		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
 612  	}
 613  	return NewServicePrincipalTokenWithSecret(
 614  		oauthConfig,
 615  		clientID,
 616  		resource,
 617  		&ServicePrincipalCertificateSecret{
 618  			PrivateKey:  privateKey,
 619  			Certificate: certificate,
 620  		},
 621  		callbacks...,
 622  	)
 623  }
 624  
 625  // NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
 626  func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 627  	if err := validateOAuthConfig(oauthConfig); err != nil {
 628  		return nil, err
 629  	}
 630  	if err := validateStringParam(clientID, "clientID"); err != nil {
 631  		return nil, err
 632  	}
 633  	if err := validateStringParam(username, "username"); err != nil {
 634  		return nil, err
 635  	}
 636  	if err := validateStringParam(password, "password"); err != nil {
 637  		return nil, err
 638  	}
 639  	if err := validateStringParam(resource, "resource"); err != nil {
 640  		return nil, err
 641  	}
 642  	return NewServicePrincipalTokenWithSecret(
 643  		oauthConfig,
 644  		clientID,
 645  		resource,
 646  		&ServicePrincipalUsernamePasswordSecret{
 647  			Username: username,
 648  			Password: password,
 649  		},
 650  		callbacks...,
 651  	)
 652  }
 653  
 654  // NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
 655  func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 656  
 657  	if err := validateOAuthConfig(oauthConfig); err != nil {
 658  		return nil, err
 659  	}
 660  	if err := validateStringParam(clientID, "clientID"); err != nil {
 661  		return nil, err
 662  	}
 663  	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
 664  		return nil, err
 665  	}
 666  	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
 667  		return nil, err
 668  	}
 669  	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
 670  		return nil, err
 671  	}
 672  	if err := validateStringParam(resource, "resource"); err != nil {
 673  		return nil, err
 674  	}
 675  
 676  	return NewServicePrincipalTokenWithSecret(
 677  		oauthConfig,
 678  		clientID,
 679  		resource,
 680  		&ServicePrincipalAuthorizationCodeSecret{
 681  			ClientSecret:      clientSecret,
 682  			AuthorizationCode: authorizationCode,
 683  			RedirectURI:       redirectURI,
 684  		},
 685  		callbacks...,
 686  	)
 687  }
 688  
 689  // NewServicePrincipalTokenFromFederatedToken creates a ServicePrincipalToken from the supplied federated OIDC JWT.
 690  func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 691  	if err := validateOAuthConfig(oauthConfig); err != nil {
 692  		return nil, err
 693  	}
 694  	if err := validateStringParam(clientID, "clientID"); err != nil {
 695  		return nil, err
 696  	}
 697  	if err := validateStringParam(resource, "resource"); err != nil {
 698  		return nil, err
 699  	}
 700  	if jwt == "" {
 701  		return nil, fmt.Errorf("parameter 'jwt' cannot be empty")
 702  	}
 703  	return NewServicePrincipalTokenWithSecret(
 704  		oauthConfig,
 705  		clientID,
 706  		resource,
 707  		&ServicePrincipalFederatedSecret{
 708  			jwt: jwt,
 709  		},
 710  		callbacks...,
 711  	)
 712  }
 713  
 714  type msiType int
 715  
 716  const (
 717  	msiTypeUnavailable msiType = iota
 718  	msiTypeAppServiceV20170901
 719  	msiTypeCloudShell
 720  	msiTypeIMDS
 721  )
 722  
 723  func (m msiType) String() string {
 724  	switch m {
 725  	case msiTypeAppServiceV20170901:
 726  		return "AppServiceV20170901"
 727  	case msiTypeCloudShell:
 728  		return "CloudShell"
 729  	case msiTypeIMDS:
 730  		return "IMDS"
 731  	default:
 732  		return fmt.Sprintf("unhandled MSI type %d", m)
 733  	}
 734  }
 735  
 736  // returns the MSI type and endpoint, or an error
 737  func getMSIType() (msiType, string, error) {
 738  	if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
 739  		// if the env var MSI_ENDPOINT is set
 740  		if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
 741  			// if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
 742  			return msiTypeAppServiceV20170901, endpointEnvVar, nil
 743  		}
 744  		// if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
 745  		return msiTypeCloudShell, endpointEnvVar, nil
 746  	}
 747  	// if MSI_ENDPOINT is NOT set assume the msiType is IMDS
 748  	return msiTypeIMDS, msiEndpoint, nil
 749  }
 750  
 751  // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
 752  // NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell.
 753  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
 754  func GetMSIVMEndpoint() (string, error) {
 755  	return msiEndpoint, nil
 756  }
 757  
 758  // GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions.
 759  // It will return an error when not running in an app service/functions environment.
 760  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
 761  func GetMSIAppServiceEndpoint() (string, error) {
 762  	msiType, endpoint, err := getMSIType()
 763  	if err != nil {
 764  		return "", err
 765  	}
 766  	switch msiType {
 767  	case msiTypeAppServiceV20170901:
 768  		return endpoint, nil
 769  	default:
 770  		return "", fmt.Errorf("%s is not app service environment", msiType)
 771  	}
 772  }
 773  
 774  // GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
 775  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
 776  func GetMSIEndpoint() (string, error) {
 777  	_, endpoint, err := getMSIType()
 778  	return endpoint, err
 779  }
 780  
 781  // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
 782  // It will use the system assigned identity when creating the token.
 783  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
 784  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
 785  func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 786  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
 787  }
 788  
 789  // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
 790  // It will use the clientID of specified user assigned identity when creating the token.
 791  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
 792  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
 793  func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 794  	if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
 795  		return nil, err
 796  	}
 797  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
 798  }
 799  
 800  // NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
 801  // It will use the azure resource id of user assigned identity when creating the token.
 802  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
 803  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
 804  func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 805  	if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
 806  		return nil, err
 807  	}
 808  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
 809  }
 810  
 811  // ManagedIdentityOptions contains optional values for configuring managed identity authentication.
 812  type ManagedIdentityOptions struct {
 813  	// ClientID is the user-assigned identity to use during authentication.
 814  	// It is mutually exclusive with IdentityResourceID.
 815  	ClientID string
 816  
 817  	// IdentityResourceID is the resource ID of the user-assigned identity to use during authentication.
 818  	// It is mutually exclusive with ClientID.
 819  	IdentityResourceID string
 820  }
 821  
 822  // NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity.
 823  // It supports the following managed identity environments.
 824  // - App Service Environment (API version 2017-09-01 only)
 825  // - Cloud shell
 826  // - IMDS with a system or user assigned identity
 827  func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 828  	if options == nil {
 829  		options = &ManagedIdentityOptions{}
 830  	}
 831  	return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
 832  }
 833  
 834  func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
 835  	if err := validateStringParam(resource, "resource"); err != nil {
 836  		return nil, err
 837  	}
 838  	if userAssignedID != "" && identityResourceID != "" {
 839  		return nil, errors.New("cannot specify userAssignedID and identityResourceID")
 840  	}
 841  	msiType, endpoint, err := getMSIType()
 842  	if err != nil {
 843  		logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err)
 844  		return nil, err
 845  	}
 846  	logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint)
 847  	if msiEndpoint != "" {
 848  		endpoint = msiEndpoint
 849  		logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint)
 850  	}
 851  	msiEndpointURL, err := url.Parse(endpoint)
 852  	if err != nil {
 853  		return nil, err
 854  	}
 855  	// cloud shell sends its data in the request body
 856  	if msiType != msiTypeCloudShell {
 857  		v := url.Values{}
 858  		v.Set("resource", resource)
 859  		clientIDParam := "client_id"
 860  		switch msiType {
 861  		case msiTypeAppServiceV20170901:
 862  			clientIDParam = "clientid"
 863  			v.Set("api-version", appServiceAPIVersion2017)
 864  			break
 865  		case msiTypeIMDS:
 866  			v.Set("api-version", msiAPIVersion)
 867  		}
 868  		if userAssignedID != "" {
 869  			v.Set(clientIDParam, userAssignedID)
 870  		} else if identityResourceID != "" {
 871  			v.Set("mi_res_id", identityResourceID)
 872  		}
 873  		msiEndpointURL.RawQuery = v.Encode()
 874  	}
 875  
 876  	spt := &ServicePrincipalToken{
 877  		inner: servicePrincipalToken{
 878  			Token: newToken(),
 879  			OauthConfig: OAuthConfig{
 880  				TokenEndpoint: *msiEndpointURL,
 881  			},
 882  			Secret: &ServicePrincipalMSISecret{
 883  				msiType:          msiType,
 884  				clientResourceID: identityResourceID,
 885  			},
 886  			Resource:      resource,
 887  			AutoRefresh:   true,
 888  			RefreshWithin: defaultRefresh,
 889  			ClientID:      userAssignedID,
 890  		},
 891  		refreshLock:           &sync.RWMutex{},
 892  		sender:                sender(),
 893  		refreshCallbacks:      callbacks,
 894  		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
 895  	}
 896  
 897  	return spt, nil
 898  }
 899  
 900  // internal type that implements TokenRefreshError
 901  type tokenRefreshError struct {
 902  	message string
 903  	resp    *http.Response
 904  }
 905  
 906  // Error implements the error interface which is part of the TokenRefreshError interface.
 907  func (tre tokenRefreshError) Error() string {
 908  	return tre.message
 909  }
 910  
 911  // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
 912  func (tre tokenRefreshError) Response() *http.Response {
 913  	return tre.resp
 914  }
 915  
 916  func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
 917  	return tokenRefreshError{message: message, resp: resp}
 918  }
 919  
 920  // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
 921  // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
 922  func (spt *ServicePrincipalToken) EnsureFresh() error {
 923  	return spt.EnsureFreshWithContext(context.Background())
 924  }
 925  
 926  // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
 927  // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
 928  func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
 929  	// must take the read lock when initially checking the token's expiration
 930  	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
 931  		// take the write lock then check again to see if the token was already refreshed
 932  		spt.refreshLock.Lock()
 933  		defer spt.refreshLock.Unlock()
 934  		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
 935  			return spt.refreshInternal(ctx, spt.inner.Resource)
 936  		}
 937  	}
 938  	return nil
 939  }
 940  
 941  // InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
 942  func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
 943  	if spt.refreshCallbacks != nil {
 944  		for _, callback := range spt.refreshCallbacks {
 945  			err := callback(spt.inner.Token)
 946  			if err != nil {
 947  				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
 948  			}
 949  		}
 950  	}
 951  	return nil
 952  }
 953  
 954  // Refresh obtains a fresh token for the Service Principal.
 955  // This method is safe for concurrent use.
 956  func (spt *ServicePrincipalToken) Refresh() error {
 957  	return spt.RefreshWithContext(context.Background())
 958  }
 959  
 960  // RefreshWithContext obtains a fresh token for the Service Principal.
 961  // This method is safe for concurrent use.
 962  func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
 963  	spt.refreshLock.Lock()
 964  	defer spt.refreshLock.Unlock()
 965  	return spt.refreshInternal(ctx, spt.inner.Resource)
 966  }
 967  
 968  // RefreshExchange refreshes the token, but for a different resource.
 969  // This method is safe for concurrent use.
 970  func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
 971  	return spt.RefreshExchangeWithContext(context.Background(), resource)
 972  }
 973  
 974  // RefreshExchangeWithContext refreshes the token, but for a different resource.
 975  // This method is safe for concurrent use.
 976  func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
 977  	spt.refreshLock.Lock()
 978  	defer spt.refreshLock.Unlock()
 979  	return spt.refreshInternal(ctx, resource)
 980  }
 981  
 982  func (spt *ServicePrincipalToken) getGrantType() string {
 983  	switch spt.inner.Secret.(type) {
 984  	case *ServicePrincipalUsernamePasswordSecret:
 985  		return OAuthGrantTypeUserPass
 986  	case *ServicePrincipalAuthorizationCodeSecret:
 987  		return OAuthGrantTypeAuthorizationCode
 988  	default:
 989  		return OAuthGrantTypeClientCredentials
 990  	}
 991  }
 992  
 993  func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
 994  	if spt.customRefreshFunc != nil {
 995  		token, err := spt.customRefreshFunc(ctx, resource)
 996  		if err != nil {
 997  			return err
 998  		}
 999  		spt.inner.Token = *token
1000  		return spt.InvokeRefreshCallbacks(spt.inner.Token)
1001  	}
1002  	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
1003  	if err != nil {
1004  		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
1005  	}
1006  	req.Header.Add("User-Agent", UserAgent())
1007  	req = req.WithContext(ctx)
1008  	var resp *http.Response
1009  	authBodyFilter := func(b []byte) []byte {
1010  		if logger.Level() != logger.LogAuth {
1011  			return []byte("**REDACTED** authentication body")
1012  		}
1013  		return b
1014  	}
1015  	if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
1016  		switch msiSecret.msiType {
1017  		case msiTypeAppServiceV20170901:
1018  			req.Method = http.MethodGet
1019  			req.Header.Set("secret", os.Getenv(msiSecretEnv))
1020  			break
1021  		case msiTypeCloudShell:
1022  			req.Header.Set("Metadata", "true")
1023  			data := url.Values{}
1024  			data.Set("resource", spt.inner.Resource)
1025  			if spt.inner.ClientID != "" {
1026  				data.Set("client_id", spt.inner.ClientID)
1027  			} else if msiSecret.clientResourceID != "" {
1028  				data.Set("msi_res_id", msiSecret.clientResourceID)
1029  			}
1030  			req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
1031  			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1032  			break
1033  		case msiTypeIMDS:
1034  			req.Method = http.MethodGet
1035  			req.Header.Set("Metadata", "true")
1036  			break
1037  		}
1038  		logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1039  		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
1040  	} else {
1041  		v := url.Values{}
1042  		v.Set("client_id", spt.inner.ClientID)
1043  		v.Set("resource", resource)
1044  
1045  		if spt.inner.Token.RefreshToken != "" {
1046  			v.Set("grant_type", OAuthGrantTypeRefreshToken)
1047  			v.Set("refresh_token", spt.inner.Token.RefreshToken)
1048  			// web apps must specify client_secret when refreshing tokens
1049  			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
1050  			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
1051  				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1052  				if err != nil {
1053  					return err
1054  				}
1055  			}
1056  		} else {
1057  			v.Set("grant_type", spt.getGrantType())
1058  			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1059  			if err != nil {
1060  				return err
1061  			}
1062  		}
1063  
1064  		s := v.Encode()
1065  		body := ioutil.NopCloser(strings.NewReader(s))
1066  		req.ContentLength = int64(len(s))
1067  		req.Header.Set(contentType, mimeTypeFormPost)
1068  		req.Body = body
1069  		logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1070  		resp, err = spt.sender.Do(req)
1071  	}
1072  
1073  	// don't return a TokenRefreshError here; this will allow retry logic to apply
1074  	if err != nil {
1075  		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
1076  	} else if resp == nil {
1077  		return fmt.Errorf("adal: received nil response and error")
1078  	}
1079  
1080  	logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter})
1081  	defer resp.Body.Close()
1082  	rb, err := ioutil.ReadAll(resp.Body)
1083  
1084  	if resp.StatusCode != http.StatusOK {
1085  		if err != nil {
1086  			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
1087  		}
1088  		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
1089  	}
1090  
1091  	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
1092  	// but some transient failure happened during deserialization.  by returning a generic error
1093  	// the retry logic will kick in (we don't retry on TokenRefreshError).
1094  
1095  	if err != nil {
1096  		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
1097  	}
1098  	if len(strings.Trim(string(rb), " ")) == 0 {
1099  		return fmt.Errorf("adal: Empty service principal token received during refresh")
1100  	}
1101  	token := struct {
1102  		AccessToken  string `json:"access_token"`
1103  		RefreshToken string `json:"refresh_token"`
1104  
1105  		// AAD returns expires_in as a string, ADFS returns it as an int
1106  		ExpiresIn json.Number `json:"expires_in"`
1107  		// expires_on can be in three formats, a UTC time stamp, or the number of seconds as a string *or* int.
1108  		ExpiresOn interface{} `json:"expires_on"`
1109  		NotBefore json.Number `json:"not_before"`
1110  
1111  		Resource string `json:"resource"`
1112  		Type     string `json:"token_type"`
1113  	}{}
1114  	// return a TokenRefreshError in the follow error cases as the token is in an unexpected format
1115  	err = json.Unmarshal(rb, &token)
1116  	if err != nil {
1117  		return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
1118  	}
1119  	expiresOn := json.Number("")
1120  	// ADFS doesn't include the expires_on field
1121  	if token.ExpiresOn != nil {
1122  		if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
1123  			return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
1124  		}
1125  	}
1126  	spt.inner.Token.AccessToken = token.AccessToken
1127  	spt.inner.Token.RefreshToken = token.RefreshToken
1128  	spt.inner.Token.ExpiresIn = token.ExpiresIn
1129  	spt.inner.Token.ExpiresOn = expiresOn
1130  	spt.inner.Token.NotBefore = token.NotBefore
1131  	spt.inner.Token.Resource = token.Resource
1132  	spt.inner.Token.Type = token.Type
1133  
1134  	return spt.InvokeRefreshCallbacks(spt.inner.Token)
1135  }
1136  
1137  // converts expires_on to the number of seconds
1138  func parseExpiresOn(s interface{}) (json.Number, error) {
1139  	// the JSON unmarshaler treats JSON numbers unmarshaled into an interface{} as float64
1140  	asFloat64, ok := s.(float64)
1141  	if ok {
1142  		// this is the number of seconds as int case
1143  		return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil
1144  	}
1145  	asStr, ok := s.(string)
1146  	if !ok {
1147  		return "", fmt.Errorf("unexpected expires_on type %T", s)
1148  	}
1149  	// convert the expiration date to the number of seconds from the unix epoch
1150  	timeToDuration := func(t time.Time) json.Number {
1151  		return json.Number(strconv.FormatInt(t.UTC().Unix(), 10))
1152  	}
1153  	if _, err := json.Number(asStr).Int64(); err == nil {
1154  		// this is the number of seconds case, no conversion required
1155  		return json.Number(asStr), nil
1156  	} else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil {
1157  		return timeToDuration(eo), nil
1158  	} else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil {
1159  		return timeToDuration(eo), nil
1160  	} else {
1161  		// unknown format
1162  		return json.Number(""), err
1163  	}
1164  }
1165  
1166  // retry logic specific to retrieving a token from the IMDS endpoint
1167  func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
1168  	// copied from client.go due to circular dependency
1169  	retries := []int{
1170  		http.StatusRequestTimeout,      // 408
1171  		http.StatusTooManyRequests,     // 429
1172  		http.StatusInternalServerError, // 500
1173  		http.StatusBadGateway,          // 502
1174  		http.StatusServiceUnavailable,  // 503
1175  		http.StatusGatewayTimeout,      // 504
1176  	}
1177  	// extra retry status codes specific to IMDS
1178  	retries = append(retries,
1179  		http.StatusNotFound,
1180  		http.StatusGone,
1181  		// all remaining 5xx
1182  		http.StatusNotImplemented,
1183  		http.StatusHTTPVersionNotSupported,
1184  		http.StatusVariantAlsoNegotiates,
1185  		http.StatusInsufficientStorage,
1186  		http.StatusLoopDetected,
1187  		http.StatusNotExtended,
1188  		http.StatusNetworkAuthenticationRequired)
1189  
1190  	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
1191  
1192  	const maxDelay time.Duration = 60 * time.Second
1193  
1194  	attempt := 0
1195  	delay := time.Duration(0)
1196  
1197  	// maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
1198  	if maxAttempts < 1 {
1199  		maxAttempts = defaultMaxMSIRefreshAttempts
1200  	}
1201  
1202  	for attempt < maxAttempts {
1203  		if resp != nil && resp.Body != nil {
1204  			io.Copy(ioutil.Discard, resp.Body)
1205  			resp.Body.Close()
1206  		}
1207  		resp, err = sender.Do(req)
1208  		// we want to retry if err is not nil or the status code is in the list of retry codes
1209  		if err == nil && !responseHasStatusCode(resp, retries...) {
1210  			return
1211  		}
1212  
1213  		// perform exponential backoff with a cap.
1214  		// must increment attempt before calculating delay.
1215  		attempt++
1216  		// the base value of 2 is the "delta backoff" as specified in the guidance doc
1217  		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
1218  		if delay > maxDelay {
1219  			delay = maxDelay
1220  		}
1221  
1222  		select {
1223  		case <-time.After(delay):
1224  			// intentionally left blank
1225  		case <-req.Context().Done():
1226  			err = req.Context().Err()
1227  			return
1228  		}
1229  	}
1230  	return
1231  }
1232  
1233  func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1234  	if resp != nil {
1235  		for _, i := range codes {
1236  			if i == resp.StatusCode {
1237  				return true
1238  			}
1239  		}
1240  	}
1241  	return false
1242  }
1243  
1244  // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
1245  func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1246  	spt.inner.AutoRefresh = autoRefresh
1247  }
1248  
1249  // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
1250  // refresh the token.
1251  func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1252  	spt.inner.RefreshWithin = d
1253  	return
1254  }
1255  
1256  // SetSender sets the http.Client used when obtaining the Service Principal token. An
1257  // undecorated http.Client is used by default.
1258  func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1259  
1260  // OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
1261  func (spt *ServicePrincipalToken) OAuthToken() string {
1262  	spt.refreshLock.RLock()
1263  	defer spt.refreshLock.RUnlock()
1264  	return spt.inner.Token.OAuthToken()
1265  }
1266  
1267  // Token returns a copy of the current token.
1268  func (spt *ServicePrincipalToken) Token() Token {
1269  	spt.refreshLock.RLock()
1270  	defer spt.refreshLock.RUnlock()
1271  	return spt.inner.Token
1272  }
1273  
1274  // MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
1275  type MultiTenantServicePrincipalToken struct {
1276  	PrimaryToken    *ServicePrincipalToken
1277  	AuxiliaryTokens []*ServicePrincipalToken
1278  }
1279  
1280  // PrimaryOAuthToken returns the primary authorization token.
1281  func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1282  	return mt.PrimaryToken.OAuthToken()
1283  }
1284  
1285  // AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
1286  func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1287  	tokens := make([]string, len(mt.AuxiliaryTokens))
1288  	for i := range mt.AuxiliaryTokens {
1289  		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1290  	}
1291  	return tokens
1292  }
1293  
1294  // NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
1295  func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1296  	if err := validateStringParam(clientID, "clientID"); err != nil {
1297  		return nil, err
1298  	}
1299  	if err := validateStringParam(secret, "secret"); err != nil {
1300  		return nil, err
1301  	}
1302  	if err := validateStringParam(resource, "resource"); err != nil {
1303  		return nil, err
1304  	}
1305  	auxTenants := multiTenantCfg.AuxiliaryTenants()
1306  	m := MultiTenantServicePrincipalToken{
1307  		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1308  	}
1309  	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1310  	if err != nil {
1311  		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1312  	}
1313  	m.PrimaryToken = primary
1314  	for i := range auxTenants {
1315  		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1316  		if err != nil {
1317  			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1318  		}
1319  		m.AuxiliaryTokens[i] = aux
1320  	}
1321  	return &m, nil
1322  }
1323  
1324  // NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
1325  func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
1326  	if err := validateStringParam(clientID, "clientID"); err != nil {
1327  		return nil, err
1328  	}
1329  	if err := validateStringParam(resource, "resource"); err != nil {
1330  		return nil, err
1331  	}
1332  	if certificate == nil {
1333  		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
1334  	}
1335  	if privateKey == nil {
1336  		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
1337  	}
1338  	auxTenants := multiTenantCfg.AuxiliaryTenants()
1339  	m := MultiTenantServicePrincipalToken{
1340  		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1341  	}
1342  	primary, err := NewServicePrincipalTokenWithSecret(
1343  		*multiTenantCfg.PrimaryTenant(),
1344  		clientID,
1345  		resource,
1346  		&ServicePrincipalCertificateSecret{
1347  			PrivateKey:  privateKey,
1348  			Certificate: certificate,
1349  		},
1350  	)
1351  	if err != nil {
1352  		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1353  	}
1354  	m.PrimaryToken = primary
1355  	for i := range auxTenants {
1356  		aux, err := NewServicePrincipalTokenWithSecret(
1357  			*auxTenants[i],
1358  			clientID,
1359  			resource,
1360  			&ServicePrincipalCertificateSecret{
1361  				PrivateKey:  privateKey,
1362  				Certificate: certificate,
1363  			},
1364  		)
1365  		if err != nil {
1366  			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1367  		}
1368  		m.AuxiliaryTokens[i] = aux
1369  	}
1370  	return &m, nil
1371  }
1372  
1373  // MSIAvailable returns true if the MSI endpoint is available for authentication.
1374  func MSIAvailable(ctx context.Context, s Sender) bool {
1375  	msiType, _, err := getMSIType()
1376  
1377  	if err != nil {
1378  		return false
1379  	}
1380  
1381  	if msiType != msiTypeIMDS {
1382  		return true
1383  	}
1384  
1385  	if s == nil {
1386  		s = sender()
1387  	}
1388  
1389  	resp, err := getMSIEndpoint(ctx, s)
1390  
1391  	if err == nil {
1392  		resp.Body.Close()
1393  	}
1394  
1395  	return err == nil
1396  }
1397