managed_identity_client.go raw

   1  //go:build go1.18
   2  // +build go1.18
   3  
   4  // Copyright (c) Microsoft Corporation. All rights reserved.
   5  // Licensed under the MIT License.
   6  
   7  package azidentity
   8  
   9  import (
  10  	"context"
  11  	"errors"
  12  	"fmt"
  13  	"net/http"
  14  	"strings"
  15  	"time"
  16  
  17  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
  18  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
  19  	azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
  20  	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
  21  	msalerrors "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
  22  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
  23  )
  24  
  25  const (
  26  	arcIMDSEndpoint          = "IMDS_ENDPOINT"
  27  	defaultIdentityClientID  = "DEFAULT_IDENTITY_CLIENT_ID"
  28  	identityEndpoint         = "IDENTITY_ENDPOINT"
  29  	identityHeader           = "IDENTITY_HEADER"
  30  	identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
  31  	headerMetadata           = "Metadata"
  32  	imdsEndpoint             = "http://169.254.169.254/metadata/identity/oauth2/token"
  33  	miResID                  = "mi_res_id"
  34  	msiEndpoint              = "MSI_ENDPOINT"
  35  	msiResID                 = "msi_res_id"
  36  	msiSecret                = "MSI_SECRET"
  37  	imdsAPIVersion           = "2018-02-01"
  38  	azureArcAPIVersion       = "2020-06-01"
  39  	qpClientID               = "client_id"
  40  	serviceFabricAPIVersion  = "2019-07-01-preview"
  41  )
  42  
  43  var imdsProbeTimeout = time.Second
  44  
  45  type managedIdentityClient struct {
  46  	azClient                      *azcore.Client
  47  	imds, probeIMDS, userAssigned bool
  48  	// chained indicates whether the client is part of a credential chain. If true, the client will return
  49  	// a credentialUnavailableError instead of an AuthenticationFailedError for an unexpected IMDS response.
  50  	chained    bool
  51  	msalClient msalManagedIdentityClient
  52  }
  53  
  54  // setIMDSRetryOptionDefaults sets zero-valued fields to default values appropriate for IMDS
  55  func setIMDSRetryOptionDefaults(o *policy.RetryOptions) {
  56  	if o.MaxRetries == 0 {
  57  		o.MaxRetries = 6
  58  	}
  59  	if o.MaxRetryDelay == 0 {
  60  		o.MaxRetryDelay = 25 * time.Second
  61  	}
  62  	if o.RetryDelay == 0 {
  63  		o.RetryDelay = 2 * time.Second
  64  	}
  65  	if o.StatusCodes == nil {
  66  		o.StatusCodes = []int{
  67  			// IMDS docs recommend retrying 404, 410, 429 and 5xx
  68  			// https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling
  69  			http.StatusNotFound,                      // 404
  70  			http.StatusGone,                          // 410
  71  			http.StatusTooManyRequests,               // 429
  72  			http.StatusInternalServerError,           // 500
  73  			http.StatusNotImplemented,                // 501
  74  			http.StatusBadGateway,                    // 502
  75  			http.StatusServiceUnavailable,            // 503
  76  			http.StatusGatewayTimeout,                // 504
  77  			http.StatusHTTPVersionNotSupported,       // 505
  78  			http.StatusVariantAlsoNegotiates,         // 506
  79  			http.StatusInsufficientStorage,           // 507
  80  			http.StatusLoopDetected,                  // 508
  81  			http.StatusNotExtended,                   // 510
  82  			http.StatusNetworkAuthenticationRequired, // 511
  83  		}
  84  	}
  85  	if o.TryTimeout == 0 {
  86  		o.TryTimeout = 1 * time.Minute
  87  	}
  88  }
  89  
  90  // newManagedIdentityClient creates a new instance of the ManagedIdentityClient with the ManagedIdentityCredentialOptions
  91  // that are passed into it along with a default pipeline.
  92  // options: ManagedIdentityCredentialOptions configure policies for the pipeline and the authority host that
  93  // will be used to retrieve tokens and authenticate
  94  func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*managedIdentityClient, error) {
  95  	if options == nil {
  96  		options = &ManagedIdentityCredentialOptions{}
  97  	}
  98  	cp := options.ClientOptions
  99  	c := managedIdentityClient{}
 100  	source, err := managedidentity.GetSource()
 101  	if err != nil {
 102  		return nil, err
 103  	}
 104  	env := string(source)
 105  	if source == managedidentity.DefaultToIMDS {
 106  		env = "IMDS"
 107  		c.imds = true
 108  		c.probeIMDS = options.dac
 109  		setIMDSRetryOptionDefaults(&cp.Retry)
 110  	}
 111  
 112  	c.azClient, err = azcore.NewClient(module, version, azruntime.PipelineOptions{
 113  		Tracing: azruntime.TracingOptions{
 114  			Namespace: traceNamespace,
 115  		},
 116  	}, &cp)
 117  	if err != nil {
 118  		return nil, err
 119  	}
 120  
 121  	id := managedidentity.SystemAssigned()
 122  	if options.ID != nil {
 123  		c.userAssigned = true
 124  		switch s := options.ID.String(); options.ID.idKind() {
 125  		case miClientID:
 126  			id = managedidentity.UserAssignedClientID(s)
 127  		case miObjectID:
 128  			id = managedidentity.UserAssignedObjectID(s)
 129  		case miResourceID:
 130  			id = managedidentity.UserAssignedResourceID(s)
 131  		}
 132  	}
 133  	msalClient, err := managedidentity.New(id, managedidentity.WithHTTPClient(&c), managedidentity.WithRetryPolicyDisabled())
 134  	if err != nil {
 135  		return nil, err
 136  	}
 137  	c.msalClient = &msalClient
 138  
 139  	if log.Should(EventAuthentication) {
 140  		msg := fmt.Sprintf("%s will use %s managed identity", credNameManagedIdentity, env)
 141  		if options.ID != nil {
 142  			kind := "client"
 143  			switch options.ID.(type) {
 144  			case ObjectID:
 145  				kind = "object"
 146  			case ResourceID:
 147  				kind = "resource"
 148  			}
 149  			msg += fmt.Sprintf(" with %s ID %q", kind, options.ID.String())
 150  		}
 151  		log.Write(EventAuthentication, msg)
 152  	}
 153  
 154  	return &c, nil
 155  }
 156  
 157  func (*managedIdentityClient) CloseIdleConnections() {
 158  	// do nothing
 159  }
 160  
 161  func (c *managedIdentityClient) Do(r *http.Request) (*http.Response, error) {
 162  	return doForClient(c.azClient, r)
 163  }
 164  
 165  // authenticate acquires an access token
 166  func (c *managedIdentityClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
 167  	// no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
 168  	// and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
 169  	if c.probeIMDS {
 170  		// send a malformed request (no Metadata header) to IMDS to determine whether the endpoint is available
 171  		cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
 172  		defer cancel()
 173  		cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
 174  		req, err := azruntime.NewRequest(cx, http.MethodGet, imdsEndpoint)
 175  		if err != nil {
 176  			return azcore.AccessToken{}, fmt.Errorf("failed to create IMDS probe request: %s", err)
 177  		}
 178  		if _, err = c.azClient.Pipeline().Do(req); err != nil {
 179  			msg := err.Error()
 180  			if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
 181  				msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
 182  			}
 183  			return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
 184  		}
 185  		// send normal token requests from now on because something responded
 186  		c.probeIMDS = false
 187  	}
 188  
 189  	ar, err := c.msalClient.AcquireToken(ctx, tro.Scopes[0], managedidentity.WithClaims(tro.Claims))
 190  	if err == nil {
 191  		msg := fmt.Sprintf(scopeLogFmt, credNameManagedIdentity, strings.Join(ar.GrantedScopes, ", "))
 192  		log.Write(EventAuthentication, msg)
 193  		return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC(), RefreshOn: ar.Metadata.RefreshOn.UTC()}, err
 194  	}
 195  	if c.imds {
 196  		var ije msalerrors.InvalidJsonErr
 197  		if c.chained && errors.As(err, &ije) {
 198  			// an unmarshaling error implies the response is from something other than IMDS such as a proxy listening at
 199  			// the same address. Return a credentialUnavailableError so credential chains continue to their next credential
 200  			return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, err.Error())
 201  		}
 202  		resp := getResponseFromError(err)
 203  		if resp == nil {
 204  			return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSAL(credNameManagedIdentity, err)
 205  		}
 206  		switch resp.StatusCode {
 207  		case http.StatusBadRequest:
 208  			if c.userAssigned {
 209  				return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp)
 210  			}
 211  			msg := "failed to authenticate a system assigned identity"
 212  			if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 {
 213  				msg += fmt.Sprintf(". The endpoint responded with %s", body)
 214  			}
 215  			return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
 216  		case http.StatusForbidden:
 217  			// Docker Desktop runs a proxy that responds 403 to IMDS token requests. If we get that response,
 218  			// we return credentialUnavailableError so credential chains continue to their next credential
 219  			body, err := azruntime.Payload(resp)
 220  			if err == nil && strings.Contains(string(body), "unreachable") {
 221  				return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
 222  			}
 223  		}
 224  	}
 225  	err = newAuthenticationFailedErrorFromMSAL(credNameManagedIdentity, err)
 226  	return azcore.AccessToken{}, err
 227  }
 228