request_middleware.go raw

   1  package imds
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"fmt"
   7  	"io/ioutil"
   8  	"net/url"
   9  	"path"
  10  	"time"
  11  
  12  	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
  13  	"github.com/aws/aws-sdk-go-v2/aws/retry"
  14  	"github.com/aws/smithy-go/middleware"
  15  	smithyhttp "github.com/aws/smithy-go/transport/http"
  16  )
  17  
  18  func addAPIRequestMiddleware(stack *middleware.Stack,
  19  	options Options,
  20  	operation string,
  21  	getPath func(interface{}) (string, error),
  22  	getOutput func(*smithyhttp.Response) (interface{}, error),
  23  ) (err error) {
  24  	err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
  25  	if err != nil {
  26  		return err
  27  	}
  28  
  29  	// Token Serializer build and state management.
  30  	if !options.disableAPIToken {
  31  		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
  32  		if err != nil {
  33  			return err
  34  		}
  35  
  36  		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
  37  		if err != nil {
  38  			return err
  39  		}
  40  	}
  41  
  42  	return nil
  43  }
  44  
  45  func addRequestMiddleware(stack *middleware.Stack,
  46  	options Options,
  47  	method string,
  48  	operation string,
  49  	getPath func(interface{}) (string, error),
  50  	getOutput func(*smithyhttp.Response) (interface{}, error),
  51  ) (err error) {
  52  	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
  53  	if err != nil {
  54  		return err
  55  	}
  56  
  57  	// Operation timeout
  58  	err = stack.Initialize.Add(&operationTimeout{
  59  		Disabled:       options.DisableDefaultTimeout,
  60  		DefaultTimeout: defaultOperationTimeout,
  61  	}, middleware.Before)
  62  	if err != nil {
  63  		return err
  64  	}
  65  
  66  	// Operation Serializer
  67  	err = stack.Serialize.Add(&serializeRequest{
  68  		GetPath: getPath,
  69  		Method:  method,
  70  	}, middleware.After)
  71  	if err != nil {
  72  		return err
  73  	}
  74  
  75  	// Operation endpoint resolver
  76  	err = stack.Serialize.Insert(&resolveEndpoint{
  77  		Endpoint:     options.Endpoint,
  78  		EndpointMode: options.EndpointMode,
  79  	}, "OperationSerializer", middleware.Before)
  80  	if err != nil {
  81  		return err
  82  	}
  83  
  84  	// Operation Deserializer
  85  	err = stack.Deserialize.Add(&deserializeResponse{
  86  		GetOutput: getOutput,
  87  	}, middleware.After)
  88  	if err != nil {
  89  		return err
  90  	}
  91  
  92  	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
  93  		LogRequest:          options.ClientLogMode.IsRequest(),
  94  		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
  95  		LogResponse:         options.ClientLogMode.IsResponse(),
  96  		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
  97  	}, middleware.After)
  98  	if err != nil {
  99  		return err
 100  	}
 101  
 102  	err = addSetLoggerMiddleware(stack, options)
 103  	if err != nil {
 104  		return err
 105  	}
 106  
 107  	if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
 108  		return fmt.Errorf("add protocol finalizers: %w", err)
 109  	}
 110  
 111  	// Retry support
 112  	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
 113  		Retryer:          options.Retryer,
 114  		LogRetryAttempts: options.ClientLogMode.IsRetries(),
 115  	})
 116  }
 117  
 118  func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
 119  	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
 120  }
 121  
 122  type serializeRequest struct {
 123  	GetPath func(interface{}) (string, error)
 124  	Method  string
 125  }
 126  
 127  func (*serializeRequest) ID() string {
 128  	return "OperationSerializer"
 129  }
 130  
 131  func (m *serializeRequest) HandleSerialize(
 132  	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
 133  ) (
 134  	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
 135  ) {
 136  	request, ok := in.Request.(*smithyhttp.Request)
 137  	if !ok {
 138  		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
 139  	}
 140  
 141  	reqPath, err := m.GetPath(in.Parameters)
 142  	if err != nil {
 143  		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
 144  	}
 145  
 146  	request.Request.URL.Path = reqPath
 147  	request.Request.Method = m.Method
 148  
 149  	return next.HandleSerialize(ctx, in)
 150  }
 151  
 152  type deserializeResponse struct {
 153  	GetOutput func(*smithyhttp.Response) (interface{}, error)
 154  }
 155  
 156  func (*deserializeResponse) ID() string {
 157  	return "OperationDeserializer"
 158  }
 159  
 160  func (m *deserializeResponse) HandleDeserialize(
 161  	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
 162  ) (
 163  	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
 164  ) {
 165  	out, metadata, err = next.HandleDeserialize(ctx, in)
 166  	if err != nil {
 167  		return out, metadata, err
 168  	}
 169  
 170  	resp, ok := out.RawResponse.(*smithyhttp.Response)
 171  	if !ok {
 172  		return out, metadata, fmt.Errorf(
 173  			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
 174  	}
 175  	defer resp.Body.Close()
 176  
 177  	// read the full body so that any operation timeouts cleanup will not race
 178  	// the body being read.
 179  	body, err := ioutil.ReadAll(resp.Body)
 180  	if err != nil {
 181  		return out, metadata, fmt.Errorf("read response body failed, %w", err)
 182  	}
 183  	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
 184  
 185  	// Anything that's not 200 |< 300 is error
 186  	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 187  		return out, metadata, &smithyhttp.ResponseError{
 188  			Response: resp,
 189  			Err:      fmt.Errorf("request to EC2 IMDS failed"),
 190  		}
 191  	}
 192  
 193  	result, err := m.GetOutput(resp)
 194  	if err != nil {
 195  		return out, metadata, fmt.Errorf(
 196  			"unable to get deserialized result for response, %w", err,
 197  		)
 198  	}
 199  	out.Result = result
 200  
 201  	return out, metadata, err
 202  }
 203  
 204  type resolveEndpoint struct {
 205  	Endpoint     string
 206  	EndpointMode EndpointModeState
 207  }
 208  
 209  func (*resolveEndpoint) ID() string {
 210  	return "ResolveEndpoint"
 211  }
 212  
 213  func (m *resolveEndpoint) HandleSerialize(
 214  	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
 215  ) (
 216  	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
 217  ) {
 218  
 219  	req, ok := in.Request.(*smithyhttp.Request)
 220  	if !ok {
 221  		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
 222  	}
 223  
 224  	var endpoint string
 225  	if len(m.Endpoint) > 0 {
 226  		endpoint = m.Endpoint
 227  	} else {
 228  		switch m.EndpointMode {
 229  		case EndpointModeStateIPv6:
 230  			endpoint = defaultIPv6Endpoint
 231  		case EndpointModeStateIPv4:
 232  			fallthrough
 233  		case EndpointModeStateUnset:
 234  			endpoint = defaultIPv4Endpoint
 235  		default:
 236  			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
 237  		}
 238  	}
 239  
 240  	req.URL, err = url.Parse(endpoint)
 241  	if err != nil {
 242  		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
 243  	}
 244  
 245  	return next.HandleSerialize(ctx, in)
 246  }
 247  
 248  const (
 249  	defaultOperationTimeout = 5 * time.Second
 250  )
 251  
 252  // operationTimeout adds a timeout on the middleware stack if the Context the
 253  // stack was called with does not have a deadline. The next middleware must
 254  // complete before the timeout, or the context will be canceled.
 255  //
 256  // If DefaultTimeout is zero, no default timeout will be used if the Context
 257  // does not have a timeout.
 258  //
 259  // The next middleware must also ensure that any resources that are also
 260  // canceled by the stack's context are completely consumed before returning.
 261  // Otherwise the timeout cleanup will race the resource being consumed
 262  // upstream.
 263  type operationTimeout struct {
 264  	Disabled       bool
 265  	DefaultTimeout time.Duration
 266  }
 267  
 268  func (*operationTimeout) ID() string { return "OperationTimeout" }
 269  
 270  func (m *operationTimeout) HandleInitialize(
 271  	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
 272  ) (
 273  	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
 274  ) {
 275  	if m.Disabled {
 276  		return next.HandleInitialize(ctx, input)
 277  	}
 278  
 279  	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
 280  		var cancelFn func()
 281  		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
 282  		defer cancelFn()
 283  	}
 284  
 285  	return next.HandleInitialize(ctx, input)
 286  }
 287  
 288  // appendURIPath joins a URI path component to the existing path with `/`
 289  // separators between the path components. If the path being added ends with a
 290  // trailing `/` that slash will be maintained.
 291  func appendURIPath(base, add string) string {
 292  	reqPath := path.Join(base, add)
 293  	if len(add) != 0 && add[len(add)-1] == '/' {
 294  		reqPath += "/"
 295  	}
 296  	return reqPath
 297  }
 298  
 299  func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
 300  	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
 301  		return fmt.Errorf("add ResolveAuthScheme: %w", err)
 302  	}
 303  	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
 304  		return fmt.Errorf("add GetIdentity: %w", err)
 305  	}
 306  	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
 307  		return fmt.Errorf("add ResolveEndpointV2: %w", err)
 308  	}
 309  	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
 310  		return fmt.Errorf("add Signing: %w", err)
 311  	}
 312  	return nil
 313  }
 314