aws.go raw
1 // Copyright 2021 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package externalaccount
6
7 import (
8 "context"
9 "crypto/hmac"
10 "crypto/sha256"
11 "encoding/hex"
12 "encoding/json"
13 "errors"
14 "fmt"
15 "io"
16 "net/http"
17 "net/url"
18 "os"
19 "path"
20 "sort"
21 "strings"
22 "time"
23
24 "golang.org/x/oauth2"
25 )
26
27 // AwsSecurityCredentials models AWS security credentials.
28 type AwsSecurityCredentials struct {
29 // AccessKeyID is the AWS Access Key ID - Required.
30 AccessKeyID string `json:"AccessKeyID"`
31 // SecretAccessKey is the AWS Secret Access Key - Required.
32 SecretAccessKey string `json:"SecretAccessKey"`
33 // SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional.
34 SessionToken string `json:"Token"`
35 }
36
37 // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
38 type awsRequestSigner struct {
39 RegionName string
40 AwsSecurityCredentials *AwsSecurityCredentials
41 }
42
43 // getenv aliases os.Getenv for testing
44 var getenv = os.Getenv
45
46 const (
47 defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
48
49 // AWS Signature Version 4 signing algorithm identifier.
50 awsAlgorithm = "AWS4-HMAC-SHA256"
51
52 // The termination string for the AWS credential scope value as defined in
53 // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
54 awsRequestType = "aws4_request"
55
56 // The AWS authorization header name for the security session token if available.
57 awsSecurityTokenHeader = "x-amz-security-token"
58
59 // The name of the header containing the session token for metadata endpoint calls
60 awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
61
62 awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
63
64 awsIMDSv2SessionTtl = "300"
65
66 // The AWS authorization header name for the auto-generated date.
67 awsDateHeader = "x-amz-date"
68
69 // Supported AWS configuration environment variables.
70 awsAccessKeyId = "AWS_ACCESS_KEY_ID"
71 awsDefaultRegion = "AWS_DEFAULT_REGION"
72 awsRegion = "AWS_REGION"
73 awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
74 awsSessionToken = "AWS_SESSION_TOKEN"
75
76 awsTimeFormatLong = "20060102T150405Z"
77 awsTimeFormatShort = "20060102"
78 )
79
80 func getSha256(input []byte) (string, error) {
81 hash := sha256.New()
82 if _, err := hash.Write(input); err != nil {
83 return "", err
84 }
85 return hex.EncodeToString(hash.Sum(nil)), nil
86 }
87
88 func getHmacSha256(key, input []byte) ([]byte, error) {
89 hash := hmac.New(sha256.New, key)
90 if _, err := hash.Write(input); err != nil {
91 return nil, err
92 }
93 return hash.Sum(nil), nil
94 }
95
96 func cloneRequest(r *http.Request) *http.Request {
97 r2 := new(http.Request)
98 *r2 = *r
99 if r.Header != nil {
100 r2.Header = make(http.Header, len(r.Header))
101
102 // Find total number of values.
103 headerCount := 0
104 for _, headerValues := range r.Header {
105 headerCount += len(headerValues)
106 }
107 copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
108
109 for headerKey, headerValues := range r.Header {
110 headerCount = copy(copiedHeaders, headerValues)
111 r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
112 copiedHeaders = copiedHeaders[headerCount:]
113 }
114 }
115 return r2
116 }
117
118 func canonicalPath(req *http.Request) string {
119 result := req.URL.EscapedPath()
120 if result == "" {
121 return "/"
122 }
123 return path.Clean(result)
124 }
125
126 func canonicalQuery(req *http.Request) string {
127 queryValues := req.URL.Query()
128 for queryKey := range queryValues {
129 sort.Strings(queryValues[queryKey])
130 }
131 return queryValues.Encode()
132 }
133
134 func canonicalHeaders(req *http.Request) (string, string) {
135 // Header keys need to be sorted alphabetically.
136 var headers []string
137 lowerCaseHeaders := make(http.Header)
138 for k, v := range req.Header {
139 k := strings.ToLower(k)
140 if _, ok := lowerCaseHeaders[k]; ok {
141 // include additional values
142 lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
143 } else {
144 headers = append(headers, k)
145 lowerCaseHeaders[k] = v
146 }
147 }
148 sort.Strings(headers)
149
150 var fullHeaders strings.Builder
151 for _, header := range headers {
152 headerValue := strings.Join(lowerCaseHeaders[header], ",")
153 fullHeaders.WriteString(header)
154 fullHeaders.WriteByte(':')
155 fullHeaders.WriteString(headerValue)
156 fullHeaders.WriteByte('\n')
157 }
158
159 return strings.Join(headers, ";"), fullHeaders.String()
160 }
161
162 func requestDataHash(req *http.Request) (string, error) {
163 var requestData []byte
164 if req.Body != nil {
165 requestBody, err := req.GetBody()
166 if err != nil {
167 return "", err
168 }
169 defer requestBody.Close()
170
171 requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20))
172 if err != nil {
173 return "", err
174 }
175 }
176
177 return getSha256(requestData)
178 }
179
180 func requestHost(req *http.Request) string {
181 if req.Host != "" {
182 return req.Host
183 }
184 return req.URL.Host
185 }
186
187 func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
188 dataHash, err := requestDataHash(req)
189 if err != nil {
190 return "", err
191 }
192
193 return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
194 }
195
196 // SignRequest adds the appropriate headers to an http.Request
197 // or returns an error if something prevented this.
198 func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
199 signedRequest := cloneRequest(req)
200 timestamp := now()
201
202 signedRequest.Header.Add("host", requestHost(req))
203
204 if rs.AwsSecurityCredentials.SessionToken != "" {
205 signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
206 }
207
208 if signedRequest.Header.Get("date") == "" {
209 signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
210 }
211
212 authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
213 if err != nil {
214 return err
215 }
216 signedRequest.Header.Set("Authorization", authorizationCode)
217
218 req.Header = signedRequest.Header
219 return nil
220 }
221
222 func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
223 canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
224
225 dateStamp := timestamp.Format(awsTimeFormatShort)
226 serviceName := ""
227 if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
228 serviceName = splitHost[0]
229 }
230
231 credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
232
233 requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
234 if err != nil {
235 return "", err
236 }
237 requestHash, err := getSha256([]byte(requestString))
238 if err != nil {
239 return "", err
240 }
241
242 stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
243
244 signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
245 for _, signingInput := range []string{
246 dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
247 } {
248 signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
249 if err != nil {
250 return "", err
251 }
252 }
253
254 return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
255 }
256
257 type awsCredentialSource struct {
258 environmentID string
259 regionURL string
260 regionalCredVerificationURL string
261 credVerificationURL string
262 imdsv2SessionTokenURL string
263 targetResource string
264 requestSigner *awsRequestSigner
265 region string
266 ctx context.Context
267 client *http.Client
268 awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
269 supplierOptions SupplierOptions
270 }
271
272 type awsRequestHeader struct {
273 Key string `json:"key"`
274 Value string `json:"value"`
275 }
276
277 type awsRequest struct {
278 URL string `json:"url"`
279 Method string `json:"method"`
280 Headers []awsRequestHeader `json:"headers"`
281 }
282
283 func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
284 if cs.client == nil {
285 cs.client = oauth2.NewClient(cs.ctx, nil)
286 }
287 return cs.client.Do(req.WithContext(cs.ctx))
288 }
289
290 func canRetrieveRegionFromEnvironment() bool {
291 // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
292 // required.
293 return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
294 }
295
296 func canRetrieveSecurityCredentialFromEnvironment() bool {
297 // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
298 return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
299 }
300
301 func (cs awsCredentialSource) shouldUseMetadataServer() bool {
302 return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
303 }
304
305 func (cs awsCredentialSource) credentialSourceType() string {
306 if cs.awsSecurityCredentialsSupplier != nil {
307 return "programmatic"
308 }
309 return "aws"
310 }
311
312 func (cs awsCredentialSource) subjectToken() (string, error) {
313 // Set Defaults
314 if cs.regionalCredVerificationURL == "" {
315 cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
316 }
317 if cs.requestSigner == nil {
318 headers := make(map[string]string)
319 if cs.shouldUseMetadataServer() {
320 awsSessionToken, err := cs.getAWSSessionToken()
321 if err != nil {
322 return "", err
323 }
324
325 if awsSessionToken != "" {
326 headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
327 }
328 }
329
330 awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
331 if err != nil {
332 return "", err
333 }
334 cs.region, err = cs.getRegion(headers)
335 if err != nil {
336 return "", err
337 }
338
339 cs.requestSigner = &awsRequestSigner{
340 RegionName: cs.region,
341 AwsSecurityCredentials: awsSecurityCredentials,
342 }
343 }
344
345 // Generate the signed request to AWS STS GetCallerIdentity API.
346 // Use the required regional endpoint. Otherwise, the request will fail.
347 req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil)
348 if err != nil {
349 return "", err
350 }
351 // The full, canonical resource name of the workload identity pool
352 // provider, with or without the HTTPS prefix.
353 // Including this header as part of the signature is recommended to
354 // ensure data integrity.
355 if cs.targetResource != "" {
356 req.Header.Add("x-goog-cloud-target-resource", cs.targetResource)
357 }
358 cs.requestSigner.SignRequest(req)
359
360 /*
361 The GCP STS endpoint expects the headers to be formatted as:
362 # [
363 # {key: 'x-amz-date', value: '...'},
364 # {key: 'Authorization', value: '...'},
365 # ...
366 # ]
367 # And then serialized as:
368 # quote(json.dumps({
369 # url: '...',
370 # method: 'POST',
371 # headers: [{key: 'x-amz-date', value: '...'}, ...]
372 # }))
373 */
374
375 awsSignedReq := awsRequest{
376 URL: req.URL.String(),
377 Method: "POST",
378 }
379 for headerKey, headerList := range req.Header {
380 for _, headerValue := range headerList {
381 awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
382 Key: headerKey,
383 Value: headerValue,
384 })
385 }
386 }
387 sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
388 headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
389 if headerCompare == 0 {
390 return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
391 }
392 return headerCompare < 0
393 })
394
395 result, err := json.Marshal(awsSignedReq)
396 if err != nil {
397 return "", err
398 }
399 return url.QueryEscape(string(result)), nil
400 }
401
402 func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
403 if cs.imdsv2SessionTokenURL == "" {
404 return "", nil
405 }
406
407 req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil)
408 if err != nil {
409 return "", err
410 }
411
412 req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
413
414 resp, err := cs.doRequest(req)
415 if err != nil {
416 return "", err
417 }
418 defer resp.Body.Close()
419
420 respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
421 if err != nil {
422 return "", err
423 }
424
425 if resp.StatusCode != 200 {
426 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody))
427 }
428
429 return string(respBody), nil
430 }
431
432 func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
433 if cs.awsSecurityCredentialsSupplier != nil {
434 return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions)
435 }
436 if canRetrieveRegionFromEnvironment() {
437 if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
438 cs.region = envAwsRegion
439 return envAwsRegion, nil
440 }
441 return getenv("AWS_DEFAULT_REGION"), nil
442 }
443
444 if cs.regionURL == "" {
445 return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region")
446 }
447
448 req, err := http.NewRequest("GET", cs.regionURL, nil)
449 if err != nil {
450 return "", err
451 }
452
453 for name, value := range headers {
454 req.Header.Add(name, value)
455 }
456
457 resp, err := cs.doRequest(req)
458 if err != nil {
459 return "", err
460 }
461 defer resp.Body.Close()
462
463 respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
464 if err != nil {
465 return "", err
466 }
467
468 if resp.StatusCode != 200 {
469 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody))
470 }
471
472 // This endpoint will return the region in format: us-east-2b.
473 // Only the us-east-2 part should be used.
474 respBodyEnd := 0
475 if len(respBody) > 1 {
476 respBodyEnd = len(respBody) - 1
477 }
478 return string(respBody[:respBodyEnd]), nil
479 }
480
481 func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) {
482 if cs.awsSecurityCredentialsSupplier != nil {
483 return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions)
484 }
485 if canRetrieveSecurityCredentialFromEnvironment() {
486 return &AwsSecurityCredentials{
487 AccessKeyID: getenv(awsAccessKeyId),
488 SecretAccessKey: getenv(awsSecretAccessKey),
489 SessionToken: getenv(awsSessionToken),
490 }, nil
491 }
492
493 roleName, err := cs.getMetadataRoleName(headers)
494 if err != nil {
495 return
496 }
497
498 credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
499 if err != nil {
500 return
501 }
502
503 if credentials.AccessKeyID == "" {
504 return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential")
505 }
506
507 if credentials.SecretAccessKey == "" {
508 return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential")
509 }
510
511 return &credentials, nil
512 }
513
514 func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
515 var result AwsSecurityCredentials
516
517 req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil)
518 if err != nil {
519 return result, err
520 }
521
522 for name, value := range headers {
523 req.Header.Add(name, value)
524 }
525
526 resp, err := cs.doRequest(req)
527 if err != nil {
528 return result, err
529 }
530 defer resp.Body.Close()
531
532 respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
533 if err != nil {
534 return result, err
535 }
536
537 if resp.StatusCode != 200 {
538 return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody))
539 }
540
541 err = json.Unmarshal(respBody, &result)
542 return result, err
543 }
544
545 func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
546 if cs.credVerificationURL == "" {
547 return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint")
548 }
549
550 req, err := http.NewRequest("GET", cs.credVerificationURL, nil)
551 if err != nil {
552 return "", err
553 }
554
555 for name, value := range headers {
556 req.Header.Add(name, value)
557 }
558
559 resp, err := cs.doRequest(req)
560 if err != nil {
561 return "", err
562 }
563 defer resp.Body.Close()
564
565 respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
566 if err != nil {
567 return "", err
568 }
569
570 if resp.StatusCode != 200 {
571 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody))
572 }
573
574 return string(respBody), nil
575 }
576