1 package http
2 3 import (
4 "context"
5 "fmt"
6 7 "github.com/aws/smithy-go/middleware"
8 )
9 10 // ComputeContentLength provides a middleware to set the content-length
11 // header for the length of a serialize request body.
12 type ComputeContentLength struct {
13 }
14 15 // AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
16 // stack's Build step.
17 func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
18 return stack.Build.Add(&ComputeContentLength{}, middleware.After)
19 }
20 21 // ID returns the identifier for the ComputeContentLength.
22 func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
23 24 // HandleBuild adds the length of the serialized request to the HTTP header
25 // if the length can be determined.
26 func (m *ComputeContentLength) HandleBuild(
27 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
28 ) (
29 out middleware.BuildOutput, metadata middleware.Metadata, err error,
30 ) {
31 req, ok := in.Request.(*Request)
32 if !ok {
33 return out, metadata, fmt.Errorf("unknown request type %T", req)
34 }
35 36 // do nothing if request content-length was set to 0 or above.
37 if req.ContentLength >= 0 {
38 return next.HandleBuild(ctx, in)
39 }
40 41 // attempt to compute stream length
42 if n, ok, err := req.StreamLength(); err != nil {
43 return out, metadata, fmt.Errorf(
44 "failed getting length of request stream, %w", err)
45 } else if ok {
46 req.ContentLength = n
47 }
48 49 return next.HandleBuild(ctx, in)
50 }
51 52 // validateContentLength provides a middleware to validate the content-length
53 // is valid (greater than zero), for the serialized request payload.
54 type validateContentLength struct{}
55 56 // ValidateContentLengthHeader adds middleware that validates request content-length
57 // is set to value greater than zero.
58 func ValidateContentLengthHeader(stack *middleware.Stack) error {
59 return stack.Build.Add(&validateContentLength{}, middleware.After)
60 }
61 62 // ID returns the identifier for the ComputeContentLength.
63 func (m *validateContentLength) ID() string { return "ValidateContentLength" }
64 65 // HandleBuild adds the length of the serialized request to the HTTP header
66 // if the length can be determined.
67 func (m *validateContentLength) HandleBuild(
68 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
69 ) (
70 out middleware.BuildOutput, metadata middleware.Metadata, err error,
71 ) {
72 req, ok := in.Request.(*Request)
73 if !ok {
74 return out, metadata, fmt.Errorf("unknown request type %T", req)
75 }
76 77 // if request content-length was set to less than 0, return an error
78 if req.ContentLength < 0 {
79 return out, metadata, fmt.Errorf(
80 "content length for payload is required and must be at least 0")
81 }
82 83 return next.HandleBuild(ctx, in)
84 }
85