middleware.go raw
1 // Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
2 // resty source code and usage is governed by a MIT style
3 // license that can be found in the LICENSE file.
4
5 package resty
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "mime/multipart"
13 "net/http"
14 "net/url"
15 "os"
16 "path/filepath"
17 "reflect"
18 "strconv"
19 "strings"
20 "time"
21 )
22
23 const debugRequestLogKey = "__restyDebugRequestLog"
24
25 //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
26 // Request Middleware(s)
27 //_______________________________________________________________________
28
29 func parseRequestURL(c *Client, r *Request) error {
30 if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
31 params := make(map[string]string, l)
32
33 // GitHub #103 Path Params
34 for p, v := range r.PathParams {
35 params[p] = url.PathEscape(v)
36 }
37 for p, v := range c.PathParams {
38 if _, ok := params[p]; !ok {
39 params[p] = url.PathEscape(v)
40 }
41 }
42
43 // GitHub #663 Raw Path Params
44 for p, v := range r.RawPathParams {
45 if _, ok := params[p]; !ok {
46 params[p] = v
47 }
48 }
49 for p, v := range c.RawPathParams {
50 if _, ok := params[p]; !ok {
51 params[p] = v
52 }
53 }
54
55 if len(params) > 0 {
56 var prev int
57 buf := acquireBuffer()
58 defer releaseBuffer(buf)
59 // search for the next or first opened curly bracket
60 for curr := strings.Index(r.URL, "{"); curr == 0 || curr >= prev; curr = prev + strings.Index(r.URL[prev:], "{") {
61 // write everything from the previous position up to the current
62 if curr > prev {
63 buf.WriteString(r.URL[prev:curr])
64 }
65 // search for the closed curly bracket from current position
66 next := curr + strings.Index(r.URL[curr:], "}")
67 // if not found, then write the remainder and exit
68 if next < curr {
69 buf.WriteString(r.URL[curr:])
70 prev = len(r.URL)
71 break
72 }
73 // special case for {}, without parameter's name
74 if next == curr+1 {
75 buf.WriteString("{}")
76 } else {
77 // check for the replacement
78 key := r.URL[curr+1 : next]
79 value, ok := params[key]
80 /// keep the original string if the replacement not found
81 if !ok {
82 value = r.URL[curr : next+1]
83 }
84 buf.WriteString(value)
85 }
86
87 // set the previous position after the closed curly bracket
88 prev = next + 1
89 if prev >= len(r.URL) {
90 break
91 }
92 }
93 if buf.Len() > 0 {
94 // write remainder
95 if prev < len(r.URL) {
96 buf.WriteString(r.URL[prev:])
97 }
98 r.URL = buf.String()
99 }
100 }
101 }
102
103 // Parsing request URL
104 reqURL, err := url.Parse(r.URL)
105 if err != nil {
106 return err
107 }
108
109 // If Request.URL is relative path then added c.HostURL into
110 // the request URL otherwise Request.URL will be used as-is
111 if !reqURL.IsAbs() {
112 r.URL = reqURL.String()
113 if len(r.URL) > 0 && r.URL[0] != '/' {
114 r.URL = "/" + r.URL
115 }
116
117 // TODO: change to use c.BaseURL only in v3.0.0
118 baseURL := c.BaseURL
119 if len(baseURL) == 0 {
120 baseURL = c.HostURL
121 }
122 reqURL, err = url.Parse(baseURL + r.URL)
123 if err != nil {
124 return err
125 }
126 }
127
128 // GH #407 && #318
129 if reqURL.Scheme == "" && len(c.scheme) > 0 {
130 reqURL.Scheme = c.scheme
131 }
132
133 // Adding Query Param
134 if len(c.QueryParam)+len(r.QueryParam) > 0 {
135 for k, v := range c.QueryParam {
136 // skip query parameter if it was set in request
137 if _, ok := r.QueryParam[k]; ok {
138 continue
139 }
140
141 r.QueryParam[k] = v[:]
142 }
143
144 // GitHub #123 Preserve query string order partially.
145 // Since not feasible in `SetQuery*` resty methods, because
146 // standard package `url.Encode(...)` sorts the query params
147 // alphabetically
148 if len(r.QueryParam) > 0 {
149 if IsStringEmpty(reqURL.RawQuery) {
150 reqURL.RawQuery = r.QueryParam.Encode()
151 } else {
152 reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParam.Encode()
153 }
154 }
155 }
156
157 // GH#797 Unescape query parameters
158 if r.unescapeQueryParams && len(reqURL.RawQuery) > 0 {
159 // at this point, all errors caught up in the above operations
160 // so ignore the return error on query unescape; I realized
161 // while writing the unit test
162 unescapedQuery, _ := url.QueryUnescape(reqURL.RawQuery)
163 reqURL.RawQuery = strings.ReplaceAll(unescapedQuery, " ", "+") // otherwise request becomes bad request
164 }
165
166 r.URL = reqURL.String()
167
168 return nil
169 }
170
171 func parseRequestHeader(c *Client, r *Request) error {
172 for k, v := range c.Header {
173 if _, ok := r.Header[k]; ok {
174 continue
175 }
176 r.Header[k] = v[:]
177 }
178
179 if IsStringEmpty(r.Header.Get(hdrUserAgentKey)) {
180 r.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
181 }
182
183 if ct := r.Header.Get(hdrContentTypeKey); IsStringEmpty(r.Header.Get(hdrAcceptKey)) && !IsStringEmpty(ct) && (IsJSONType(ct) || IsXMLType(ct)) {
184 r.Header.Set(hdrAcceptKey, r.Header.Get(hdrContentTypeKey))
185 }
186
187 return nil
188 }
189
190 func parseRequestBody(c *Client, r *Request) error {
191 if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
192 switch {
193 case r.isMultiPart: // Handling Multipart
194 if err := handleMultipart(c, r); err != nil {
195 return err
196 }
197 case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data
198 handleFormData(c, r)
199 case r.Body == nil && r.bodyBuf == nil: // Handling Request body when nil body
200 // Go http library omits Content-Length if body is nil; use http.NoBody to force it if SetContentLength is true
201 r.Body = http.NoBody
202 fallthrough
203 case r.Body != nil: // Handling Request body
204 handleContentType(c, r)
205
206 if err := handleRequestBody(c, r); err != nil {
207 return err
208 }
209 }
210 }
211
212 // by default resty won't set content length, you can if you want to :)
213 if c.setContentLength || r.setContentLength {
214 if r.bodyBuf == nil {
215 r.Header.Set(hdrContentLengthKey, "0")
216 } else {
217 r.Header.Set(hdrContentLengthKey, strconv.Itoa(r.bodyBuf.Len()))
218 }
219 }
220
221 return nil
222 }
223
224 func createHTTPRequest(c *Client, r *Request) (err error) {
225 if r.bodyBuf == nil {
226 if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
227 r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
228 } else if c.setContentLength || r.setContentLength {
229 r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
230 } else {
231 r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
232 }
233 } else {
234 // fix data race: must deep copy.
235 bodyBuf := bytes.NewBuffer(append([]byte{}, r.bodyBuf.Bytes()...))
236 r.RawRequest, err = http.NewRequest(r.Method, r.URL, bodyBuf)
237 }
238
239 if err != nil {
240 return
241 }
242
243 // Assign close connection option
244 r.RawRequest.Close = c.closeConnection
245
246 // Add headers into http request
247 r.RawRequest.Header = r.Header
248
249 // Add cookies from client instance into http request
250 for _, cookie := range c.Cookies {
251 r.RawRequest.AddCookie(cookie)
252 }
253
254 // Add cookies from request instance into http request
255 for _, cookie := range r.Cookies {
256 r.RawRequest.AddCookie(cookie)
257 }
258
259 // Enable trace
260 if c.trace || r.trace {
261 r.clientTrace = &clientTrace{}
262 r.ctx = r.clientTrace.createContext(r.Context())
263 }
264
265 // Use context if it was specified
266 if r.ctx != nil {
267 r.RawRequest = r.RawRequest.WithContext(r.ctx)
268 }
269
270 // assign get body func for the underlying raw request instance
271 if r.RawRequest.GetBody == nil {
272 bodyCopy, err := getBodyCopy(r)
273 if err != nil {
274 return err
275 }
276 if bodyCopy != nil {
277 buf := bodyCopy.Bytes()
278 r.RawRequest.GetBody = func() (io.ReadCloser, error) {
279 b := bytes.NewReader(buf)
280 return io.NopCloser(b), nil
281 }
282 }
283 }
284
285 return
286 }
287
288 func addCredentials(c *Client, r *Request) error {
289 var isBasicAuth bool
290 // Basic Auth
291 if r.UserInfo != nil { // takes precedence
292 r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
293 isBasicAuth = true
294 } else if c.UserInfo != nil {
295 r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
296 isBasicAuth = true
297 }
298
299 if !c.DisableWarn {
300 if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
301 r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
302 }
303 }
304
305 // Build the token Auth header
306 if !IsStringEmpty(r.Token) {
307 r.RawRequest.Header.Set(c.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+r.Token))
308 } else if !IsStringEmpty(c.Token) {
309 r.RawRequest.Header.Set(c.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+c.Token))
310 }
311
312 return nil
313 }
314
315 func createCurlCmd(c *Client, r *Request) (err error) {
316 if r.Debug && r.generateCurlOnDebug {
317 if r.resultCurlCmd == nil {
318 r.resultCurlCmd = new(string)
319 }
320 *r.resultCurlCmd = buildCurlRequest(r.RawRequest, c.httpClient.Jar)
321 }
322
323 return nil
324 }
325
326 func requestLogger(c *Client, r *Request) error {
327 if r.Debug {
328 rr := r.RawRequest
329 rh := copyHeaders(rr.Header)
330 if c.GetClient().Jar != nil {
331 for _, cookie := range c.GetClient().Jar.Cookies(r.RawRequest.URL) {
332 s := fmt.Sprintf("%s=%s", cookie.Name, cookie.Value)
333 if c := rh.Get("Cookie"); c != "" {
334 rh.Set("Cookie", c+"; "+s)
335 } else {
336 rh.Set("Cookie", s)
337 }
338 }
339 }
340 rl := &RequestLog{Header: rh, Body: r.fmtBodyString(c.debugBodySizeLimit)}
341 if c.requestLog != nil {
342 if err := c.requestLog(rl); err != nil {
343 return err
344 }
345 }
346
347 reqLog := "\n==============================================================================\n"
348
349 if r.Debug && r.generateCurlOnDebug {
350 reqLog += "~~~ REQUEST(CURL) ~~~\n" +
351 fmt.Sprintf(" %v\n", *r.resultCurlCmd)
352 }
353
354 reqLog += "~~~ REQUEST ~~~\n" +
355 fmt.Sprintf("%s %s %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
356 fmt.Sprintf("HOST : %s\n", rr.URL.Host) +
357 fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) +
358 fmt.Sprintf("BODY :\n%v\n", rl.Body) +
359 "------------------------------------------------------------------------------\n"
360
361 r.initValuesMap()
362 r.values[debugRequestLogKey] = reqLog
363 }
364
365 return nil
366 }
367
368 //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
369 // Response Middleware(s)
370 //_______________________________________________________________________
371
372 func responseLogger(c *Client, res *Response) error {
373 if res.Request.Debug {
374 rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
375 if c.responseLog != nil {
376 if err := c.responseLog(rl); err != nil {
377 return err
378 }
379 }
380
381 debugLog := res.Request.values[debugRequestLogKey].(string)
382 debugLog += "~~~ RESPONSE ~~~\n" +
383 fmt.Sprintf("STATUS : %s\n", res.Status()) +
384 fmt.Sprintf("PROTO : %s\n", res.Proto()) +
385 fmt.Sprintf("RECEIVED AT : %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
386 fmt.Sprintf("TIME DURATION: %v\n", res.Time()) +
387 "HEADERS :\n" +
388 composeHeaders(c, res.Request, rl.Header) + "\n"
389 if res.Request.isSaveResponse {
390 debugLog += "BODY :\n***** RESPONSE WRITTEN INTO FILE *****\n"
391 } else {
392 debugLog += fmt.Sprintf("BODY :\n%v\n", rl.Body)
393 }
394 debugLog += "==============================================================================\n"
395
396 res.Request.log.Debugf("%s", debugLog)
397 }
398
399 return nil
400 }
401
402 func parseResponseBody(c *Client, res *Response) (err error) {
403 if res.StatusCode() == http.StatusNoContent {
404 res.Request.Error = nil
405 return
406 }
407 // Handles only JSON or XML content type
408 ct := firstNonEmpty(res.Request.forceContentType, res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
409 if IsJSONType(ct) || IsXMLType(ct) {
410 // HTTP status code > 199 and < 300, considered as Result
411 if res.IsSuccess() {
412 res.Request.Error = nil
413 if res.Request.Result != nil {
414 err = Unmarshalc(c, ct, res.body, res.Request.Result)
415 return
416 }
417 }
418
419 // HTTP status code > 399, considered as Error
420 if res.IsError() {
421 // global error interface
422 if res.Request.Error == nil && c.Error != nil {
423 res.Request.Error = reflect.New(c.Error).Interface()
424 }
425
426 if res.Request.Error != nil {
427 unmarshalErr := Unmarshalc(c, ct, res.body, res.Request.Error)
428 if unmarshalErr != nil {
429 c.log.Warnf("Cannot unmarshal response body: %s", unmarshalErr)
430 }
431 }
432 }
433 }
434
435 return
436 }
437
438 func handleMultipart(c *Client, r *Request) error {
439 r.bodyBuf = acquireBuffer()
440 w := multipart.NewWriter(r.bodyBuf)
441
442 // Set boundary if not set by user
443 if r.multipartBoundary != "" {
444 if err := w.SetBoundary(r.multipartBoundary); err != nil {
445 return err
446 }
447 }
448
449 for k, v := range c.FormData {
450 for _, iv := range v {
451 if err := w.WriteField(k, iv); err != nil {
452 return err
453 }
454 }
455 }
456
457 for k, v := range r.FormData {
458 for _, iv := range v {
459 if strings.HasPrefix(k, "@") { // file
460 if err := addFile(w, k[1:], iv); err != nil {
461 return err
462 }
463 } else { // form value
464 if err := w.WriteField(k, iv); err != nil {
465 return err
466 }
467 }
468 }
469 }
470
471 // #21 - adding io.Reader support
472 for _, f := range r.multipartFiles {
473 if err := addFileReader(w, f); err != nil {
474 return err
475 }
476 }
477
478 // GitHub #130 adding multipart field support with content type
479 for _, mf := range r.multipartFields {
480 if err := addMultipartFormField(w, mf); err != nil {
481 return err
482 }
483 }
484
485 r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
486 return w.Close()
487 }
488
489 func handleFormData(c *Client, r *Request) {
490 for k, v := range c.FormData {
491 if _, ok := r.FormData[k]; ok {
492 continue
493 }
494 r.FormData[k] = v[:]
495 }
496
497 r.bodyBuf = acquireBuffer()
498 r.bodyBuf.WriteString(r.FormData.Encode())
499 r.Header.Set(hdrContentTypeKey, formContentType)
500 r.isFormData = true
501 }
502
503 func handleContentType(c *Client, r *Request) {
504 if r.Body == http.NoBody {
505 return
506 }
507 contentType := r.Header.Get(hdrContentTypeKey)
508 if IsStringEmpty(contentType) {
509 contentType = DetectContentType(r.Body)
510 r.Header.Set(hdrContentTypeKey, contentType)
511 }
512 }
513
514 func handleRequestBody(c *Client, r *Request) error {
515 var bodyBytes []byte
516 r.bodyBuf = nil
517
518 switch body := r.Body.(type) {
519 case io.Reader:
520 if c.setContentLength || r.setContentLength { // keep backward compatibility
521 r.bodyBuf = acquireBuffer()
522 if _, err := r.bodyBuf.ReadFrom(body); err != nil {
523 return err
524 }
525 r.Body = nil
526 } else {
527 // Otherwise buffer less processing for `io.Reader`, sounds good.
528 return nil
529 }
530 case []byte:
531 bodyBytes = body
532 case string:
533 bodyBytes = []byte(body)
534 default:
535 contentType := r.Header.Get(hdrContentTypeKey)
536 kind := kindOf(r.Body)
537 var err error
538 if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
539 r.bodyBuf, err = jsonMarshal(c, r, r.Body)
540 } else if IsXMLType(contentType) && (kind == reflect.Struct) {
541 bodyBytes, err = c.XMLMarshal(r.Body)
542 }
543 if err != nil {
544 return err
545 }
546 }
547
548 if bodyBytes == nil && r.bodyBuf == nil {
549 return errors.New("unsupported 'Body' type/value")
550 }
551
552 // []byte into Buffer
553 if bodyBytes != nil && r.bodyBuf == nil {
554 r.bodyBuf = acquireBuffer()
555 _, _ = r.bodyBuf.Write(bodyBytes)
556 }
557
558 return nil
559 }
560
561 func saveResponseIntoFile(c *Client, res *Response) error {
562 if res.Request.isSaveResponse {
563 file := ""
564
565 if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
566 file += c.outputDirectory + string(filepath.Separator)
567 }
568
569 file = filepath.Clean(file + res.Request.outputFile)
570 if err := createDirectory(filepath.Dir(file)); err != nil {
571 return err
572 }
573
574 outFile, err := os.Create(file)
575 if err != nil {
576 return err
577 }
578 defer closeq(outFile)
579
580 // io.Copy reads maximum 32kb size, it is perfect for large file download too
581 defer closeq(res.RawResponse.Body)
582
583 written, err := io.Copy(outFile, res.RawResponse.Body)
584 if err != nil {
585 return err
586 }
587
588 res.size = written
589 }
590
591 return nil
592 }
593
594 func getBodyCopy(r *Request) (*bytes.Buffer, error) {
595 // If r.bodyBuf present, return the copy
596 if r.bodyBuf != nil {
597 bodyCopy := acquireBuffer()
598 if _, err := io.Copy(bodyCopy, bytes.NewReader(r.bodyBuf.Bytes())); err != nil {
599 // cannot use io.Copy(bodyCopy, r.bodyBuf) because io.Copy reset r.bodyBuf
600 return nil, err
601 }
602 return bodyCopy, nil
603 }
604
605 // Maybe body is `io.Reader`.
606 // Note: Resty user have to watchout for large body size of `io.Reader`
607 if r.RawRequest.Body != nil {
608 b, err := io.ReadAll(r.RawRequest.Body)
609 if err != nil {
610 return nil, err
611 }
612
613 // Restore the Body
614 closeq(r.RawRequest.Body)
615 r.RawRequest.Body = io.NopCloser(bytes.NewBuffer(b))
616
617 // Return the Body bytes
618 return bytes.NewBuffer(b), nil
619 }
620 return nil, nil
621 }
622