oauth.go raw

   1  // Copyright (c) Microsoft Corporation.
   2  // Licensed under the MIT license.
   3  
   4  package oauth
   5  
   6  import (
   7  	"context"
   8  	"encoding/json"
   9  	"fmt"
  10  	"io"
  11  	"time"
  12  
  13  	"github.com/google/uuid"
  14  
  15  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
  16  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
  17  	internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
  18  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
  19  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
  20  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
  21  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
  22  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
  23  )
  24  
  25  // ResolveEndpointer contains the methods for resolving authority endpoints.
  26  type ResolveEndpointer interface {
  27  	ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error)
  28  }
  29  
  30  // AccessTokens contains the methods for fetching tokens from different sources.
  31  type AccessTokens interface {
  32  	DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error)
  33  	FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error)
  34  	FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error)
  35  	FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error)
  36  	FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error)
  37  	FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error)
  38  	FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error)
  39  	FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error)
  40  	FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error)
  41  	FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error)
  42  }
  43  
  44  // FetchAuthority will be implemented by authority.Authority.
  45  type FetchAuthority interface {
  46  	UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
  47  	AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
  48  }
  49  
  50  // FetchWSTrust contains the methods for interacting with WSTrust endpoints.
  51  type FetchWSTrust interface {
  52  	Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error)
  53  	SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error)
  54  }
  55  
  56  // Client provides tokens for various types of token requests.
  57  type Client struct {
  58  	Resolver     ResolveEndpointer
  59  	AccessTokens AccessTokens
  60  	Authority    FetchAuthority
  61  	WSTrust      FetchWSTrust
  62  }
  63  
  64  // New is the constructor for Token.
  65  func New(httpClient ops.HTTPClient) *Client {
  66  	r := ops.New(httpClient)
  67  	return &Client{
  68  		Resolver:     newAuthorityEndpoint(r),
  69  		AccessTokens: r.AccessTokens(),
  70  		Authority:    r.Authority(),
  71  		WSTrust:      r.WSTrust(),
  72  	}
  73  }
  74  
  75  // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance.
  76  func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
  77  	return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName)
  78  }
  79  
  80  // AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
  81  // This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
  82  func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
  83  	return t.Authority.AADInstanceDiscovery(ctx, authorityInfo)
  84  }
  85  
  86  // AuthCode returns a token based on an authorization code.
  87  func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
  88  	if err := scopeError(req.AuthParams); err != nil {
  89  		return accesstokens.TokenResponse{}, err
  90  	}
  91  	if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil {
  92  		return accesstokens.TokenResponse{}, err
  93  	}
  94  
  95  	tResp, err := t.AccessTokens.FromAuthCode(ctx, req)
  96  	if err != nil {
  97  		return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err)
  98  	}
  99  	return tResp, nil
 100  }
 101  
 102  // Credential acquires a token from the authority using a client credentials grant.
 103  func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
 104  	if cred.TokenProvider != nil {
 105  		now := time.Now()
 106  		scopes := make([]string, len(authParams.Scopes))
 107  		copy(scopes, authParams.Scopes)
 108  		params := exported.TokenProviderParameters{
 109  			Claims:        authParams.Claims,
 110  			CorrelationID: uuid.New().String(),
 111  			Scopes:        scopes,
 112  			TenantID:      authParams.AuthorityInfo.Tenant,
 113  		}
 114  		pr, err := cred.TokenProvider(ctx, params)
 115  		if err != nil {
 116  			if len(scopes) == 0 {
 117  				err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
 118  				return accesstokens.TokenResponse{}, err
 119  			}
 120  			return accesstokens.TokenResponse{}, err
 121  		}
 122  		tr := accesstokens.TokenResponse{
 123  			TokenType:     authParams.AuthnScheme.AccessTokenType(),
 124  			AccessToken:   pr.AccessToken,
 125  			ExpiresOn:     now.Add(time.Duration(pr.ExpiresInSeconds) * time.Second),
 126  			GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
 127  		}
 128  		if pr.RefreshInSeconds > 0 {
 129  			tr.RefreshOn = internalTime.DurationTime{
 130  				T: now.Add(time.Duration(pr.RefreshInSeconds) * time.Second),
 131  			}
 132  		}
 133  		return tr, nil
 134  	}
 135  
 136  	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
 137  		return accesstokens.TokenResponse{}, err
 138  	}
 139  
 140  	if cred.Secret != "" {
 141  		return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret)
 142  	}
 143  	jwt, err := cred.JWT(ctx, authParams)
 144  	if err != nil {
 145  		return accesstokens.TokenResponse{}, err
 146  	}
 147  	return t.AccessTokens.FromAssertion(ctx, authParams, jwt)
 148  }
 149  
 150  // Credential acquires a token from the authority using a client credentials grant.
 151  func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
 152  	if err := scopeError(authParams); err != nil {
 153  		return accesstokens.TokenResponse{}, err
 154  	}
 155  	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
 156  		return accesstokens.TokenResponse{}, err
 157  	}
 158  
 159  	if cred.Secret != "" {
 160  		return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret)
 161  	}
 162  	jwt, err := cred.JWT(ctx, authParams)
 163  	if err != nil {
 164  		return accesstokens.TokenResponse{}, err
 165  	}
 166  	tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt)
 167  	if err != nil {
 168  		return accesstokens.TokenResponse{}, err
 169  	}
 170  	return tr, nil
 171  }
 172  
 173  func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) {
 174  	if err := scopeError(authParams); err != nil {
 175  		return accesstokens.TokenResponse{}, err
 176  	}
 177  	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
 178  		return accesstokens.TokenResponse{}, err
 179  	}
 180  
 181  	tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret)
 182  	if err != nil {
 183  		return accesstokens.TokenResponse{}, err
 184  	}
 185  	return tr, nil
 186  }
 187  
 188  // UsernamePassword retrieves a token where a username and password is used. However, if this is
 189  // a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password.
 190  func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) {
 191  	if err := scopeError(authParams); err != nil {
 192  		return accesstokens.TokenResponse{}, err
 193  	}
 194  
 195  	if authParams.AuthorityInfo.AuthorityType == authority.ADFS {
 196  		if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil {
 197  			return accesstokens.TokenResponse{}, err
 198  		}
 199  		return t.AccessTokens.FromUsernamePassword(ctx, authParams)
 200  	}
 201  	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
 202  		return accesstokens.TokenResponse{}, err
 203  	}
 204  
 205  	userRealm, err := t.Authority.UserRealm(ctx, authParams)
 206  	if err != nil {
 207  		return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err)
 208  	}
 209  
 210  	switch userRealm.AccountType {
 211  	case authority.Federated:
 212  		mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL)
 213  		if err != nil {
 214  			err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err)
 215  			return accesstokens.TokenResponse{}, err
 216  		}
 217  
 218  		saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint)
 219  		if err != nil {
 220  			err = fmt.Errorf("problem getting SAML token info: %w", err)
 221  			return accesstokens.TokenResponse{}, err
 222  		}
 223  		tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml)
 224  		if err != nil {
 225  			return accesstokens.TokenResponse{}, err
 226  		}
 227  		return tr, nil
 228  	case authority.Managed:
 229  		if len(authParams.Scopes) == 0 {
 230  			err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
 231  			return accesstokens.TokenResponse{}, err
 232  		}
 233  		return t.AccessTokens.FromUsernamePassword(ctx, authParams)
 234  	}
 235  	return accesstokens.TokenResponse{}, errors.New("unknown account type")
 236  }
 237  
 238  // DeviceCode is the result of a call to Token.DeviceCode().
 239  type DeviceCode struct {
 240  	// Result is the device code result from the first call in the device code flow. This allows
 241  	// the caller to retrieve the displayed code that is used to authorize on the second device.
 242  	Result     accesstokens.DeviceCodeResult
 243  	authParams authority.AuthParams
 244  
 245  	accessTokens AccessTokens
 246  }
 247  
 248  // Token returns a token AFTER the user uses the user code on the second device. This will block
 249  // until either: (1) the code is input by the user and the service releases a token, (2) the token
 250  // expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service
 251  // error occurs.
 252  func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) {
 253  	if d.accessTokens == nil {
 254  		return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid")
 255  	}
 256  
 257  	var cancel context.CancelFunc
 258  	if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) {
 259  		ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn)
 260  	} else {
 261  		ctx, cancel = context.WithCancel(ctx)
 262  	}
 263  	defer cancel()
 264  
 265  	var interval = 50 * time.Millisecond
 266  	timer := time.NewTimer(interval)
 267  	defer timer.Stop()
 268  
 269  	for {
 270  		timer.Reset(interval)
 271  		select {
 272  		case <-ctx.Done():
 273  			return accesstokens.TokenResponse{}, ctx.Err()
 274  		case <-timer.C:
 275  			interval += interval * 2
 276  			if interval > 5*time.Second {
 277  				interval = 5 * time.Second
 278  			}
 279  		}
 280  
 281  		token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result)
 282  		if err != nil && isWaitDeviceCodeErr(err) {
 283  			continue
 284  		}
 285  		return token, err // This handles if it was a non-wait error or success
 286  	}
 287  }
 288  
 289  type deviceCodeError struct {
 290  	Error string `json:"error"`
 291  }
 292  
 293  func isWaitDeviceCodeErr(err error) bool {
 294  	var c errors.CallErr
 295  	if !errors.As(err, &c) {
 296  		return false
 297  	}
 298  	if c.Resp.StatusCode != 400 {
 299  		return false
 300  	}
 301  	var dCErr deviceCodeError
 302  	defer c.Resp.Body.Close()
 303  	body, err := io.ReadAll(c.Resp.Body)
 304  	if err != nil {
 305  		return false
 306  	}
 307  	err = json.Unmarshal(body, &dCErr)
 308  	if err != nil {
 309  		return false
 310  	}
 311  	if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" {
 312  		return true
 313  	}
 314  	return false
 315  }
 316  
 317  // DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second
 318  // device and optionally the token once the code has been entered on the second device.
 319  func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) {
 320  	if err := scopeError(authParams); err != nil {
 321  		return DeviceCode{}, err
 322  	}
 323  
 324  	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
 325  		return DeviceCode{}, err
 326  	}
 327  
 328  	dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams)
 329  	if err != nil {
 330  		return DeviceCode{}, err
 331  	}
 332  
 333  	return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil
 334  }
 335  
 336  func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error {
 337  	endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName)
 338  	if err != nil {
 339  		return fmt.Errorf("unable to resolve an endpoint: %w", err)
 340  	}
 341  	authParams.Endpoints = endpoints
 342  	return nil
 343  }
 344  
 345  // scopeError takes an authority.AuthParams and returns an error
 346  // if len(AuthParams.Scope) == 0.
 347  func scopeError(a authority.AuthParams) error {
 348  	// TODO(someone): we could look deeper at the message to determine if
 349  	// it's a scope error, but this is a good start.
 350  	/*
 351  		{error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope
 352  		openid offline_access profile is not valid. Client credential flows must have a scope value
 353  		with /.default suffixed to the resource identifier (application ID URI)...}
 354  	*/
 355  	if len(a.Scopes) == 0 {
 356  		return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid")
 357  	}
 358  	return nil
 359  }
 360