1 package stscreds
2 3 import (
4 "context"
5 "fmt"
6 "io/ioutil"
7 "strconv"
8 "strings"
9 "time"
10 11 "github.com/aws/aws-sdk-go-v2/aws"
12 "github.com/aws/aws-sdk-go-v2/aws/retry"
13 "github.com/aws/aws-sdk-go-v2/internal/sdk"
14 "github.com/aws/aws-sdk-go-v2/service/sts"
15 "github.com/aws/aws-sdk-go-v2/service/sts/types"
16 )
17 18 var invalidIdentityTokenExceptionCode = (&types.InvalidIdentityTokenException{}).ErrorCode()
19 20 const (
21 // WebIdentityProviderName is the web identity provider name
22 WebIdentityProviderName = "WebIdentityCredentials"
23 )
24 25 // AssumeRoleWithWebIdentityAPIClient is a client capable of the STS AssumeRoleWithWebIdentity operation.
26 type AssumeRoleWithWebIdentityAPIClient interface {
27 AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
28 }
29 30 // WebIdentityRoleProvider is used to retrieve credentials using
31 // an OIDC token.
32 type WebIdentityRoleProvider struct {
33 options WebIdentityRoleOptions
34 }
35 36 // WebIdentityRoleOptions is a structure of configurable options for WebIdentityRoleProvider
37 type WebIdentityRoleOptions struct {
38 // Client implementation of the AssumeRoleWithWebIdentity operation. Required
39 Client AssumeRoleWithWebIdentityAPIClient
40 41 // JWT Token Provider. Required
42 TokenRetriever IdentityTokenRetriever
43 44 // IAM Role ARN to assume. Required
45 RoleARN string
46 47 // Session name, if you wish to uniquely identify this session.
48 RoleSessionName string
49 50 // Expiry duration of the STS credentials. STS will assign a default expiry
51 // duration if this value is unset. This is different from the Duration
52 // option of AssumeRoleProvider, which automatically assigns 15 minutes if
53 // Duration is unset.
54 //
55 // See the STS AssumeRoleWithWebIdentity API reference guide for more
56 // information on defaults.
57 // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
58 Duration time.Duration
59 60 // An IAM policy in JSON format that you want to use as an inline session policy.
61 Policy *string
62 63 // The Amazon Resource Names (ARNs) of the IAM managed policies that you
64 // want to use as managed session policies. The policies must exist in the
65 // same account as the role.
66 PolicyARNs []types.PolicyDescriptorType
67 68 // The chain of providers that was used to create this provider
69 // These values are for reporting purposes and are not meant to be set up directly
70 CredentialSources []aws.CredentialSource
71 }
72 73 // IdentityTokenRetriever is an interface for retrieving a JWT
74 type IdentityTokenRetriever interface {
75 GetIdentityToken() ([]byte, error)
76 }
77 78 // IdentityTokenFile is for retrieving an identity token from the given file name
79 type IdentityTokenFile string
80 81 // GetIdentityToken retrieves the JWT token from the file and returns the contents as a []byte
82 func (j IdentityTokenFile) GetIdentityToken() ([]byte, error) {
83 b, err := ioutil.ReadFile(string(j))
84 if err != nil {
85 return nil, fmt.Errorf("unable to read file at %s: %v", string(j), err)
86 }
87 88 return b, nil
89 }
90 91 // NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
92 // provided stsiface.ClientAPI
93 func NewWebIdentityRoleProvider(client AssumeRoleWithWebIdentityAPIClient, roleARN string, tokenRetriever IdentityTokenRetriever, optFns ...func(*WebIdentityRoleOptions)) *WebIdentityRoleProvider {
94 o := WebIdentityRoleOptions{
95 Client: client,
96 RoleARN: roleARN,
97 TokenRetriever: tokenRetriever,
98 }
99 100 for _, fn := range optFns {
101 fn(&o)
102 }
103 104 return &WebIdentityRoleProvider{options: o}
105 }
106 107 // Retrieve will attempt to assume a role from a token which is located at
108 // 'WebIdentityTokenFilePath' specified destination and if that is empty an
109 // error will be returned.
110 func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
111 b, err := p.options.TokenRetriever.GetIdentityToken()
112 if err != nil {
113 return aws.Credentials{}, fmt.Errorf("failed to retrieve jwt from provide source, %w", err)
114 }
115 116 sessionName := p.options.RoleSessionName
117 if len(sessionName) == 0 {
118 // session name is used to uniquely identify a session. This simply
119 // uses unix time in nanoseconds to uniquely identify sessions.
120 sessionName = strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
121 }
122 input := &sts.AssumeRoleWithWebIdentityInput{
123 PolicyArns: p.options.PolicyARNs,
124 RoleArn: &p.options.RoleARN,
125 RoleSessionName: &sessionName,
126 WebIdentityToken: aws.String(string(b)),
127 }
128 if p.options.Duration != 0 {
129 // If set use the value, otherwise STS will assign a default expiration duration.
130 input.DurationSeconds = aws.Int32(int32(p.options.Duration / time.Second))
131 }
132 if p.options.Policy != nil {
133 input.Policy = p.options.Policy
134 }
135 136 resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, input, func(options *sts.Options) {
137 options.Retryer = retry.AddWithErrorCodes(options.Retryer, invalidIdentityTokenExceptionCode)
138 })
139 if err != nil {
140 return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
141 }
142 143 var accountID string
144 if resp.AssumedRoleUser != nil {
145 accountID = getAccountID(resp.AssumedRoleUser)
146 }
147 148 // InvalidIdentityToken error is a temporary error that can occur
149 // when assuming an Role with a JWT web identity token.
150 151 value := aws.Credentials{
152 AccessKeyID: aws.ToString(resp.Credentials.AccessKeyId),
153 SecretAccessKey: aws.ToString(resp.Credentials.SecretAccessKey),
154 SessionToken: aws.ToString(resp.Credentials.SessionToken),
155 Source: WebIdentityProviderName,
156 CanExpire: true,
157 Expires: *resp.Credentials.Expiration,
158 AccountID: accountID,
159 }
160 return value, nil
161 }
162 163 // extract accountID from arn with format "arn:partition:service:region:account-id:[resource-section]"
164 func getAccountID(u *types.AssumedRoleUser) string {
165 if u.Arn == nil {
166 return ""
167 }
168 parts := strings.Split(*u.Arn, ":")
169 if len(parts) < 5 {
170 return ""
171 }
172 return parts[4]
173 }
174 175 // ProviderSources returns the credential chain that was used to construct this provider
176 func (p *WebIdentityRoleProvider) ProviderSources() []aws.CredentialSource {
177 if p.options.CredentialSources == nil {
178 return []aws.CredentialSource{aws.CredentialSourceSTSAssumeRoleWebID}
179 }
180 return p.options.CredentialSources
181 }
182