1 package ec2rolecreds
2 3 import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "math"
9 "path"
10 "strings"
11 "time"
12 13 "github.com/aws/aws-sdk-go-v2/aws"
14 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
15 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
16 "github.com/aws/aws-sdk-go-v2/internal/sdk"
17 "github.com/aws/smithy-go"
18 "github.com/aws/smithy-go/logging"
19 "github.com/aws/smithy-go/middleware"
20 )
21 22 // ProviderName provides a name of EC2Role provider
23 const ProviderName = "EC2RoleProvider"
24 25 // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
26 // GetMetadata operation.
27 type GetMetadataAPIClient interface {
28 GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
29 }
30 31 // A Provider retrieves credentials from the EC2 service, and keeps track if
32 // those credentials are expired.
33 //
34 // The New function must be used to create the with a custom EC2 IMDS client.
35 //
36 // p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
37 // o.Client = imds.New(imds.Options{/* custom options */})
38 // })
39 type Provider struct {
40 options Options
41 }
42 43 // Options is a list of user settable options for setting the behavior of the Provider.
44 type Options struct {
45 // The API client that will be used by the provider to make GetMetadata API
46 // calls to EC2 IMDS.
47 //
48 // If nil, the provider will default to the EC2 IMDS client.
49 Client GetMetadataAPIClient
50 51 // The chain of providers that was used to create this provider
52 // These values are for reporting purposes and are not meant to be set up directly
53 CredentialSources []aws.CredentialSource
54 }
55 56 // New returns an initialized Provider value configured to retrieve
57 // credentials from EC2 Instance Metadata service.
58 func New(optFns ...func(*Options)) *Provider {
59 options := Options{}
60 61 for _, fn := range optFns {
62 fn(&options)
63 }
64 65 if options.Client == nil {
66 options.Client = imds.New(imds.Options{})
67 }
68 69 return &Provider{
70 options: options,
71 }
72 }
73 74 // Retrieve retrieves credentials from the EC2 service. Error will be returned
75 // if the request fails, or unable to extract the desired credentials.
76 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
77 credsList, err := requestCredList(ctx, p.options.Client)
78 if err != nil {
79 return aws.Credentials{Source: ProviderName}, err
80 }
81 82 if len(credsList) == 0 {
83 return aws.Credentials{Source: ProviderName},
84 fmt.Errorf("unexpected empty EC2 IMDS role list")
85 }
86 credsName := credsList[0]
87 88 roleCreds, err := requestCred(ctx, p.options.Client, credsName)
89 if err != nil {
90 return aws.Credentials{Source: ProviderName}, err
91 }
92 93 creds := aws.Credentials{
94 AccessKeyID: roleCreds.AccessKeyID,
95 SecretAccessKey: roleCreds.SecretAccessKey,
96 SessionToken: roleCreds.Token,
97 Source: ProviderName,
98 99 CanExpire: true,
100 Expires: roleCreds.Expiration,
101 }
102 103 // Cap role credentials Expires to 1 hour so they can be refreshed more
104 // often. Jitter will be applied credentials cache if being used.
105 if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
106 creds.Expires = anHour
107 }
108 109 return creds, nil
110 }
111 112 // HandleFailToRefresh will extend the credentials Expires time if it it is
113 // expired. If the credentials will not expire within the minimum time, they
114 // will be returned.
115 //
116 // If the credentials cannot expire, the original error will be returned.
117 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
118 aws.Credentials, error,
119 ) {
120 if !prevCreds.CanExpire {
121 return aws.Credentials{}, err
122 }
123 124 if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
125 return prevCreds, nil
126 }
127 128 newCreds := prevCreds
129 randFloat64, err := sdkrand.CryptoRandFloat64()
130 if err != nil {
131 return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
132 }
133 134 // Random distribution of [5,15) minutes.
135 expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
136 newCreds.Expires = sdk.NowTime().Add(expireOffset)
137 138 logger := middleware.GetLogger(ctx)
139 logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
140 141 return newCreds, nil
142 }
143 144 // AdjustExpiresBy will adds the passed in duration to the passed in
145 // credential's Expires time, unless the time until Expires is less than 15
146 // minutes. Returns the credentials, even if not updated.
147 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
148 aws.Credentials, error,
149 ) {
150 if !creds.CanExpire {
151 return creds, nil
152 }
153 if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
154 return creds, nil
155 }
156 157 creds.Expires = creds.Expires.Add(dur)
158 return creds, nil
159 }
160 161 // ec2RoleCredRespBody provides the shape for unmarshaling credential
162 // request responses.
163 type ec2RoleCredRespBody struct {
164 // Success State
165 Expiration time.Time
166 AccessKeyID string
167 SecretAccessKey string
168 Token string
169 170 // Error state
171 Code string
172 Message string
173 }
174 175 const iamSecurityCredsPath = "/iam/security-credentials/"
176 177 // requestCredList requests a list of credentials from the EC2 service. If
178 // there are no credentials, or there is an error making or receiving the
179 // request
180 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
181 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
182 Path: iamSecurityCredsPath,
183 })
184 if err != nil {
185 return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
186 }
187 defer resp.Content.Close()
188 189 credsList := []string{}
190 s := bufio.NewScanner(resp.Content)
191 for s.Scan() {
192 credsList = append(credsList, s.Text())
193 }
194 195 if err := s.Err(); err != nil {
196 return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
197 }
198 199 return credsList, nil
200 }
201 202 // requestCred requests the credentials for a specific credentials from the EC2 service.
203 //
204 // If the credentials cannot be found, or there is an error reading the response
205 // and error will be returned.
206 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
207 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
208 Path: path.Join(iamSecurityCredsPath, credsName),
209 })
210 if err != nil {
211 return ec2RoleCredRespBody{},
212 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
213 credsName, err)
214 }
215 defer resp.Content.Close()
216 217 var respCreds ec2RoleCredRespBody
218 if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
219 return ec2RoleCredRespBody{},
220 fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
221 credsName, err)
222 }
223 224 if !strings.EqualFold(respCreds.Code, "Success") {
225 // If an error code was returned something failed requesting the role.
226 return ec2RoleCredRespBody{},
227 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
228 credsName,
229 &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
230 }
231 232 return respCreds, nil
233 }
234 235 // ProviderSources returns the credential chain that was used to construct this provider
236 func (p *Provider) ProviderSources() []aws.CredentialSource {
237 if p.options.CredentialSources == nil {
238 return []aws.CredentialSource{aws.CredentialSourceIMDS}
239 } // If no source has been set, assume this is used directly which means just call to assume role
240 return p.options.CredentialSources
241 }
242