aws.go raw

   1  // Copyright 2021 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package externalaccount
   6  
   7  import (
   8  	"context"
   9  	"crypto/hmac"
  10  	"crypto/sha256"
  11  	"encoding/hex"
  12  	"encoding/json"
  13  	"errors"
  14  	"fmt"
  15  	"io"
  16  	"net/http"
  17  	"net/url"
  18  	"os"
  19  	"path"
  20  	"sort"
  21  	"strings"
  22  	"time"
  23  
  24  	"golang.org/x/oauth2"
  25  )
  26  
  27  // AwsSecurityCredentials models AWS security credentials.
  28  type AwsSecurityCredentials struct {
  29  	// AccessKeyID is the AWS Access Key ID - Required.
  30  	AccessKeyID string `json:"AccessKeyID"`
  31  	// SecretAccessKey is the AWS Secret Access Key - Required.
  32  	SecretAccessKey string `json:"SecretAccessKey"`
  33  	// SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional.
  34  	SessionToken string `json:"Token"`
  35  }
  36  
  37  // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
  38  type awsRequestSigner struct {
  39  	RegionName             string
  40  	AwsSecurityCredentials *AwsSecurityCredentials
  41  }
  42  
  43  // getenv aliases os.Getenv for testing
  44  var getenv = os.Getenv
  45  
  46  const (
  47  	defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
  48  
  49  	// AWS Signature Version 4 signing algorithm identifier.
  50  	awsAlgorithm = "AWS4-HMAC-SHA256"
  51  
  52  	// The termination string for the AWS credential scope value as defined in
  53  	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
  54  	awsRequestType = "aws4_request"
  55  
  56  	// The AWS authorization header name for the security session token if available.
  57  	awsSecurityTokenHeader = "x-amz-security-token"
  58  
  59  	// The name of the header containing the session token for metadata endpoint calls
  60  	awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
  61  
  62  	awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
  63  
  64  	awsIMDSv2SessionTtl = "300"
  65  
  66  	// The AWS authorization header name for the auto-generated date.
  67  	awsDateHeader = "x-amz-date"
  68  
  69  	// Supported AWS configuration environment variables.
  70  	awsAccessKeyId     = "AWS_ACCESS_KEY_ID"
  71  	awsDefaultRegion   = "AWS_DEFAULT_REGION"
  72  	awsRegion          = "AWS_REGION"
  73  	awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
  74  	awsSessionToken    = "AWS_SESSION_TOKEN"
  75  
  76  	awsTimeFormatLong  = "20060102T150405Z"
  77  	awsTimeFormatShort = "20060102"
  78  )
  79  
  80  func getSha256(input []byte) (string, error) {
  81  	hash := sha256.New()
  82  	if _, err := hash.Write(input); err != nil {
  83  		return "", err
  84  	}
  85  	return hex.EncodeToString(hash.Sum(nil)), nil
  86  }
  87  
  88  func getHmacSha256(key, input []byte) ([]byte, error) {
  89  	hash := hmac.New(sha256.New, key)
  90  	if _, err := hash.Write(input); err != nil {
  91  		return nil, err
  92  	}
  93  	return hash.Sum(nil), nil
  94  }
  95  
  96  func cloneRequest(r *http.Request) *http.Request {
  97  	r2 := new(http.Request)
  98  	*r2 = *r
  99  	if r.Header != nil {
 100  		r2.Header = make(http.Header, len(r.Header))
 101  
 102  		// Find total number of values.
 103  		headerCount := 0
 104  		for _, headerValues := range r.Header {
 105  			headerCount += len(headerValues)
 106  		}
 107  		copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
 108  
 109  		for headerKey, headerValues := range r.Header {
 110  			headerCount = copy(copiedHeaders, headerValues)
 111  			r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
 112  			copiedHeaders = copiedHeaders[headerCount:]
 113  		}
 114  	}
 115  	return r2
 116  }
 117  
 118  func canonicalPath(req *http.Request) string {
 119  	result := req.URL.EscapedPath()
 120  	if result == "" {
 121  		return "/"
 122  	}
 123  	return path.Clean(result)
 124  }
 125  
 126  func canonicalQuery(req *http.Request) string {
 127  	queryValues := req.URL.Query()
 128  	for queryKey := range queryValues {
 129  		sort.Strings(queryValues[queryKey])
 130  	}
 131  	return queryValues.Encode()
 132  }
 133  
 134  func canonicalHeaders(req *http.Request) (string, string) {
 135  	// Header keys need to be sorted alphabetically.
 136  	var headers []string
 137  	lowerCaseHeaders := make(http.Header)
 138  	for k, v := range req.Header {
 139  		k := strings.ToLower(k)
 140  		if _, ok := lowerCaseHeaders[k]; ok {
 141  			// include additional values
 142  			lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
 143  		} else {
 144  			headers = append(headers, k)
 145  			lowerCaseHeaders[k] = v
 146  		}
 147  	}
 148  	sort.Strings(headers)
 149  
 150  	var fullHeaders strings.Builder
 151  	for _, header := range headers {
 152  		headerValue := strings.Join(lowerCaseHeaders[header], ",")
 153  		fullHeaders.WriteString(header)
 154  		fullHeaders.WriteByte(':')
 155  		fullHeaders.WriteString(headerValue)
 156  		fullHeaders.WriteByte('\n')
 157  	}
 158  
 159  	return strings.Join(headers, ";"), fullHeaders.String()
 160  }
 161  
 162  func requestDataHash(req *http.Request) (string, error) {
 163  	var requestData []byte
 164  	if req.Body != nil {
 165  		requestBody, err := req.GetBody()
 166  		if err != nil {
 167  			return "", err
 168  		}
 169  		defer requestBody.Close()
 170  
 171  		requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20))
 172  		if err != nil {
 173  			return "", err
 174  		}
 175  	}
 176  
 177  	return getSha256(requestData)
 178  }
 179  
 180  func requestHost(req *http.Request) string {
 181  	if req.Host != "" {
 182  		return req.Host
 183  	}
 184  	return req.URL.Host
 185  }
 186  
 187  func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
 188  	dataHash, err := requestDataHash(req)
 189  	if err != nil {
 190  		return "", err
 191  	}
 192  
 193  	return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
 194  }
 195  
 196  // SignRequest adds the appropriate headers to an http.Request
 197  // or returns an error if something prevented this.
 198  func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
 199  	signedRequest := cloneRequest(req)
 200  	timestamp := now()
 201  
 202  	signedRequest.Header.Add("host", requestHost(req))
 203  
 204  	if rs.AwsSecurityCredentials.SessionToken != "" {
 205  		signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
 206  	}
 207  
 208  	if signedRequest.Header.Get("date") == "" {
 209  		signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
 210  	}
 211  
 212  	authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
 213  	if err != nil {
 214  		return err
 215  	}
 216  	signedRequest.Header.Set("Authorization", authorizationCode)
 217  
 218  	req.Header = signedRequest.Header
 219  	return nil
 220  }
 221  
 222  func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
 223  	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
 224  
 225  	dateStamp := timestamp.Format(awsTimeFormatShort)
 226  	serviceName := ""
 227  	if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
 228  		serviceName = splitHost[0]
 229  	}
 230  
 231  	credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
 232  
 233  	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
 234  	if err != nil {
 235  		return "", err
 236  	}
 237  	requestHash, err := getSha256([]byte(requestString))
 238  	if err != nil {
 239  		return "", err
 240  	}
 241  
 242  	stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
 243  
 244  	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
 245  	for _, signingInput := range []string{
 246  		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
 247  	} {
 248  		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
 249  		if err != nil {
 250  			return "", err
 251  		}
 252  	}
 253  
 254  	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
 255  }
 256  
 257  type awsCredentialSource struct {
 258  	environmentID                  string
 259  	regionURL                      string
 260  	regionalCredVerificationURL    string
 261  	credVerificationURL            string
 262  	imdsv2SessionTokenURL          string
 263  	targetResource                 string
 264  	requestSigner                  *awsRequestSigner
 265  	region                         string
 266  	ctx                            context.Context
 267  	client                         *http.Client
 268  	awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
 269  	supplierOptions                SupplierOptions
 270  }
 271  
 272  type awsRequestHeader struct {
 273  	Key   string `json:"key"`
 274  	Value string `json:"value"`
 275  }
 276  
 277  type awsRequest struct {
 278  	URL     string             `json:"url"`
 279  	Method  string             `json:"method"`
 280  	Headers []awsRequestHeader `json:"headers"`
 281  }
 282  
 283  func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
 284  	if cs.client == nil {
 285  		cs.client = oauth2.NewClient(cs.ctx, nil)
 286  	}
 287  	return cs.client.Do(req.WithContext(cs.ctx))
 288  }
 289  
 290  func canRetrieveRegionFromEnvironment() bool {
 291  	// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
 292  	// required.
 293  	return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
 294  }
 295  
 296  func canRetrieveSecurityCredentialFromEnvironment() bool {
 297  	// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
 298  	return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
 299  }
 300  
 301  func (cs awsCredentialSource) shouldUseMetadataServer() bool {
 302  	return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
 303  }
 304  
 305  func (cs awsCredentialSource) credentialSourceType() string {
 306  	if cs.awsSecurityCredentialsSupplier != nil {
 307  		return "programmatic"
 308  	}
 309  	return "aws"
 310  }
 311  
 312  func (cs awsCredentialSource) subjectToken() (string, error) {
 313  	// Set Defaults
 314  	if cs.regionalCredVerificationURL == "" {
 315  		cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
 316  	}
 317  	if cs.requestSigner == nil {
 318  		headers := make(map[string]string)
 319  		if cs.shouldUseMetadataServer() {
 320  			awsSessionToken, err := cs.getAWSSessionToken()
 321  			if err != nil {
 322  				return "", err
 323  			}
 324  
 325  			if awsSessionToken != "" {
 326  				headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
 327  			}
 328  		}
 329  
 330  		awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
 331  		if err != nil {
 332  			return "", err
 333  		}
 334  		cs.region, err = cs.getRegion(headers)
 335  		if err != nil {
 336  			return "", err
 337  		}
 338  
 339  		cs.requestSigner = &awsRequestSigner{
 340  			RegionName:             cs.region,
 341  			AwsSecurityCredentials: awsSecurityCredentials,
 342  		}
 343  	}
 344  
 345  	// Generate the signed request to AWS STS GetCallerIdentity API.
 346  	// Use the required regional endpoint. Otherwise, the request will fail.
 347  	req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil)
 348  	if err != nil {
 349  		return "", err
 350  	}
 351  	// The full, canonical resource name of the workload identity pool
 352  	// provider, with or without the HTTPS prefix.
 353  	// Including this header as part of the signature is recommended to
 354  	// ensure data integrity.
 355  	if cs.targetResource != "" {
 356  		req.Header.Add("x-goog-cloud-target-resource", cs.targetResource)
 357  	}
 358  	cs.requestSigner.SignRequest(req)
 359  
 360  	/*
 361  	   The GCP STS endpoint expects the headers to be formatted as:
 362  	   # [
 363  	   #   {key: 'x-amz-date', value: '...'},
 364  	   #   {key: 'Authorization', value: '...'},
 365  	   #   ...
 366  	   # ]
 367  	   # And then serialized as:
 368  	   # quote(json.dumps({
 369  	   #   url: '...',
 370  	   #   method: 'POST',
 371  	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
 372  	   # }))
 373  	*/
 374  
 375  	awsSignedReq := awsRequest{
 376  		URL:    req.URL.String(),
 377  		Method: "POST",
 378  	}
 379  	for headerKey, headerList := range req.Header {
 380  		for _, headerValue := range headerList {
 381  			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
 382  				Key:   headerKey,
 383  				Value: headerValue,
 384  			})
 385  		}
 386  	}
 387  	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
 388  		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
 389  		if headerCompare == 0 {
 390  			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
 391  		}
 392  		return headerCompare < 0
 393  	})
 394  
 395  	result, err := json.Marshal(awsSignedReq)
 396  	if err != nil {
 397  		return "", err
 398  	}
 399  	return url.QueryEscape(string(result)), nil
 400  }
 401  
 402  func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
 403  	if cs.imdsv2SessionTokenURL == "" {
 404  		return "", nil
 405  	}
 406  
 407  	req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil)
 408  	if err != nil {
 409  		return "", err
 410  	}
 411  
 412  	req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
 413  
 414  	resp, err := cs.doRequest(req)
 415  	if err != nil {
 416  		return "", err
 417  	}
 418  	defer resp.Body.Close()
 419  
 420  	respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
 421  	if err != nil {
 422  		return "", err
 423  	}
 424  
 425  	if resp.StatusCode != 200 {
 426  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody))
 427  	}
 428  
 429  	return string(respBody), nil
 430  }
 431  
 432  func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
 433  	if cs.awsSecurityCredentialsSupplier != nil {
 434  		return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions)
 435  	}
 436  	if canRetrieveRegionFromEnvironment() {
 437  		if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
 438  			cs.region = envAwsRegion
 439  			return envAwsRegion, nil
 440  		}
 441  		return getenv("AWS_DEFAULT_REGION"), nil
 442  	}
 443  
 444  	if cs.regionURL == "" {
 445  		return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region")
 446  	}
 447  
 448  	req, err := http.NewRequest("GET", cs.regionURL, nil)
 449  	if err != nil {
 450  		return "", err
 451  	}
 452  
 453  	for name, value := range headers {
 454  		req.Header.Add(name, value)
 455  	}
 456  
 457  	resp, err := cs.doRequest(req)
 458  	if err != nil {
 459  		return "", err
 460  	}
 461  	defer resp.Body.Close()
 462  
 463  	respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
 464  	if err != nil {
 465  		return "", err
 466  	}
 467  
 468  	if resp.StatusCode != 200 {
 469  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody))
 470  	}
 471  
 472  	// This endpoint will return the region in format: us-east-2b.
 473  	// Only the us-east-2 part should be used.
 474  	respBodyEnd := 0
 475  	if len(respBody) > 1 {
 476  		respBodyEnd = len(respBody) - 1
 477  	}
 478  	return string(respBody[:respBodyEnd]), nil
 479  }
 480  
 481  func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) {
 482  	if cs.awsSecurityCredentialsSupplier != nil {
 483  		return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions)
 484  	}
 485  	if canRetrieveSecurityCredentialFromEnvironment() {
 486  		return &AwsSecurityCredentials{
 487  			AccessKeyID:     getenv(awsAccessKeyId),
 488  			SecretAccessKey: getenv(awsSecretAccessKey),
 489  			SessionToken:    getenv(awsSessionToken),
 490  		}, nil
 491  	}
 492  
 493  	roleName, err := cs.getMetadataRoleName(headers)
 494  	if err != nil {
 495  		return
 496  	}
 497  
 498  	credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
 499  	if err != nil {
 500  		return
 501  	}
 502  
 503  	if credentials.AccessKeyID == "" {
 504  		return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential")
 505  	}
 506  
 507  	if credentials.SecretAccessKey == "" {
 508  		return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential")
 509  	}
 510  
 511  	return &credentials, nil
 512  }
 513  
 514  func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
 515  	var result AwsSecurityCredentials
 516  
 517  	req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil)
 518  	if err != nil {
 519  		return result, err
 520  	}
 521  
 522  	for name, value := range headers {
 523  		req.Header.Add(name, value)
 524  	}
 525  
 526  	resp, err := cs.doRequest(req)
 527  	if err != nil {
 528  		return result, err
 529  	}
 530  	defer resp.Body.Close()
 531  
 532  	respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
 533  	if err != nil {
 534  		return result, err
 535  	}
 536  
 537  	if resp.StatusCode != 200 {
 538  		return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody))
 539  	}
 540  
 541  	err = json.Unmarshal(respBody, &result)
 542  	return result, err
 543  }
 544  
 545  func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
 546  	if cs.credVerificationURL == "" {
 547  		return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint")
 548  	}
 549  
 550  	req, err := http.NewRequest("GET", cs.credVerificationURL, nil)
 551  	if err != nil {
 552  		return "", err
 553  	}
 554  
 555  	for name, value := range headers {
 556  		req.Header.Add(name, value)
 557  	}
 558  
 559  	resp, err := cs.doRequest(req)
 560  	if err != nil {
 561  		return "", err
 562  	}
 563  	defer resp.Body.Close()
 564  
 565  	respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
 566  	if err != nil {
 567  		return "", err
 568  	}
 569  
 570  	if resp.StatusCode != 200 {
 571  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody))
 572  	}
 573  
 574  	return string(respBody), nil
 575  }
 576