provider.go raw

   1  package ec2rolecreds
   2  
   3  import (
   4  	"bufio"
   5  	"context"
   6  	"encoding/json"
   7  	"fmt"
   8  	"math"
   9  	"path"
  10  	"strings"
  11  	"time"
  12  
  13  	"github.com/aws/aws-sdk-go-v2/aws"
  14  	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
  15  	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
  16  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
  17  	"github.com/aws/smithy-go"
  18  	"github.com/aws/smithy-go/logging"
  19  	"github.com/aws/smithy-go/middleware"
  20  )
  21  
  22  // ProviderName provides a name of EC2Role provider
  23  const ProviderName = "EC2RoleProvider"
  24  
  25  // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
  26  // GetMetadata operation.
  27  type GetMetadataAPIClient interface {
  28  	GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
  29  }
  30  
  31  // A Provider retrieves credentials from the EC2 service, and keeps track if
  32  // those credentials are expired.
  33  //
  34  // The New function must be used to create the with a custom EC2 IMDS client.
  35  //
  36  //	p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
  37  //	     o.Client = imds.New(imds.Options{/* custom options */})
  38  //	})
  39  type Provider struct {
  40  	options Options
  41  }
  42  
  43  // Options is a list of user settable options for setting the behavior of the Provider.
  44  type Options struct {
  45  	// The API client that will be used by the provider to make GetMetadata API
  46  	// calls to EC2 IMDS.
  47  	//
  48  	// If nil, the provider will default to the EC2 IMDS client.
  49  	Client GetMetadataAPIClient
  50  
  51  	// The chain of providers that was used to create this provider
  52  	// These values are for reporting purposes and are not meant to be set up directly
  53  	CredentialSources []aws.CredentialSource
  54  }
  55  
  56  // New returns an initialized Provider value configured to retrieve
  57  // credentials from EC2 Instance Metadata service.
  58  func New(optFns ...func(*Options)) *Provider {
  59  	options := Options{}
  60  
  61  	for _, fn := range optFns {
  62  		fn(&options)
  63  	}
  64  
  65  	if options.Client == nil {
  66  		options.Client = imds.New(imds.Options{})
  67  	}
  68  
  69  	return &Provider{
  70  		options: options,
  71  	}
  72  }
  73  
  74  // Retrieve retrieves credentials from the EC2 service. Error will be returned
  75  // if the request fails, or unable to extract the desired credentials.
  76  func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
  77  	credsList, err := requestCredList(ctx, p.options.Client)
  78  	if err != nil {
  79  		return aws.Credentials{Source: ProviderName}, err
  80  	}
  81  
  82  	if len(credsList) == 0 {
  83  		return aws.Credentials{Source: ProviderName},
  84  			fmt.Errorf("unexpected empty EC2 IMDS role list")
  85  	}
  86  	credsName := credsList[0]
  87  
  88  	roleCreds, err := requestCred(ctx, p.options.Client, credsName)
  89  	if err != nil {
  90  		return aws.Credentials{Source: ProviderName}, err
  91  	}
  92  
  93  	creds := aws.Credentials{
  94  		AccessKeyID:     roleCreds.AccessKeyID,
  95  		SecretAccessKey: roleCreds.SecretAccessKey,
  96  		SessionToken:    roleCreds.Token,
  97  		Source:          ProviderName,
  98  
  99  		CanExpire: true,
 100  		Expires:   roleCreds.Expiration,
 101  	}
 102  
 103  	// Cap role credentials Expires to 1 hour so they can be refreshed more
 104  	// often. Jitter will be applied credentials cache if being used.
 105  	if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
 106  		creds.Expires = anHour
 107  	}
 108  
 109  	return creds, nil
 110  }
 111  
 112  // HandleFailToRefresh will extend the credentials Expires time if it it is
 113  // expired. If the credentials will not expire within the minimum time, they
 114  // will be returned.
 115  //
 116  // If the credentials cannot expire, the original error will be returned.
 117  func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
 118  	aws.Credentials, error,
 119  ) {
 120  	if !prevCreds.CanExpire {
 121  		return aws.Credentials{}, err
 122  	}
 123  
 124  	if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
 125  		return prevCreds, nil
 126  	}
 127  
 128  	newCreds := prevCreds
 129  	randFloat64, err := sdkrand.CryptoRandFloat64()
 130  	if err != nil {
 131  		return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
 132  	}
 133  
 134  	// Random distribution of [5,15) minutes.
 135  	expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
 136  	newCreds.Expires = sdk.NowTime().Add(expireOffset)
 137  
 138  	logger := middleware.GetLogger(ctx)
 139  	logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
 140  
 141  	return newCreds, nil
 142  }
 143  
 144  // AdjustExpiresBy will adds the passed in duration to the passed in
 145  // credential's Expires time, unless the time until Expires is less than 15
 146  // minutes. Returns the credentials, even if not updated.
 147  func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
 148  	aws.Credentials, error,
 149  ) {
 150  	if !creds.CanExpire {
 151  		return creds, nil
 152  	}
 153  	if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
 154  		return creds, nil
 155  	}
 156  
 157  	creds.Expires = creds.Expires.Add(dur)
 158  	return creds, nil
 159  }
 160  
 161  // ec2RoleCredRespBody provides the shape for unmarshaling credential
 162  // request responses.
 163  type ec2RoleCredRespBody struct {
 164  	// Success State
 165  	Expiration      time.Time
 166  	AccessKeyID     string
 167  	SecretAccessKey string
 168  	Token           string
 169  
 170  	// Error state
 171  	Code    string
 172  	Message string
 173  }
 174  
 175  const iamSecurityCredsPath = "/iam/security-credentials/"
 176  
 177  // requestCredList requests a list of credentials from the EC2 service. If
 178  // there are no credentials, or there is an error making or receiving the
 179  // request
 180  func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
 181  	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
 182  		Path: iamSecurityCredsPath,
 183  	})
 184  	if err != nil {
 185  		return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
 186  	}
 187  	defer resp.Content.Close()
 188  
 189  	credsList := []string{}
 190  	s := bufio.NewScanner(resp.Content)
 191  	for s.Scan() {
 192  		credsList = append(credsList, s.Text())
 193  	}
 194  
 195  	if err := s.Err(); err != nil {
 196  		return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
 197  	}
 198  
 199  	return credsList, nil
 200  }
 201  
 202  // requestCred requests the credentials for a specific credentials from the EC2 service.
 203  //
 204  // If the credentials cannot be found, or there is an error reading the response
 205  // and error will be returned.
 206  func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
 207  	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
 208  		Path: path.Join(iamSecurityCredsPath, credsName),
 209  	})
 210  	if err != nil {
 211  		return ec2RoleCredRespBody{},
 212  			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
 213  				credsName, err)
 214  	}
 215  	defer resp.Content.Close()
 216  
 217  	var respCreds ec2RoleCredRespBody
 218  	if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
 219  		return ec2RoleCredRespBody{},
 220  			fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
 221  				credsName, err)
 222  	}
 223  
 224  	if !strings.EqualFold(respCreds.Code, "Success") {
 225  		// If an error code was returned something failed requesting the role.
 226  		return ec2RoleCredRespBody{},
 227  			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
 228  				credsName,
 229  				&smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
 230  	}
 231  
 232  	return respCreds, nil
 233  }
 234  
 235  // ProviderSources returns the credential chain that was used to construct this provider
 236  func (p *Provider) ProviderSources() []aws.CredentialSource {
 237  	if p.options.CredentialSources == nil {
 238  		return []aws.CredentialSource{aws.CredentialSourceIMDS}
 239  	} // If no source has been set, assume this is used directly which means just call to assume role
 240  	return p.options.CredentialSources
 241  }
 242