confidential_client.go raw

   1  //go:build go1.18
   2  // +build go1.18
   3  
   4  // Copyright (c) Microsoft Corporation. All rights reserved.
   5  // Licensed under the MIT License.
   6  
   7  package azidentity
   8  
   9  import (
  10  	"context"
  11  	"errors"
  12  	"fmt"
  13  	"net/http"
  14  	"os"
  15  	"strings"
  16  	"sync"
  17  
  18  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
  19  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
  20  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
  21  	"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
  22  	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
  23  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
  24  )
  25  
  26  type confidentialClientOptions struct {
  27  	azcore.ClientOptions
  28  
  29  	AdditionallyAllowedTenants []string
  30  	// Assertion for on-behalf-of authentication
  31  	Assertion                         string
  32  	Cache                             Cache
  33  	DisableInstanceDiscovery, SendX5C bool
  34  }
  35  
  36  // confidentialClient wraps the MSAL confidential client
  37  type confidentialClient struct {
  38  	cae, noCAE               msalConfidentialClient
  39  	caeMu, noCAEMu, clientMu *sync.Mutex
  40  	clientID, tenantID       string
  41  	cred                     confidential.Credential
  42  	host                     string
  43  	name                     string
  44  	opts                     confidentialClientOptions
  45  	region                   string
  46  	azClient                 *azcore.Client
  47  }
  48  
  49  func newConfidentialClient(tenantID, clientID, name string, cred confidential.Credential, opts confidentialClientOptions) (*confidentialClient, error) {
  50  	if !validTenantID(tenantID) {
  51  		return nil, errInvalidTenantID
  52  	}
  53  	host, err := setAuthorityHost(opts.Cloud)
  54  	if err != nil {
  55  		return nil, err
  56  	}
  57  	client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
  58  		Tracing: runtime.TracingOptions{
  59  			Namespace: traceNamespace,
  60  		},
  61  	}, &opts.ClientOptions)
  62  	if err != nil {
  63  		return nil, err
  64  	}
  65  	opts.AdditionallyAllowedTenants = resolveAdditionalTenants(opts.AdditionallyAllowedTenants)
  66  	return &confidentialClient{
  67  		caeMu:    &sync.Mutex{},
  68  		clientID: clientID,
  69  		clientMu: &sync.Mutex{},
  70  		cred:     cred,
  71  		host:     host,
  72  		name:     name,
  73  		noCAEMu:  &sync.Mutex{},
  74  		opts:     opts,
  75  		region:   os.Getenv(azureRegionalAuthorityName),
  76  		tenantID: tenantID,
  77  		azClient: client,
  78  	}, nil
  79  }
  80  
  81  // GetToken requests an access token from MSAL, checking the cache first.
  82  func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
  83  	if len(tro.Scopes) < 1 {
  84  		return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name)
  85  	}
  86  	// we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants
  87  	if c.name != credNameManagedIdentity {
  88  		tenant, err := c.resolveTenant(tro.TenantID)
  89  		if err != nil {
  90  			return azcore.AccessToken{}, err
  91  		}
  92  		tro.TenantID = tenant
  93  	}
  94  	client, mu, err := c.client(tro)
  95  	if err != nil {
  96  		return azcore.AccessToken{}, err
  97  	}
  98  	mu.Lock()
  99  	defer mu.Unlock()
 100  	var ar confidential.AuthResult
 101  	if c.opts.Assertion != "" {
 102  		ar, err = client.AcquireTokenOnBehalfOf(ctx, c.opts.Assertion, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
 103  	} else {
 104  		ar, err = client.AcquireTokenSilent(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
 105  		if err != nil {
 106  			ar, err = client.AcquireTokenByCredential(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
 107  		}
 108  	}
 109  	if err != nil {
 110  		var (
 111  			authFailedErr  *AuthenticationFailedError
 112  			unavailableErr credentialUnavailable
 113  		)
 114  		if !(errors.As(err, &unavailableErr) || errors.As(err, &authFailedErr)) {
 115  			err = newAuthenticationFailedErrorFromMSAL(c.name, err)
 116  		}
 117  	} else {
 118  		msg := fmt.Sprintf(scopeLogFmt, c.name, strings.Join(ar.GrantedScopes, ", "))
 119  		log.Write(EventAuthentication, msg)
 120  	}
 121  	return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC(), RefreshOn: ar.Metadata.RefreshOn.UTC()}, err
 122  }
 123  
 124  func (c *confidentialClient) client(tro policy.TokenRequestOptions) (msalConfidentialClient, *sync.Mutex, error) {
 125  	c.clientMu.Lock()
 126  	defer c.clientMu.Unlock()
 127  	if tro.EnableCAE {
 128  		if c.cae == nil {
 129  			client, err := c.newMSALClient(true)
 130  			if err != nil {
 131  				return nil, nil, err
 132  			}
 133  			c.cae = client
 134  		}
 135  		return c.cae, c.caeMu, nil
 136  	}
 137  	if c.noCAE == nil {
 138  		client, err := c.newMSALClient(false)
 139  		if err != nil {
 140  			return nil, nil, err
 141  		}
 142  		c.noCAE = client
 143  	}
 144  	return c.noCAE, c.noCAEMu, nil
 145  }
 146  
 147  func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) {
 148  	cache, err := internal.ExportReplace(c.opts.Cache, enableCAE)
 149  	if err != nil {
 150  		return nil, err
 151  	}
 152  	authority := runtime.JoinPaths(c.host, c.tenantID)
 153  	o := []confidential.Option{
 154  		confidential.WithAzureRegion(c.region),
 155  		confidential.WithCache(cache),
 156  		confidential.WithHTTPClient(c),
 157  	}
 158  	if enableCAE {
 159  		o = append(o, confidential.WithClientCapabilities(cp1))
 160  	}
 161  	if c.opts.SendX5C {
 162  		o = append(o, confidential.WithX5C())
 163  	}
 164  	if c.opts.DisableInstanceDiscovery || strings.ToLower(c.tenantID) == "adfs" {
 165  		o = append(o, confidential.WithInstanceDiscovery(false))
 166  	}
 167  	return confidential.New(authority, c.clientID, c.cred, o...)
 168  }
 169  
 170  // resolveTenant returns the correct WithTenantID() argument for a token request given the client's
 171  // configuration, or an error when that configuration doesn't allow the specified tenant
 172  func (c *confidentialClient) resolveTenant(specified string) (string, error) {
 173  	return resolveTenant(c.tenantID, specified, c.name, c.opts.AdditionallyAllowedTenants)
 174  }
 175  
 176  // these methods satisfy the MSAL ops.HTTPClient interface
 177  
 178  func (c *confidentialClient) CloseIdleConnections() {
 179  	// do nothing
 180  }
 181  
 182  func (c *confidentialClient) Do(r *http.Request) (*http.Response, error) {
 183  	return doForClient(c.azClient, r)
 184  }
 185