middleware_content_length.go raw

   1  package http
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  
   7  	"github.com/aws/smithy-go/middleware"
   8  )
   9  
  10  // ComputeContentLength provides a middleware to set the content-length
  11  // header for the length of a serialize request body.
  12  type ComputeContentLength struct {
  13  }
  14  
  15  // AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
  16  // stack's Build step.
  17  func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
  18  	return stack.Build.Add(&ComputeContentLength{}, middleware.After)
  19  }
  20  
  21  // ID returns the identifier for the ComputeContentLength.
  22  func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
  23  
  24  // HandleBuild adds the length of the serialized request to the HTTP header
  25  // if the length can be determined.
  26  func (m *ComputeContentLength) HandleBuild(
  27  	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
  28  ) (
  29  	out middleware.BuildOutput, metadata middleware.Metadata, err error,
  30  ) {
  31  	req, ok := in.Request.(*Request)
  32  	if !ok {
  33  		return out, metadata, fmt.Errorf("unknown request type %T", req)
  34  	}
  35  
  36  	// do nothing if request content-length was set to 0 or above.
  37  	if req.ContentLength >= 0 {
  38  		return next.HandleBuild(ctx, in)
  39  	}
  40  
  41  	// attempt to compute stream length
  42  	if n, ok, err := req.StreamLength(); err != nil {
  43  		return out, metadata, fmt.Errorf(
  44  			"failed getting length of request stream, %w", err)
  45  	} else if ok {
  46  		req.ContentLength = n
  47  	}
  48  
  49  	return next.HandleBuild(ctx, in)
  50  }
  51  
  52  // validateContentLength provides a middleware to validate the content-length
  53  // is valid (greater than zero), for the serialized request payload.
  54  type validateContentLength struct{}
  55  
  56  // ValidateContentLengthHeader adds middleware that validates request content-length
  57  // is set to value greater than zero.
  58  func ValidateContentLengthHeader(stack *middleware.Stack) error {
  59  	return stack.Build.Add(&validateContentLength{}, middleware.After)
  60  }
  61  
  62  // ID returns the identifier for the ComputeContentLength.
  63  func (m *validateContentLength) ID() string { return "ValidateContentLength" }
  64  
  65  // HandleBuild adds the length of the serialized request to the HTTP header
  66  // if the length can be determined.
  67  func (m *validateContentLength) HandleBuild(
  68  	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
  69  ) (
  70  	out middleware.BuildOutput, metadata middleware.Metadata, err error,
  71  ) {
  72  	req, ok := in.Request.(*Request)
  73  	if !ok {
  74  		return out, metadata, fmt.Errorf("unknown request type %T", req)
  75  	}
  76  
  77  	// if request content-length was set to less than 0, return an error
  78  	if req.ContentLength < 0 {
  79  		return out, metadata, fmt.Errorf(
  80  			"content length for payload is required and must be at least 0")
  81  	}
  82  
  83  	return next.HandleBuild(ctx, in)
  84  }
  85