1 // Copyright 2018 Adam S Levy
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to
5 // deal in the Software without restriction, including without limitation the
6 // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
7 // sell copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
19 // IN THE SOFTWARE.
20 21 package jsonrpc2
22 23 import (
24 "context"
25 "encoding/json"
26 "fmt"
27 "io/ioutil"
28 "log"
29 "net/http"
30 "os"
31 "strings"
32 )
33 34 // HTTPRequestHandler returns an http.HandlerFunc for the given methods.
35 //
36 // The returned http.HandlerFunc efficiently handles any conforming single or
37 // batch requests or notifications, accurately catches all defined protocol
38 // errors, calls the appropriate MethodFuncs, recovers from any panics or
39 // invalid return values, and returns an Internal Error or Response with the
40 // correct ID, if not a Notification.
41 //
42 // See MethodFunc for more details.
43 //
44 // It is not safe to modify methods while the returned http.HandlerFunc is in
45 // use.
46 //
47 // This will panic if a method name beginning with "rpc." is used. See
48 // MethodMap for more details.
49 //
50 // The handler will use lgr to log any errors and debug information, if
51 // DebugMethodFunc is true. If lgr is nil, the default Logger from the log
52 // package is used.
53 func HTTPRequestHandler(methods MethodMap, lgr Logger) http.HandlerFunc {
54 for name := range methods {
55 if strings.HasPrefix(name, "rpc.") {
56 panic(fmt.Errorf("invalid method name: %v", name))
57 }
58 }
59 if lgr == nil {
60 lgr = log.New(os.Stderr, "", log.LstdFlags)
61 }
62 63 return func(w http.ResponseWriter, req *http.Request) {
64 res := handle(methods, req, lgr)
65 if req.Context().Err() != nil || res == nil {
66 return
67 }
68 // We should never have a JSON encoding related error because
69 // MethodFunc.call() already Marshaled any user provided Data
70 // or Result, and everything else is marshalable.
71 //
72 // However an error can be returned related to w.Write, which
73 // there is nothing we can do about, so we just log it here.
74 enc := json.NewEncoder(w)
75 if err := enc.Encode(res); err != nil {
76 lgr.Printf("req.Body.Write(): %v", err)
77 }
78 }
79 }
80 81 // handle an http.Request for the given methods.
82 func handle(methods MethodMap, req *http.Request, lgr Logger) interface{} {
83 // Read all bytes of HTTP request body.
84 reqBytes, err := ioutil.ReadAll(req.Body)
85 if err != nil {
86 return Response{Error: errorInternal(err.Error())}
87 }
88 89 // Ensure valid JSON so it can be assumed going forward.
90 if !json.Valid(reqBytes) {
91 return Response{Error: errorParse(nil)}
92 }
93 94 // Attempt to unmarshal into a slice to detect a batch request. Use
95 // []json.RawMessage so that each Request can be unmarshaled
96 // individually.
97 batch := true
98 rawReqs := make([]json.RawMessage, 1)
99 if json.Unmarshal(reqBytes, &rawReqs) != nil {
100 // Since the JSON is valid, this Unmarshal error indicates that
101 // this is just a single request.
102 batch = false
103 rawReqs[0] = json.RawMessage(reqBytes)
104 }
105 106 // Catch empty batch requests.
107 if len(rawReqs) == 0 {
108 return Response{Error: errorInvalidRequest("empty batch request")}
109 }
110 111 // Process each Request, omitting any returned Response that is empty.
112 responses := make(BatchResponse, 0, len(rawReqs))
113 for _, rawReq := range rawReqs {
114 if req.Context().Err() != nil {
115 return nil
116 }
117 res := processRequest(req.Context(), methods, rawReq, lgr)
118 if res == (Response{}) {
119 // Don't respond to Notifications.
120 continue
121 }
122 responses = append(responses, res)
123 }
124 125 // Send nothing if there are no responses.
126 if len(responses) == 0 {
127 return nil
128 }
129 130 // Return the BatchResponse if this was a batch request.
131 if batch {
132 return responses
133 }
134 135 // Return a single Response.
136 return responses[0]
137 }
138 139 // processRequest unmarshals and processes a single Request stored in rawReq
140 // using the methods defined in methods. If res is zero valued, then the
141 // Request was a Notification and should not be responded to.
142 func processRequest(ctx context.Context,
143 methods MethodMap, rawReq json.RawMessage, lgr Logger) (res Response) {
144 145 // Unmarshal into req with an error on any unknown fields.
146 var req Request
147 if err := json.Unmarshal(rawReq, &req); err != nil {
148 // At this point we know that this was valid JSON, so this is a
149 // not a ParseError, but something about the Request object did
150 // not conform to spec, so we return an invalidRequest Error.
151 //
152 // At this point we have no way to know if this was a Request
153 // or Notification, so we must respond with "id" set to null.
154 return Response{Error: errorInvalidRequest(err.Error())}
155 }
156 157 // Use a type assertion to get req.ID and req.Params as
158 // json.RawMessage. See Request.UnmarshalJSON for details about why
159 // this type assertion is safe here.
160 id, params := req.ID.(json.RawMessage), req.Params.(json.RawMessage)
161 162 // Never respond to Notifications, even if an Error occurs. For
163 // Requests, always use the Request ID in the Response.
164 defer func() {
165 if id == nil {
166 res = Response{}
167 return
168 }
169 res.ID = id
170 }()
171 172 // Look up the requested method and call it if found.
173 method, ok := methods[req.Method]
174 if !ok {
175 return Response{Error: errorMethodNotFound(req.Method)}
176 }
177 res = method.call(ctx, req.Method, params, lgr)
178 179 // Log the method name if debugging is enabled and the method had an
180 // internal error.
181 if DebugMethodFunc && res.HasError() && res.Error.Code == ErrorCodeInternal {
182 lgr.Printf("Method: %#v\n\n", req.Method)
183 }
184 185 return res
186 }
187