checksum_middleware.go raw

   1  package http
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  
   7  	"github.com/aws/smithy-go/middleware"
   8  )
   9  
  10  const contentMD5Header = "Content-Md5"
  11  
  12  // contentMD5Checksum provides a middleware to compute and set
  13  // content-md5 checksum for a http request
  14  type contentMD5Checksum struct {
  15  }
  16  
  17  // AddContentChecksumMiddleware adds checksum middleware to middleware's
  18  // build step.
  19  func AddContentChecksumMiddleware(stack *middleware.Stack) error {
  20  	// This middleware must be executed before request body is set.
  21  	return stack.Build.Add(&contentMD5Checksum{}, middleware.Before)
  22  }
  23  
  24  // ID returns the identifier for the checksum middleware
  25  func (m *contentMD5Checksum) ID() string { return "ContentChecksum" }
  26  
  27  // HandleBuild adds behavior to compute md5 checksum and add content-md5 header
  28  // on http request
  29  func (m *contentMD5Checksum) HandleBuild(
  30  	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
  31  ) (
  32  	out middleware.BuildOutput, metadata middleware.Metadata, err error,
  33  ) {
  34  	req, ok := in.Request.(*Request)
  35  	if !ok {
  36  		return out, metadata, fmt.Errorf("unknown request type %T", req)
  37  	}
  38  
  39  	// if Content-MD5 header is already present, return
  40  	if v := req.Header.Get(contentMD5Header); len(v) != 0 {
  41  		return next.HandleBuild(ctx, in)
  42  	}
  43  
  44  	// fetch the request stream.
  45  	stream := req.GetStream()
  46  	// compute checksum if payload is explicit
  47  	if stream != nil {
  48  		if !req.IsStreamSeekable() {
  49  			return out, metadata, fmt.Errorf(
  50  				"unseekable stream is not supported for computing md5 checksum")
  51  		}
  52  
  53  		v, err := computeMD5Checksum(stream)
  54  		if err != nil {
  55  			return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err)
  56  		}
  57  
  58  		// reset the request stream
  59  		if err := req.RewindStream(); err != nil {
  60  			return out, metadata, fmt.Errorf(
  61  				"error rewinding request stream after computing md5 checksum, %w", err)
  62  		}
  63  
  64  		// set the 'Content-MD5' header
  65  		req.Header.Set(contentMD5Header, string(v))
  66  	}
  67  
  68  	// set md5 header value
  69  	return next.HandleBuild(ctx, in)
  70  }
  71