public_client.go raw

   1  //go:build go1.18
   2  // +build go1.18
   3  
   4  // Copyright (c) Microsoft Corporation. All rights reserved.
   5  // Licensed under the MIT License.
   6  
   7  package azidentity
   8  
   9  import (
  10  	"context"
  11  	"errors"
  12  	"fmt"
  13  	"net/http"
  14  	"strings"
  15  	"sync"
  16  
  17  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
  18  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
  19  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
  20  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
  21  	"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
  22  	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
  23  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
  24  
  25  	// this import ensures well-known configurations in azcore/cloud have ARM audiences for Authenticate()
  26  	_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
  27  )
  28  
  29  type publicClientOptions struct {
  30  	azcore.ClientOptions
  31  
  32  	AdditionallyAllowedTenants     []string
  33  	Cache                          Cache
  34  	DeviceCodePrompt               func(context.Context, DeviceCodeMessage) error
  35  	DisableAutomaticAuthentication bool
  36  	DisableInstanceDiscovery       bool
  37  	LoginHint, RedirectURL         string
  38  	Record                         AuthenticationRecord
  39  	Username, Password             string
  40  }
  41  
  42  // publicClient wraps the MSAL public client
  43  type publicClient struct {
  44  	cae, noCAE               msalPublicClient
  45  	caeMu, noCAEMu, clientMu *sync.Mutex
  46  	clientID, tenantID       string
  47  	defaultScope             []string
  48  	host                     string
  49  	name                     string
  50  	opts                     publicClientOptions
  51  	record                   AuthenticationRecord
  52  	azClient                 *azcore.Client
  53  }
  54  
  55  var errScopeRequired = errors.New("authenticating in this environment requires specifying a scope in TokenRequestOptions")
  56  
  57  func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) {
  58  	if !validTenantID(tenantID) {
  59  		return nil, errInvalidTenantID
  60  	}
  61  	host, err := setAuthorityHost(o.Cloud)
  62  	if err != nil {
  63  		return nil, err
  64  	}
  65  	// if the application specified a cloud configuration, use its ARM audience as the default scope for Authenticate()
  66  	audience := o.Cloud.Services[cloud.ResourceManager].Audience
  67  	if audience == "" {
  68  		// no cloud configuration, or no ARM audience, specified; try to map the host to a well-known one (all of which have a trailing slash)
  69  		if !strings.HasSuffix(host, "/") {
  70  			host += "/"
  71  		}
  72  		switch host {
  73  		case cloud.AzureChina.ActiveDirectoryAuthorityHost:
  74  			audience = cloud.AzureChina.Services[cloud.ResourceManager].Audience
  75  		case cloud.AzureGovernment.ActiveDirectoryAuthorityHost:
  76  			audience = cloud.AzureGovernment.Services[cloud.ResourceManager].Audience
  77  		case cloud.AzurePublic.ActiveDirectoryAuthorityHost:
  78  			audience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
  79  		}
  80  	}
  81  	// if we didn't come up with an audience, the application will have to specify a scope for Authenticate()
  82  	var defaultScope []string
  83  	if audience != "" {
  84  		defaultScope = []string{audience + defaultSuffix}
  85  	}
  86  	client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
  87  		Tracing: runtime.TracingOptions{
  88  			Namespace: traceNamespace,
  89  		},
  90  	}, &o.ClientOptions)
  91  	if err != nil {
  92  		return nil, err
  93  	}
  94  	o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants)
  95  	return &publicClient{
  96  		caeMu:        &sync.Mutex{},
  97  		clientID:     clientID,
  98  		clientMu:     &sync.Mutex{},
  99  		defaultScope: defaultScope,
 100  		host:         host,
 101  		name:         name,
 102  		noCAEMu:      &sync.Mutex{},
 103  		opts:         o,
 104  		record:       o.Record,
 105  		tenantID:     tenantID,
 106  		azClient:     client,
 107  	}, nil
 108  }
 109  
 110  func (p *publicClient) Authenticate(ctx context.Context, tro *policy.TokenRequestOptions) (AuthenticationRecord, error) {
 111  	if tro == nil {
 112  		tro = &policy.TokenRequestOptions{}
 113  	}
 114  	if len(tro.Scopes) == 0 {
 115  		if p.defaultScope == nil {
 116  			return AuthenticationRecord{}, errScopeRequired
 117  		}
 118  		tro.Scopes = p.defaultScope
 119  	}
 120  	client, mu, err := p.client(*tro)
 121  	if err != nil {
 122  		return AuthenticationRecord{}, err
 123  	}
 124  	mu.Lock()
 125  	defer mu.Unlock()
 126  	_, err = p.reqToken(ctx, client, *tro)
 127  	if err == nil {
 128  		scope := strings.Join(tro.Scopes, ", ")
 129  		msg := fmt.Sprintf("%s.Authenticate() acquired a token for scope %q", p.name, scope)
 130  		log.Write(EventAuthentication, msg)
 131  	}
 132  	return p.record, err
 133  }
 134  
 135  // GetToken requests an access token from MSAL, checking the cache first.
 136  func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
 137  	if len(tro.Scopes) < 1 {
 138  		return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name)
 139  	}
 140  	tenant, err := p.resolveTenant(tro.TenantID)
 141  	if err != nil {
 142  		return azcore.AccessToken{}, err
 143  	}
 144  	client, mu, err := p.client(tro)
 145  	if err != nil {
 146  		return azcore.AccessToken{}, err
 147  	}
 148  	mu.Lock()
 149  	defer mu.Unlock()
 150  	ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.record.account()), public.WithClaims(tro.Claims), public.WithTenantID(tenant))
 151  	if err == nil {
 152  		return p.token(ar, err)
 153  	}
 154  	if p.opts.DisableAutomaticAuthentication {
 155  		return azcore.AccessToken{}, newAuthenticationRequiredError(p.name, tro)
 156  	}
 157  	return p.reqToken(ctx, client, tro)
 158  }
 159  
 160  // reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache.
 161  func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
 162  	tenant, err := p.resolveTenant(tro.TenantID)
 163  	if err != nil {
 164  		return azcore.AccessToken{}, err
 165  	}
 166  	var ar public.AuthResult
 167  	switch p.name {
 168  	case credNameBrowser:
 169  		ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes,
 170  			public.WithClaims(tro.Claims),
 171  			public.WithLoginHint(p.opts.LoginHint),
 172  			public.WithRedirectURI(p.opts.RedirectURL),
 173  			public.WithTenantID(tenant),
 174  		)
 175  	case credNameDeviceCode:
 176  		dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
 177  		if e != nil {
 178  			return azcore.AccessToken{}, e
 179  		}
 180  		err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{
 181  			Message:         dc.Result.Message,
 182  			UserCode:        dc.Result.UserCode,
 183  			VerificationURL: dc.Result.VerificationURL,
 184  		})
 185  		if err == nil {
 186  			ar, err = dc.AuthenticationResult(ctx)
 187  		}
 188  	case credNameUserPassword:
 189  		ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
 190  	default:
 191  		return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name)
 192  	}
 193  	return p.token(ar, err)
 194  }
 195  
 196  func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) {
 197  	p.clientMu.Lock()
 198  	defer p.clientMu.Unlock()
 199  	if tro.EnableCAE {
 200  		if p.cae == nil {
 201  			client, err := p.newMSALClient(true)
 202  			if err != nil {
 203  				return nil, nil, err
 204  			}
 205  			p.cae = client
 206  		}
 207  		return p.cae, p.caeMu, nil
 208  	}
 209  	if p.noCAE == nil {
 210  		client, err := p.newMSALClient(false)
 211  		if err != nil {
 212  			return nil, nil, err
 213  		}
 214  		p.noCAE = client
 215  	}
 216  	return p.noCAE, p.noCAEMu, nil
 217  }
 218  
 219  func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) {
 220  	c, err := internal.ExportReplace(p.opts.Cache, enableCAE)
 221  	if err != nil {
 222  		return nil, err
 223  	}
 224  	o := []public.Option{
 225  		public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)),
 226  		public.WithCache(c),
 227  		public.WithHTTPClient(p),
 228  	}
 229  	if enableCAE {
 230  		o = append(o, public.WithClientCapabilities(cp1))
 231  	}
 232  	if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" {
 233  		o = append(o, public.WithInstanceDiscovery(false))
 234  	}
 235  	return public.New(p.clientID, o...)
 236  }
 237  
 238  func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) {
 239  	if err == nil {
 240  		msg := fmt.Sprintf(scopeLogFmt, p.name, strings.Join(ar.GrantedScopes, ", "))
 241  		log.Write(EventAuthentication, msg)
 242  		p.record, err = newAuthenticationRecord(ar)
 243  	} else {
 244  		err = newAuthenticationFailedErrorFromMSAL(p.name, err)
 245  	}
 246  	return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC(), RefreshOn: ar.Metadata.RefreshOn.UTC()}, err
 247  }
 248  
 249  // resolveTenant returns the correct WithTenantID() argument for a token request given the client's
 250  // configuration, or an error when that configuration doesn't allow the specified tenant
 251  func (p *publicClient) resolveTenant(specified string) (string, error) {
 252  	t, err := resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants)
 253  	if t == p.tenantID {
 254  		// callers pass this value to MSAL's WithTenantID(). There's no need to redundantly specify
 255  		// the client's default tenant and doing so is an error when that tenant is "organizations"
 256  		t = ""
 257  	}
 258  	return t, err
 259  }
 260  
 261  // these methods satisfy the MSAL ops.HTTPClient interface
 262  
 263  func (p *publicClient) CloseIdleConnections() {
 264  	// do nothing
 265  }
 266  
 267  func (p *publicClient) Do(r *http.Request) (*http.Response, error) {
 268  	return doForClient(p.azClient, r)
 269  }
 270