middleware.go raw

   1  // Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
   2  // resty source code and usage is governed by a MIT style
   3  // license that can be found in the LICENSE file.
   4  
   5  package resty
   6  
   7  import (
   8  	"bytes"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  	"mime/multipart"
  13  	"net/http"
  14  	"net/url"
  15  	"os"
  16  	"path/filepath"
  17  	"reflect"
  18  	"strconv"
  19  	"strings"
  20  	"time"
  21  )
  22  
  23  const debugRequestLogKey = "__restyDebugRequestLog"
  24  
  25  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
  26  // Request Middleware(s)
  27  //_______________________________________________________________________
  28  
  29  func parseRequestURL(c *Client, r *Request) error {
  30  	if l := len(c.PathParams) + len(c.RawPathParams) + len(r.PathParams) + len(r.RawPathParams); l > 0 {
  31  		params := make(map[string]string, l)
  32  
  33  		// GitHub #103 Path Params
  34  		for p, v := range r.PathParams {
  35  			params[p] = url.PathEscape(v)
  36  		}
  37  		for p, v := range c.PathParams {
  38  			if _, ok := params[p]; !ok {
  39  				params[p] = url.PathEscape(v)
  40  			}
  41  		}
  42  
  43  		// GitHub #663 Raw Path Params
  44  		for p, v := range r.RawPathParams {
  45  			if _, ok := params[p]; !ok {
  46  				params[p] = v
  47  			}
  48  		}
  49  		for p, v := range c.RawPathParams {
  50  			if _, ok := params[p]; !ok {
  51  				params[p] = v
  52  			}
  53  		}
  54  
  55  		if len(params) > 0 {
  56  			var prev int
  57  			buf := acquireBuffer()
  58  			defer releaseBuffer(buf)
  59  			// search for the next or first opened curly bracket
  60  			for curr := strings.Index(r.URL, "{"); curr == 0 || curr >= prev; curr = prev + strings.Index(r.URL[prev:], "{") {
  61  				// write everything from the previous position up to the current
  62  				if curr > prev {
  63  					buf.WriteString(r.URL[prev:curr])
  64  				}
  65  				// search for the closed curly bracket from current position
  66  				next := curr + strings.Index(r.URL[curr:], "}")
  67  				// if not found, then write the remainder and exit
  68  				if next < curr {
  69  					buf.WriteString(r.URL[curr:])
  70  					prev = len(r.URL)
  71  					break
  72  				}
  73  				// special case for {}, without parameter's name
  74  				if next == curr+1 {
  75  					buf.WriteString("{}")
  76  				} else {
  77  					// check for the replacement
  78  					key := r.URL[curr+1 : next]
  79  					value, ok := params[key]
  80  					/// keep the original string if the replacement not found
  81  					if !ok {
  82  						value = r.URL[curr : next+1]
  83  					}
  84  					buf.WriteString(value)
  85  				}
  86  
  87  				// set the previous position after the closed curly bracket
  88  				prev = next + 1
  89  				if prev >= len(r.URL) {
  90  					break
  91  				}
  92  			}
  93  			if buf.Len() > 0 {
  94  				// write remainder
  95  				if prev < len(r.URL) {
  96  					buf.WriteString(r.URL[prev:])
  97  				}
  98  				r.URL = buf.String()
  99  			}
 100  		}
 101  	}
 102  
 103  	// Parsing request URL
 104  	reqURL, err := url.Parse(r.URL)
 105  	if err != nil {
 106  		return err
 107  	}
 108  
 109  	// If Request.URL is relative path then added c.HostURL into
 110  	// the request URL otherwise Request.URL will be used as-is
 111  	if !reqURL.IsAbs() {
 112  		r.URL = reqURL.String()
 113  		if len(r.URL) > 0 && r.URL[0] != '/' {
 114  			r.URL = "/" + r.URL
 115  		}
 116  
 117  		// TODO: change to use c.BaseURL only in v3.0.0
 118  		baseURL := c.BaseURL
 119  		if len(baseURL) == 0 {
 120  			baseURL = c.HostURL
 121  		}
 122  		reqURL, err = url.Parse(baseURL + r.URL)
 123  		if err != nil {
 124  			return err
 125  		}
 126  	}
 127  
 128  	// GH #407 && #318
 129  	if reqURL.Scheme == "" && len(c.scheme) > 0 {
 130  		reqURL.Scheme = c.scheme
 131  	}
 132  
 133  	// Adding Query Param
 134  	if len(c.QueryParam)+len(r.QueryParam) > 0 {
 135  		for k, v := range c.QueryParam {
 136  			// skip query parameter if it was set in request
 137  			if _, ok := r.QueryParam[k]; ok {
 138  				continue
 139  			}
 140  
 141  			r.QueryParam[k] = v[:]
 142  		}
 143  
 144  		// GitHub #123 Preserve query string order partially.
 145  		// Since not feasible in `SetQuery*` resty methods, because
 146  		// standard package `url.Encode(...)` sorts the query params
 147  		// alphabetically
 148  		if len(r.QueryParam) > 0 {
 149  			if IsStringEmpty(reqURL.RawQuery) {
 150  				reqURL.RawQuery = r.QueryParam.Encode()
 151  			} else {
 152  				reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParam.Encode()
 153  			}
 154  		}
 155  	}
 156  
 157  	// GH#797 Unescape query parameters
 158  	if r.unescapeQueryParams && len(reqURL.RawQuery) > 0 {
 159  		// at this point, all errors caught up in the above operations
 160  		// so ignore the return error on query unescape; I realized
 161  		// while writing the unit test
 162  		unescapedQuery, _ := url.QueryUnescape(reqURL.RawQuery)
 163  		reqURL.RawQuery = strings.ReplaceAll(unescapedQuery, " ", "+") // otherwise request becomes bad request
 164  	}
 165  
 166  	r.URL = reqURL.String()
 167  
 168  	return nil
 169  }
 170  
 171  func parseRequestHeader(c *Client, r *Request) error {
 172  	for k, v := range c.Header {
 173  		if _, ok := r.Header[k]; ok {
 174  			continue
 175  		}
 176  		r.Header[k] = v[:]
 177  	}
 178  
 179  	if IsStringEmpty(r.Header.Get(hdrUserAgentKey)) {
 180  		r.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
 181  	}
 182  
 183  	if ct := r.Header.Get(hdrContentTypeKey); IsStringEmpty(r.Header.Get(hdrAcceptKey)) && !IsStringEmpty(ct) && (IsJSONType(ct) || IsXMLType(ct)) {
 184  		r.Header.Set(hdrAcceptKey, r.Header.Get(hdrContentTypeKey))
 185  	}
 186  
 187  	return nil
 188  }
 189  
 190  func parseRequestBody(c *Client, r *Request) error {
 191  	if isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
 192  		switch {
 193  		case r.isMultiPart: // Handling Multipart
 194  			if err := handleMultipart(c, r); err != nil {
 195  				return err
 196  			}
 197  		case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data
 198  			handleFormData(c, r)
 199  		case r.Body == nil && r.bodyBuf == nil: // Handling Request body when nil body
 200  			// Go http library omits Content-Length if body is nil; use http.NoBody to force it if SetContentLength is true
 201  			r.Body = http.NoBody
 202  			fallthrough
 203  		case r.Body != nil: // Handling Request body
 204  			handleContentType(c, r)
 205  
 206  			if err := handleRequestBody(c, r); err != nil {
 207  				return err
 208  			}
 209  		}
 210  	}
 211  
 212  	// by default resty won't set content length, you can if you want to :)
 213  	if c.setContentLength || r.setContentLength {
 214  		if r.bodyBuf == nil {
 215  			r.Header.Set(hdrContentLengthKey, "0")
 216  		} else {
 217  			r.Header.Set(hdrContentLengthKey, strconv.Itoa(r.bodyBuf.Len()))
 218  		}
 219  	}
 220  
 221  	return nil
 222  }
 223  
 224  func createHTTPRequest(c *Client, r *Request) (err error) {
 225  	if r.bodyBuf == nil {
 226  		if reader, ok := r.Body.(io.Reader); ok && isPayloadSupported(r.Method, c.AllowGetMethodPayload) {
 227  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, reader)
 228  		} else if c.setContentLength || r.setContentLength {
 229  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, http.NoBody)
 230  		} else {
 231  			r.RawRequest, err = http.NewRequest(r.Method, r.URL, nil)
 232  		}
 233  	} else {
 234  		// fix data race: must deep copy.
 235  		bodyBuf := bytes.NewBuffer(append([]byte{}, r.bodyBuf.Bytes()...))
 236  		r.RawRequest, err = http.NewRequest(r.Method, r.URL, bodyBuf)
 237  	}
 238  
 239  	if err != nil {
 240  		return
 241  	}
 242  
 243  	// Assign close connection option
 244  	r.RawRequest.Close = c.closeConnection
 245  
 246  	// Add headers into http request
 247  	r.RawRequest.Header = r.Header
 248  
 249  	// Add cookies from client instance into http request
 250  	for _, cookie := range c.Cookies {
 251  		r.RawRequest.AddCookie(cookie)
 252  	}
 253  
 254  	// Add cookies from request instance into http request
 255  	for _, cookie := range r.Cookies {
 256  		r.RawRequest.AddCookie(cookie)
 257  	}
 258  
 259  	// Enable trace
 260  	if c.trace || r.trace {
 261  		r.clientTrace = &clientTrace{}
 262  		r.ctx = r.clientTrace.createContext(r.Context())
 263  	}
 264  
 265  	// Use context if it was specified
 266  	if r.ctx != nil {
 267  		r.RawRequest = r.RawRequest.WithContext(r.ctx)
 268  	}
 269  
 270  	// assign get body func for the underlying raw request instance
 271  	if r.RawRequest.GetBody == nil {
 272  		bodyCopy, err := getBodyCopy(r)
 273  		if err != nil {
 274  			return err
 275  		}
 276  		if bodyCopy != nil {
 277  			buf := bodyCopy.Bytes()
 278  			r.RawRequest.GetBody = func() (io.ReadCloser, error) {
 279  				b := bytes.NewReader(buf)
 280  				return io.NopCloser(b), nil
 281  			}
 282  		}
 283  	}
 284  
 285  	return
 286  }
 287  
 288  func addCredentials(c *Client, r *Request) error {
 289  	var isBasicAuth bool
 290  	// Basic Auth
 291  	if r.UserInfo != nil { // takes precedence
 292  		r.RawRequest.SetBasicAuth(r.UserInfo.Username, r.UserInfo.Password)
 293  		isBasicAuth = true
 294  	} else if c.UserInfo != nil {
 295  		r.RawRequest.SetBasicAuth(c.UserInfo.Username, c.UserInfo.Password)
 296  		isBasicAuth = true
 297  	}
 298  
 299  	if !c.DisableWarn {
 300  		if isBasicAuth && !strings.HasPrefix(r.URL, "https") {
 301  			r.log.Warnf("Using Basic Auth in HTTP mode is not secure, use HTTPS")
 302  		}
 303  	}
 304  
 305  	// Build the token Auth header
 306  	if !IsStringEmpty(r.Token) {
 307  		r.RawRequest.Header.Set(c.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+r.Token))
 308  	} else if !IsStringEmpty(c.Token) {
 309  		r.RawRequest.Header.Set(c.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+c.Token))
 310  	}
 311  
 312  	return nil
 313  }
 314  
 315  func createCurlCmd(c *Client, r *Request) (err error) {
 316  	if r.Debug && r.generateCurlOnDebug {
 317  		if r.resultCurlCmd == nil {
 318  			r.resultCurlCmd = new(string)
 319  		}
 320  		*r.resultCurlCmd = buildCurlRequest(r.RawRequest, c.httpClient.Jar)
 321  	}
 322  
 323  	return nil
 324  }
 325  
 326  func requestLogger(c *Client, r *Request) error {
 327  	if r.Debug {
 328  		rr := r.RawRequest
 329  		rh := copyHeaders(rr.Header)
 330  		if c.GetClient().Jar != nil {
 331  			for _, cookie := range c.GetClient().Jar.Cookies(r.RawRequest.URL) {
 332  				s := fmt.Sprintf("%s=%s", cookie.Name, cookie.Value)
 333  				if c := rh.Get("Cookie"); c != "" {
 334  					rh.Set("Cookie", c+"; "+s)
 335  				} else {
 336  					rh.Set("Cookie", s)
 337  				}
 338  			}
 339  		}
 340  		rl := &RequestLog{Header: rh, Body: r.fmtBodyString(c.debugBodySizeLimit)}
 341  		if c.requestLog != nil {
 342  			if err := c.requestLog(rl); err != nil {
 343  				return err
 344  			}
 345  		}
 346  
 347  		reqLog := "\n==============================================================================\n"
 348  
 349  		if r.Debug && r.generateCurlOnDebug {
 350  			reqLog += "~~~ REQUEST(CURL) ~~~\n" +
 351  				fmt.Sprintf("	%v\n", *r.resultCurlCmd)
 352  		}
 353  
 354  		reqLog += "~~~ REQUEST ~~~\n" +
 355  			fmt.Sprintf("%s  %s  %s\n", r.Method, rr.URL.RequestURI(), rr.Proto) +
 356  			fmt.Sprintf("HOST   : %s\n", rr.URL.Host) +
 357  			fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(c, r, rl.Header)) +
 358  			fmt.Sprintf("BODY   :\n%v\n", rl.Body) +
 359  			"------------------------------------------------------------------------------\n"
 360  
 361  		r.initValuesMap()
 362  		r.values[debugRequestLogKey] = reqLog
 363  	}
 364  
 365  	return nil
 366  }
 367  
 368  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
 369  // Response Middleware(s)
 370  //_______________________________________________________________________
 371  
 372  func responseLogger(c *Client, res *Response) error {
 373  	if res.Request.Debug {
 374  		rl := &ResponseLog{Header: copyHeaders(res.Header()), Body: res.fmtBodyString(c.debugBodySizeLimit)}
 375  		if c.responseLog != nil {
 376  			if err := c.responseLog(rl); err != nil {
 377  				return err
 378  			}
 379  		}
 380  
 381  		debugLog := res.Request.values[debugRequestLogKey].(string)
 382  		debugLog += "~~~ RESPONSE ~~~\n" +
 383  			fmt.Sprintf("STATUS       : %s\n", res.Status()) +
 384  			fmt.Sprintf("PROTO        : %s\n", res.Proto()) +
 385  			fmt.Sprintf("RECEIVED AT  : %v\n", res.ReceivedAt().Format(time.RFC3339Nano)) +
 386  			fmt.Sprintf("TIME DURATION: %v\n", res.Time()) +
 387  			"HEADERS      :\n" +
 388  			composeHeaders(c, res.Request, rl.Header) + "\n"
 389  		if res.Request.isSaveResponse {
 390  			debugLog += "BODY         :\n***** RESPONSE WRITTEN INTO FILE *****\n"
 391  		} else {
 392  			debugLog += fmt.Sprintf("BODY         :\n%v\n", rl.Body)
 393  		}
 394  		debugLog += "==============================================================================\n"
 395  
 396  		res.Request.log.Debugf("%s", debugLog)
 397  	}
 398  
 399  	return nil
 400  }
 401  
 402  func parseResponseBody(c *Client, res *Response) (err error) {
 403  	if res.StatusCode() == http.StatusNoContent {
 404  		res.Request.Error = nil
 405  		return
 406  	}
 407  	// Handles only JSON or XML content type
 408  	ct := firstNonEmpty(res.Request.forceContentType, res.Header().Get(hdrContentTypeKey), res.Request.fallbackContentType)
 409  	if IsJSONType(ct) || IsXMLType(ct) {
 410  		// HTTP status code > 199 and < 300, considered as Result
 411  		if res.IsSuccess() {
 412  			res.Request.Error = nil
 413  			if res.Request.Result != nil {
 414  				err = Unmarshalc(c, ct, res.body, res.Request.Result)
 415  				return
 416  			}
 417  		}
 418  
 419  		// HTTP status code > 399, considered as Error
 420  		if res.IsError() {
 421  			// global error interface
 422  			if res.Request.Error == nil && c.Error != nil {
 423  				res.Request.Error = reflect.New(c.Error).Interface()
 424  			}
 425  
 426  			if res.Request.Error != nil {
 427  				unmarshalErr := Unmarshalc(c, ct, res.body, res.Request.Error)
 428  				if unmarshalErr != nil {
 429  					c.log.Warnf("Cannot unmarshal response body: %s", unmarshalErr)
 430  				}
 431  			}
 432  		}
 433  	}
 434  
 435  	return
 436  }
 437  
 438  func handleMultipart(c *Client, r *Request) error {
 439  	r.bodyBuf = acquireBuffer()
 440  	w := multipart.NewWriter(r.bodyBuf)
 441  
 442  	// Set boundary if not set by user
 443  	if r.multipartBoundary != "" {
 444  		if err := w.SetBoundary(r.multipartBoundary); err != nil {
 445  			return err
 446  		}
 447  	}
 448  
 449  	for k, v := range c.FormData {
 450  		for _, iv := range v {
 451  			if err := w.WriteField(k, iv); err != nil {
 452  				return err
 453  			}
 454  		}
 455  	}
 456  
 457  	for k, v := range r.FormData {
 458  		for _, iv := range v {
 459  			if strings.HasPrefix(k, "@") { // file
 460  				if err := addFile(w, k[1:], iv); err != nil {
 461  					return err
 462  				}
 463  			} else { // form value
 464  				if err := w.WriteField(k, iv); err != nil {
 465  					return err
 466  				}
 467  			}
 468  		}
 469  	}
 470  
 471  	// #21 - adding io.Reader support
 472  	for _, f := range r.multipartFiles {
 473  		if err := addFileReader(w, f); err != nil {
 474  			return err
 475  		}
 476  	}
 477  
 478  	// GitHub #130 adding multipart field support with content type
 479  	for _, mf := range r.multipartFields {
 480  		if err := addMultipartFormField(w, mf); err != nil {
 481  			return err
 482  		}
 483  	}
 484  
 485  	r.Header.Set(hdrContentTypeKey, w.FormDataContentType())
 486  	return w.Close()
 487  }
 488  
 489  func handleFormData(c *Client, r *Request) {
 490  	for k, v := range c.FormData {
 491  		if _, ok := r.FormData[k]; ok {
 492  			continue
 493  		}
 494  		r.FormData[k] = v[:]
 495  	}
 496  
 497  	r.bodyBuf = acquireBuffer()
 498  	r.bodyBuf.WriteString(r.FormData.Encode())
 499  	r.Header.Set(hdrContentTypeKey, formContentType)
 500  	r.isFormData = true
 501  }
 502  
 503  func handleContentType(c *Client, r *Request) {
 504  	if r.Body == http.NoBody {
 505  		return
 506  	}
 507  	contentType := r.Header.Get(hdrContentTypeKey)
 508  	if IsStringEmpty(contentType) {
 509  		contentType = DetectContentType(r.Body)
 510  		r.Header.Set(hdrContentTypeKey, contentType)
 511  	}
 512  }
 513  
 514  func handleRequestBody(c *Client, r *Request) error {
 515  	var bodyBytes []byte
 516  	r.bodyBuf = nil
 517  
 518  	switch body := r.Body.(type) {
 519  	case io.Reader:
 520  		if c.setContentLength || r.setContentLength { // keep backward compatibility
 521  			r.bodyBuf = acquireBuffer()
 522  			if _, err := r.bodyBuf.ReadFrom(body); err != nil {
 523  				return err
 524  			}
 525  			r.Body = nil
 526  		} else {
 527  			// Otherwise buffer less processing for `io.Reader`, sounds good.
 528  			return nil
 529  		}
 530  	case []byte:
 531  		bodyBytes = body
 532  	case string:
 533  		bodyBytes = []byte(body)
 534  	default:
 535  		contentType := r.Header.Get(hdrContentTypeKey)
 536  		kind := kindOf(r.Body)
 537  		var err error
 538  		if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) {
 539  			r.bodyBuf, err = jsonMarshal(c, r, r.Body)
 540  		} else if IsXMLType(contentType) && (kind == reflect.Struct) {
 541  			bodyBytes, err = c.XMLMarshal(r.Body)
 542  		}
 543  		if err != nil {
 544  			return err
 545  		}
 546  	}
 547  
 548  	if bodyBytes == nil && r.bodyBuf == nil {
 549  		return errors.New("unsupported 'Body' type/value")
 550  	}
 551  
 552  	// []byte into Buffer
 553  	if bodyBytes != nil && r.bodyBuf == nil {
 554  		r.bodyBuf = acquireBuffer()
 555  		_, _ = r.bodyBuf.Write(bodyBytes)
 556  	}
 557  
 558  	return nil
 559  }
 560  
 561  func saveResponseIntoFile(c *Client, res *Response) error {
 562  	if res.Request.isSaveResponse {
 563  		file := ""
 564  
 565  		if len(c.outputDirectory) > 0 && !filepath.IsAbs(res.Request.outputFile) {
 566  			file += c.outputDirectory + string(filepath.Separator)
 567  		}
 568  
 569  		file = filepath.Clean(file + res.Request.outputFile)
 570  		if err := createDirectory(filepath.Dir(file)); err != nil {
 571  			return err
 572  		}
 573  
 574  		outFile, err := os.Create(file)
 575  		if err != nil {
 576  			return err
 577  		}
 578  		defer closeq(outFile)
 579  
 580  		// io.Copy reads maximum 32kb size, it is perfect for large file download too
 581  		defer closeq(res.RawResponse.Body)
 582  
 583  		written, err := io.Copy(outFile, res.RawResponse.Body)
 584  		if err != nil {
 585  			return err
 586  		}
 587  
 588  		res.size = written
 589  	}
 590  
 591  	return nil
 592  }
 593  
 594  func getBodyCopy(r *Request) (*bytes.Buffer, error) {
 595  	// If r.bodyBuf present, return the copy
 596  	if r.bodyBuf != nil {
 597  		bodyCopy := acquireBuffer()
 598  		if _, err := io.Copy(bodyCopy, bytes.NewReader(r.bodyBuf.Bytes())); err != nil {
 599  			// cannot use io.Copy(bodyCopy, r.bodyBuf) because io.Copy reset r.bodyBuf
 600  			return nil, err
 601  		}
 602  		return bodyCopy, nil
 603  	}
 604  
 605  	// Maybe body is `io.Reader`.
 606  	// Note: Resty user have to watchout for large body size of `io.Reader`
 607  	if r.RawRequest.Body != nil {
 608  		b, err := io.ReadAll(r.RawRequest.Body)
 609  		if err != nil {
 610  			return nil, err
 611  		}
 612  
 613  		// Restore the Body
 614  		closeq(r.RawRequest.Body)
 615  		r.RawRequest.Body = io.NopCloser(bytes.NewBuffer(b))
 616  
 617  		// Return the Body bytes
 618  		return bytes.NewBuffer(b), nil
 619  	}
 620  	return nil, nil
 621  }
 622