aws_provider.go raw

   1  // Copyright 2023 Google LLC
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package externalaccount
  16  
  17  import (
  18  	"bytes"
  19  	"context"
  20  	"crypto/hmac"
  21  	"crypto/sha256"
  22  	"encoding/hex"
  23  	"encoding/json"
  24  	"errors"
  25  	"fmt"
  26  	"log/slog"
  27  	"net/http"
  28  	"net/url"
  29  	"os"
  30  	"path"
  31  	"sort"
  32  	"strings"
  33  	"time"
  34  
  35  	"cloud.google.com/go/auth/internal"
  36  	"github.com/googleapis/gax-go/v2/internallog"
  37  )
  38  
  39  var (
  40  	// getenv aliases os.Getenv for testing
  41  	getenv = os.Getenv
  42  )
  43  
  44  const (
  45  	// AWS Signature Version 4 signing algorithm identifier.
  46  	awsAlgorithm = "AWS4-HMAC-SHA256"
  47  
  48  	// The termination string for the AWS credential scope value as defined in
  49  	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
  50  	awsRequestType = "aws4_request"
  51  
  52  	// The AWS authorization header name for the security session token if available.
  53  	awsSecurityTokenHeader = "x-amz-security-token"
  54  
  55  	// The name of the header containing the session token for metadata endpoint calls
  56  	awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
  57  
  58  	awsIMDSv2SessionTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
  59  
  60  	awsIMDSv2SessionTTL = "300"
  61  
  62  	// The AWS authorization header name for the auto-generated date.
  63  	awsDateHeader = "x-amz-date"
  64  
  65  	defaultRegionalCredentialVerificationURL = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
  66  
  67  	// Supported AWS configuration environment variables.
  68  	awsAccessKeyIDEnvVar     = "AWS_ACCESS_KEY_ID"
  69  	awsDefaultRegionEnvVar   = "AWS_DEFAULT_REGION"
  70  	awsRegionEnvVar          = "AWS_REGION"
  71  	awsSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"
  72  	awsSessionTokenEnvVar    = "AWS_SESSION_TOKEN"
  73  
  74  	awsTimeFormatLong  = "20060102T150405Z"
  75  	awsTimeFormatShort = "20060102"
  76  	awsProviderType    = "aws"
  77  )
  78  
  79  type awsSubjectProvider struct {
  80  	EnvironmentID               string
  81  	RegionURL                   string
  82  	RegionalCredVerificationURL string
  83  	CredVerificationURL         string
  84  	IMDSv2SessionTokenURL       string
  85  	TargetResource              string
  86  	requestSigner               *awsRequestSigner
  87  	region                      string
  88  	securityCredentialsProvider AwsSecurityCredentialsProvider
  89  	reqOpts                     *RequestOptions
  90  
  91  	Client *http.Client
  92  	logger *slog.Logger
  93  }
  94  
  95  func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error) {
  96  	// Set Defaults
  97  	if sp.RegionalCredVerificationURL == "" {
  98  		sp.RegionalCredVerificationURL = defaultRegionalCredentialVerificationURL
  99  	}
 100  	headers := make(map[string]string)
 101  	if sp.shouldUseMetadataServer() {
 102  		awsSessionToken, err := sp.getAWSSessionToken(ctx)
 103  		if err != nil {
 104  			return "", err
 105  		}
 106  
 107  		if awsSessionToken != "" {
 108  			headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
 109  		}
 110  	}
 111  
 112  	awsSecurityCredentials, err := sp.getSecurityCredentials(ctx, headers)
 113  	if err != nil {
 114  		return "", err
 115  	}
 116  	if sp.region, err = sp.getRegion(ctx, headers); err != nil {
 117  		return "", err
 118  	}
 119  	sp.requestSigner = &awsRequestSigner{
 120  		RegionName:             sp.region,
 121  		AwsSecurityCredentials: awsSecurityCredentials,
 122  	}
 123  
 124  	// Generate the signed request to AWS STS GetCallerIdentity API.
 125  	// Use the required regional endpoint. Otherwise, the request will fail.
 126  	req, err := http.NewRequestWithContext(ctx, "POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
 127  	if err != nil {
 128  		return "", err
 129  	}
 130  	// The full, canonical resource name of the workload identity pool
 131  	// provider, with or without the HTTPS prefix.
 132  	// Including this header as part of the signature is recommended to
 133  	// ensure data integrity.
 134  	if sp.TargetResource != "" {
 135  		req.Header.Set("x-goog-cloud-target-resource", sp.TargetResource)
 136  	}
 137  	sp.requestSigner.signRequest(req)
 138  
 139  	/*
 140  	   The GCP STS endpoint expects the headers to be formatted as:
 141  	   # [
 142  	   #   {key: 'x-amz-date', value: '...'},
 143  	   #   {key: 'Authorization', value: '...'},
 144  	   #   ...
 145  	   # ]
 146  	   # And then serialized as:
 147  	   # quote(json.dumps({
 148  	   #   url: '...',
 149  	   #   method: 'POST',
 150  	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
 151  	   # }))
 152  	*/
 153  
 154  	awsSignedReq := awsRequest{
 155  		URL:    req.URL.String(),
 156  		Method: "POST",
 157  	}
 158  	for headerKey, headerList := range req.Header {
 159  		for _, headerValue := range headerList {
 160  			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
 161  				Key:   headerKey,
 162  				Value: headerValue,
 163  			})
 164  		}
 165  	}
 166  	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
 167  		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
 168  		if headerCompare == 0 {
 169  			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
 170  		}
 171  		return headerCompare < 0
 172  	})
 173  
 174  	result, err := json.Marshal(awsSignedReq)
 175  	if err != nil {
 176  		return "", err
 177  	}
 178  	return url.QueryEscape(string(result)), nil
 179  }
 180  
 181  func (sp *awsSubjectProvider) providerType() string {
 182  	if sp.securityCredentialsProvider != nil {
 183  		return programmaticProviderType
 184  	}
 185  	return awsProviderType
 186  }
 187  
 188  func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
 189  	if sp.IMDSv2SessionTokenURL == "" {
 190  		return "", nil
 191  	}
 192  	req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
 193  	if err != nil {
 194  		return "", err
 195  	}
 196  	req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
 197  
 198  	sp.logger.DebugContext(ctx, "aws session token request", "request", internallog.HTTPRequest(req, nil))
 199  	resp, body, err := internal.DoRequest(sp.Client, req)
 200  	if err != nil {
 201  		return "", err
 202  	}
 203  	sp.logger.DebugContext(ctx, "aws session token response", "response", internallog.HTTPResponse(resp, body))
 204  	if resp.StatusCode != http.StatusOK {
 205  		return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", body)
 206  	}
 207  	return string(body), nil
 208  }
 209  
 210  func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
 211  	if sp.securityCredentialsProvider != nil {
 212  		return sp.securityCredentialsProvider.AwsRegion(ctx, sp.reqOpts)
 213  	}
 214  	if canRetrieveRegionFromEnvironment() {
 215  		if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
 216  			return envAwsRegion, nil
 217  		}
 218  		return getenv(awsDefaultRegionEnvVar), nil
 219  	}
 220  
 221  	if sp.RegionURL == "" {
 222  		return "", errors.New("credentials: unable to determine AWS region")
 223  	}
 224  
 225  	req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
 226  	if err != nil {
 227  		return "", err
 228  	}
 229  
 230  	for name, value := range headers {
 231  		req.Header.Add(name, value)
 232  	}
 233  	sp.logger.DebugContext(ctx, "aws region request", "request", internallog.HTTPRequest(req, nil))
 234  	resp, body, err := internal.DoRequest(sp.Client, req)
 235  	if err != nil {
 236  		return "", err
 237  	}
 238  	sp.logger.DebugContext(ctx, "aws region response", "response", internallog.HTTPResponse(resp, body))
 239  	if resp.StatusCode != http.StatusOK {
 240  		return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", body)
 241  	}
 242  
 243  	// This endpoint will return the region in format: us-east-2b.
 244  	// Only the us-east-2 part should be used.
 245  	bodyLen := len(body)
 246  	if bodyLen == 0 {
 247  		return "", nil
 248  	}
 249  	return string(body[:bodyLen-1]), nil
 250  }
 251  
 252  func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
 253  	if sp.securityCredentialsProvider != nil {
 254  		return sp.securityCredentialsProvider.AwsSecurityCredentials(ctx, sp.reqOpts)
 255  	}
 256  	if canRetrieveSecurityCredentialFromEnvironment() {
 257  		return &AwsSecurityCredentials{
 258  			AccessKeyID:     getenv(awsAccessKeyIDEnvVar),
 259  			SecretAccessKey: getenv(awsSecretAccessKeyEnvVar),
 260  			SessionToken:    getenv(awsSessionTokenEnvVar),
 261  		}, nil
 262  	}
 263  
 264  	roleName, err := sp.getMetadataRoleName(ctx, headers)
 265  	if err != nil {
 266  		return
 267  	}
 268  	credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
 269  	if err != nil {
 270  		return
 271  	}
 272  
 273  	if credentials.AccessKeyID == "" {
 274  		return result, errors.New("credentials: missing AccessKeyId credential")
 275  	}
 276  	if credentials.SecretAccessKey == "" {
 277  		return result, errors.New("credentials: missing SecretAccessKey credential")
 278  	}
 279  
 280  	return credentials, nil
 281  }
 282  
 283  func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (*AwsSecurityCredentials, error) {
 284  	var result *AwsSecurityCredentials
 285  
 286  	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
 287  	if err != nil {
 288  		return result, err
 289  	}
 290  	for name, value := range headers {
 291  		req.Header.Add(name, value)
 292  	}
 293  	sp.logger.DebugContext(ctx, "aws security credential request", "request", internallog.HTTPRequest(req, nil))
 294  	resp, body, err := internal.DoRequest(sp.Client, req)
 295  	if err != nil {
 296  		return result, err
 297  	}
 298  	sp.logger.DebugContext(ctx, "aws security credential response", "response", internallog.HTTPResponse(resp, body))
 299  	if resp.StatusCode != http.StatusOK {
 300  		return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", body)
 301  	}
 302  	if err := json.Unmarshal(body, &result); err != nil {
 303  		return nil, err
 304  	}
 305  	return result, nil
 306  }
 307  
 308  func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
 309  	if sp.CredVerificationURL == "" {
 310  		return "", errors.New("credentials: unable to determine the AWS metadata server security credentials endpoint")
 311  	}
 312  	req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
 313  	if err != nil {
 314  		return "", err
 315  	}
 316  	for name, value := range headers {
 317  		req.Header.Add(name, value)
 318  	}
 319  
 320  	sp.logger.DebugContext(ctx, "aws metadata role request", "request", internallog.HTTPRequest(req, nil))
 321  	resp, body, err := internal.DoRequest(sp.Client, req)
 322  	if err != nil {
 323  		return "", err
 324  	}
 325  	sp.logger.DebugContext(ctx, "aws metadata role response", "response", internallog.HTTPResponse(resp, body))
 326  	if resp.StatusCode != http.StatusOK {
 327  		return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", body)
 328  	}
 329  	return string(body), nil
 330  }
 331  
 332  // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
 333  type awsRequestSigner struct {
 334  	RegionName             string
 335  	AwsSecurityCredentials *AwsSecurityCredentials
 336  }
 337  
 338  // signRequest adds the appropriate headers to an http.Request
 339  // or returns an error if something prevented this.
 340  func (rs *awsRequestSigner) signRequest(req *http.Request) error {
 341  	// req is assumed non-nil
 342  	signedRequest := cloneRequest(req)
 343  	timestamp := Now()
 344  	signedRequest.Header.Set("host", requestHost(req))
 345  	if rs.AwsSecurityCredentials.SessionToken != "" {
 346  		signedRequest.Header.Set(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
 347  	}
 348  	if signedRequest.Header.Get("date") == "" {
 349  		signedRequest.Header.Set(awsDateHeader, timestamp.Format(awsTimeFormatLong))
 350  	}
 351  	authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
 352  	if err != nil {
 353  		return err
 354  	}
 355  	signedRequest.Header.Set("Authorization", authorizationCode)
 356  	req.Header = signedRequest.Header
 357  	return nil
 358  }
 359  
 360  func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
 361  	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
 362  	dateStamp := timestamp.Format(awsTimeFormatShort)
 363  	serviceName := ""
 364  
 365  	if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
 366  		serviceName = splitHost[0]
 367  	}
 368  	credentialScope := strings.Join([]string{dateStamp, rs.RegionName, serviceName, awsRequestType}, "/")
 369  	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
 370  	if err != nil {
 371  		return "", err
 372  	}
 373  	requestHash, err := getSha256([]byte(requestString))
 374  	if err != nil {
 375  		return "", err
 376  	}
 377  
 378  	stringToSign := strings.Join([]string{awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash}, "\n")
 379  	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
 380  	for _, signingInput := range []string{
 381  		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
 382  	} {
 383  		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
 384  		if err != nil {
 385  			return "", err
 386  		}
 387  	}
 388  
 389  	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
 390  }
 391  
 392  func getSha256(input []byte) (string, error) {
 393  	hash := sha256.New()
 394  	if _, err := hash.Write(input); err != nil {
 395  		return "", err
 396  	}
 397  	return hex.EncodeToString(hash.Sum(nil)), nil
 398  }
 399  
 400  func getHmacSha256(key, input []byte) ([]byte, error) {
 401  	hash := hmac.New(sha256.New, key)
 402  	if _, err := hash.Write(input); err != nil {
 403  		return nil, err
 404  	}
 405  	return hash.Sum(nil), nil
 406  }
 407  
 408  func cloneRequest(r *http.Request) *http.Request {
 409  	r2 := new(http.Request)
 410  	*r2 = *r
 411  	if r.Header != nil {
 412  		r2.Header = make(http.Header, len(r.Header))
 413  
 414  		// Find total number of values.
 415  		headerCount := 0
 416  		for _, headerValues := range r.Header {
 417  			headerCount += len(headerValues)
 418  		}
 419  		copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
 420  
 421  		for headerKey, headerValues := range r.Header {
 422  			headerCount = copy(copiedHeaders, headerValues)
 423  			r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
 424  			copiedHeaders = copiedHeaders[headerCount:]
 425  		}
 426  	}
 427  	return r2
 428  }
 429  
 430  func canonicalPath(req *http.Request) string {
 431  	result := req.URL.EscapedPath()
 432  	if result == "" {
 433  		return "/"
 434  	}
 435  	return path.Clean(result)
 436  }
 437  
 438  func canonicalQuery(req *http.Request) string {
 439  	queryValues := req.URL.Query()
 440  	for queryKey := range queryValues {
 441  		sort.Strings(queryValues[queryKey])
 442  	}
 443  	return queryValues.Encode()
 444  }
 445  
 446  func canonicalHeaders(req *http.Request) (string, string) {
 447  	// Header keys need to be sorted alphabetically.
 448  	var headers []string
 449  	lowerCaseHeaders := make(http.Header)
 450  	for k, v := range req.Header {
 451  		k := strings.ToLower(k)
 452  		if _, ok := lowerCaseHeaders[k]; ok {
 453  			// include additional values
 454  			lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
 455  		} else {
 456  			headers = append(headers, k)
 457  			lowerCaseHeaders[k] = v
 458  		}
 459  	}
 460  	sort.Strings(headers)
 461  
 462  	var fullHeaders bytes.Buffer
 463  	for _, header := range headers {
 464  		headerValue := strings.Join(lowerCaseHeaders[header], ",")
 465  		fullHeaders.WriteString(header)
 466  		fullHeaders.WriteRune(':')
 467  		fullHeaders.WriteString(headerValue)
 468  		fullHeaders.WriteRune('\n')
 469  	}
 470  
 471  	return strings.Join(headers, ";"), fullHeaders.String()
 472  }
 473  
 474  func requestDataHash(req *http.Request) (string, error) {
 475  	var requestData []byte
 476  	if req.Body != nil {
 477  		requestBody, err := req.GetBody()
 478  		if err != nil {
 479  			return "", err
 480  		}
 481  		defer requestBody.Close()
 482  
 483  		requestData, err = internal.ReadAll(requestBody)
 484  		if err != nil {
 485  			return "", err
 486  		}
 487  	}
 488  
 489  	return getSha256(requestData)
 490  }
 491  
 492  func requestHost(req *http.Request) string {
 493  	if req.Host != "" {
 494  		return req.Host
 495  	}
 496  	return req.URL.Host
 497  }
 498  
 499  func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
 500  	dataHash, err := requestDataHash(req)
 501  	if err != nil {
 502  		return "", err
 503  	}
 504  	return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
 505  }
 506  
 507  type awsRequestHeader struct {
 508  	Key   string `json:"key"`
 509  	Value string `json:"value"`
 510  }
 511  
 512  type awsRequest struct {
 513  	URL     string             `json:"url"`
 514  	Method  string             `json:"method"`
 515  	Headers []awsRequestHeader `json:"headers"`
 516  }
 517  
 518  // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
 519  // required.
 520  func canRetrieveRegionFromEnvironment() bool {
 521  	return getenv(awsRegionEnvVar) != "" || getenv(awsDefaultRegionEnvVar) != ""
 522  }
 523  
 524  // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
 525  func canRetrieveSecurityCredentialFromEnvironment() bool {
 526  	return getenv(awsAccessKeyIDEnvVar) != "" && getenv(awsSecretAccessKeyEnvVar) != ""
 527  }
 528  
 529  func (sp *awsSubjectProvider) shouldUseMetadataServer() bool {
 530  	return sp.securityCredentialsProvider == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
 531  }
 532