client.go raw

   1  /*
   2   * Copyright 2017 Baidu, Inc.
   3   *
   4   * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
   5   * except in compliance with the License. You may obtain a copy of the License at
   6   *
   7   * http://www.apache.org/licenses/LICENSE-2.0
   8   *
   9   * Unless required by applicable law or agreed to in writing, software distributed under the
  10   * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
  11   * either express or implied. See the License for the specific language governing permissions
  12   * and limitations under the License.
  13   */
  14  
  15  // client.go - define the execute function to send http request and get response
  16  
  17  // Package http defines the structure of request and response which used to access the BCE services
  18  // as well as the http constant headers and and methods. And finally implement the `Execute` funct-
  19  // ion to do the work.
  20  package http
  21  
  22  import (
  23  	"context"
  24  	"crypto/tls"
  25  	"fmt"
  26  	"net"
  27  	"net/http"
  28  	"net/url"
  29  	"sync"
  30  	"time"
  31  
  32  	"github.com/baidubce/bce-sdk-go/util/log"
  33  )
  34  
  35  var (
  36  	defaultMaxIdleConnsPerHost   = 500
  37  	defaultResponseHeaderTimeout = 60 * time.Second
  38  	defaultDialTimeout           = 30 * time.Second
  39  	defaultKeepAlive             = 30 * time.Second
  40  	defaultSmallInterval         = 600 * time.Second
  41  	defaultLargeInterval         = 1200 * time.Second
  42  	defaultTLSHandshakeTimeout   = 10 * time.Second
  43  	defaultIdleConnTimeout       = 90 * time.Second
  44  	defaultHTTPClientTimeout     = 1200 * time.Second
  45  	NilHTTPClient                = fmt.Errorf("custom HTTP Client is nil")
  46  	UnknownHTTPTransport         = fmt.Errorf("invalid customized HTTP Client: RoundTripper is not http.Transport")
  47  )
  48  
  49  // The httpClient is the global variable to send the request and get response
  50  // for reuse and the Client provided by the Go standard library is thread safe.
  51  var (
  52  	httpClient *http.Client
  53  	transport  *http.Transport
  54  )
  55  
  56  var defaultHTTPClient = &http.Client{
  57  	Timeout:   defaultHTTPClientTimeout,
  58  	Transport: NewTransportCustom(&DefaultClientConfig),
  59  }
  60  
  61  type Dialer struct {
  62  	net.Dialer
  63  	ReadTimeout  *time.Duration
  64  	WriteTimeout *time.Duration
  65  	postRead     []func(n int, err error)
  66  	postWrite    []func(n int, err error)
  67  }
  68  
  69  func NewDialer(config *ClientConfig) *Dialer {
  70  	dialer := &Dialer{
  71  		Dialer: net.Dialer{
  72  			Timeout:   *config.DialTimeout,
  73  			KeepAlive: *config.KeepAlive,
  74  		},
  75  		ReadTimeout:  config.ReadTimeout,
  76  		WriteTimeout: config.WriteTimeout,
  77  		postRead:     config.PostRead,
  78  		postWrite:    config.PostWrite,
  79  	}
  80  	return dialer
  81  }
  82  
  83  func (d *Dialer) Dial(network, address string) (net.Conn, error) {
  84  	return d.DialContext(context.Background(), network, address)
  85  }
  86  
  87  func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  88  	c, err := d.Dialer.DialContext(ctx, network, address)
  89  	if err != nil {
  90  		return c, err
  91  	}
  92  	tc := &timeoutConn{
  93  		conn:          c,
  94  		smallInterval: defaultSmallInterval,
  95  		largeInterval: defaultLargeInterval,
  96  		readTimeout:   d.ReadTimeout,
  97  		writeTimeout:  d.WriteTimeout,
  98  		dialer:        d,
  99  	}
 100  	if tc.readTimeout != nil {
 101  		err = tc.SetReadDeadline(time.Now().Add(*tc.readTimeout))
 102  	} else {
 103  		err = tc.SetReadDeadline(time.Now().Add(defaultLargeInterval))
 104  	}
 105  	if err != nil {
 106  		return nil, err
 107  	}
 108  	if tc.writeTimeout != nil {
 109  		err = tc.SetWriteDeadline(time.Now().Add(*tc.writeTimeout))
 110  	} else {
 111  		err = tc.SetWriteDeadline(time.Now().Add(defaultLargeInterval))
 112  	}
 113  	if err != nil {
 114  		return nil, err
 115  	}
 116  	return tc, err
 117  }
 118  
 119  type timeoutConn struct {
 120  	conn          net.Conn
 121  	smallInterval time.Duration
 122  	largeInterval time.Duration
 123  	readTimeout   *time.Duration
 124  	writeTimeout  *time.Duration
 125  	dialer        *Dialer
 126  }
 127  
 128  func (c *timeoutConn) Read(b []byte) (n int, err error) {
 129  	if c.readTimeout == nil {
 130  		err = c.SetReadDeadline(time.Now().Add(c.smallInterval))
 131  		if err != nil {
 132  			return n, err
 133  		}
 134  	}
 135  	n, err = c.conn.Read(b)
 136  	if err != nil {
 137  		return n, err
 138  	}
 139  	if c.readTimeout == nil {
 140  		err = c.SetReadDeadline(time.Now().Add(c.largeInterval))
 141  	} else {
 142  		err = c.SetReadDeadline(time.Now().Add(*c.readTimeout))
 143  	}
 144  	if c.dialer != nil {
 145  		for _, fn := range c.dialer.postRead {
 146  			fn(n, err)
 147  		}
 148  	}
 149  	return n, err
 150  }
 151  func (c *timeoutConn) Write(b []byte) (n int, err error) {
 152  	if c.writeTimeout == nil {
 153  		err = c.SetWriteDeadline(time.Now().Add(c.smallInterval))
 154  	}
 155  	if err != nil {
 156  		return n, err
 157  	}
 158  	n, err = c.conn.Write(b)
 159  	if err != nil {
 160  		return n, err
 161  	}
 162  	if c.writeTimeout == nil {
 163  		err = c.SetWriteDeadline(time.Now().Add(c.largeInterval))
 164  	} else {
 165  		err = c.SetWriteDeadline(time.Now().Add(*c.writeTimeout))
 166  	}
 167  	if c.dialer != nil {
 168  		for _, fn := range c.dialer.postWrite {
 169  			fn(n, err)
 170  		}
 171  	}
 172  	return n, err
 173  }
 174  func (c *timeoutConn) Close() error                       { return c.conn.Close() }
 175  func (c *timeoutConn) LocalAddr() net.Addr                { return c.conn.LocalAddr() }
 176  func (c *timeoutConn) RemoteAddr() net.Addr               { return c.conn.RemoteAddr() }
 177  func (c *timeoutConn) SetDeadline(t time.Time) error      { return c.conn.SetDeadline(t) }
 178  func (c *timeoutConn) SetReadDeadline(t time.Time) error  { return c.conn.SetReadDeadline(t) }
 179  func (c *timeoutConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
 180  
 181  type ClientConfig struct {
 182  	RedirectDisabled      bool
 183  	DisableKeepAlives     bool
 184  	NoVerifySSL           bool
 185  	DialTimeout           *time.Duration
 186  	KeepAlive             *time.Duration
 187  	ReadTimeout           *time.Duration
 188  	WriteTimeout          *time.Duration
 189  	TLSHandshakeTimeout   *time.Duration
 190  	IdleConnectionTimeout *time.Duration
 191  	ResponseHeaderTimeout *time.Duration
 192  	HTTPClientTimeout     *time.Duration
 193  	PostRead              []func(n int, err error)
 194  	PostWrite             []func(n int, err error)
 195  }
 196  
 197  var DefaultClientConfig = ClientConfig{
 198  	RedirectDisabled:      false,
 199  	DisableKeepAlives:     false,
 200  	DialTimeout:           &defaultDialTimeout,
 201  	KeepAlive:             &defaultKeepAlive,
 202  	TLSHandshakeTimeout:   &defaultTLSHandshakeTimeout,
 203  	IdleConnectionTimeout: &defaultIdleConnTimeout,
 204  	ResponseHeaderTimeout: &defaultResponseHeaderTimeout,
 205  	HTTPClientTimeout:     &defaultHTTPClientTimeout,
 206  }
 207  
 208  func (cfg *ClientConfig) Copy(src *ClientConfig) {
 209  	if src == nil {
 210  		return
 211  	}
 212  	cfg.RedirectDisabled = src.RedirectDisabled
 213  	cfg.DisableKeepAlives = src.DisableKeepAlives
 214  	if src.DialTimeout != nil {
 215  		cfg.DialTimeout = src.DialTimeout
 216  	}
 217  	if src.KeepAlive != nil {
 218  		cfg.KeepAlive = src.KeepAlive
 219  	}
 220  	if src.ReadTimeout != nil {
 221  		cfg.ReadTimeout = src.ReadTimeout
 222  	}
 223  	if src.WriteTimeout != nil {
 224  		cfg.WriteTimeout = src.WriteTimeout
 225  	}
 226  	if src.TLSHandshakeTimeout != nil {
 227  		cfg.TLSHandshakeTimeout = src.TLSHandshakeTimeout
 228  	}
 229  	if src.IdleConnectionTimeout != nil {
 230  		cfg.IdleConnectionTimeout = src.IdleConnectionTimeout
 231  	}
 232  	if src.ResponseHeaderTimeout != nil {
 233  		cfg.ResponseHeaderTimeout = src.ResponseHeaderTimeout
 234  	}
 235  	if src.HTTPClientTimeout != nil {
 236  		cfg.HTTPClientTimeout = src.HTTPClientTimeout
 237  	}
 238  	if len(src.PostRead) > 0 {
 239  		cfg.PostRead = append(cfg.PostRead, src.PostRead...)
 240  	}
 241  	if len(src.PostWrite) > 0 {
 242  		cfg.PostWrite = append(cfg.PostWrite, src.PostWrite...)
 243  	}
 244  }
 245  
 246  func MergeWithDefaultConfig(cfgs ...*ClientConfig) *ClientConfig {
 247  	dst := &ClientConfig{}
 248  	dst.Copy(&DefaultClientConfig)
 249  	for _, cfg := range cfgs {
 250  		dst.Copy(cfg)
 251  	}
 252  	return dst
 253  }
 254  
 255  var customizeInit sync.Once
 256  
 257  func InitClient(config ClientConfig) {
 258  	customizeInit.Do(func() {
 259  		httpClient = &http.Client{}
 260  		maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
 261  		if config.DisableKeepAlives {
 262  			maxIdleConnsPerHost = -1
 263  		}
 264  		transport = &http.Transport{
 265  			MaxIdleConnsPerHost:   maxIdleConnsPerHost,
 266  			ResponseHeaderTimeout: defaultResponseHeaderTimeout,
 267  			DisableKeepAlives:     config.DisableKeepAlives,
 268  			Dial: func(network, address string) (net.Conn, error) {
 269  				conn, err := net.DialTimeout(network, address, defaultDialTimeout)
 270  				if err != nil {
 271  					return nil, err
 272  				}
 273  				tc := &timeoutConn{
 274  					conn:          conn,
 275  					smallInterval: defaultSmallInterval,
 276  					largeInterval: defaultLargeInterval,
 277  				}
 278  				tc.SetReadDeadline(time.Now().Add(defaultLargeInterval))
 279  				return tc, nil
 280  			},
 281  		}
 282  		httpClient.Transport = transport
 283  		if config.RedirectDisabled {
 284  			httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
 285  				return http.ErrUseLastResponse
 286  			}
 287  		}
 288  	})
 289  }
 290  
 291  func NewTransportCustom(config *ClientConfig) *http.Transport {
 292  	maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
 293  	if config.DisableKeepAlives {
 294  		maxIdleConnsPerHost = -1
 295  	}
 296  	transport := &http.Transport{
 297  		DialContext:           NewDialer(config).DialContext,
 298  		MaxIdleConnsPerHost:   maxIdleConnsPerHost,
 299  		DisableKeepAlives:     config.DisableKeepAlives,
 300  		TLSHandshakeTimeout:   *config.TLSHandshakeTimeout,
 301  		ResponseHeaderTimeout: *config.ResponseHeaderTimeout,
 302  		IdleConnTimeout:       *config.IdleConnectionTimeout,
 303  	}
 304  	if config.NoVerifySSL {
 305  		transport.TLSClientConfig = &tls.Config{
 306  			InsecureSkipVerify: true,
 307  		}
 308  	}
 309  	return transport
 310  }
 311  
 312  func InitWithSpecifiedClient(customHTTPClient *http.Client) error {
 313  	if customHTTPClient == nil {
 314  		log.Warnf("customized HTTP Client is nil")
 315  		return NilHTTPClient
 316  	}
 317  	if customHTTPClient.Transport == nil {
 318  		log.Infof("transport is nil")
 319  		return nil
 320  	}
 321  	if customTransport, ok := customHTTPClient.Transport.(*http.Transport); ok {
 322  		transport = customTransport
 323  		return nil
 324  	}
 325  	log.Warnf("unknown transport type")
 326  	return UnknownHTTPTransport
 327  }
 328  
 329  func InitExclusiveHTTPClient(config *ClientConfig) *http.Client {
 330  	config = MergeWithDefaultConfig(config)
 331  	transport = NewTransportCustom(config)
 332  	myHTTPClient := &http.Client{
 333  		Timeout:   *config.HTTPClientTimeout,
 334  		Transport: transport,
 335  	}
 336  	if config.RedirectDisabled {
 337  		myHTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
 338  			return http.ErrUseLastResponse
 339  		}
 340  	}
 341  	return myHTTPClient
 342  }
 343  
 344  func InitClientWithTimeout(config *ClientConfig) {
 345  	customizeInit.Do(func() {
 346  		config = MergeWithDefaultConfig(config)
 347  		transport = NewTransportCustom(config)
 348  		httpClient = &http.Client{
 349  			Timeout:   *config.HTTPClientTimeout,
 350  			Transport: transport,
 351  		}
 352  		if config.RedirectDisabled {
 353  			httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
 354  				return http.ErrUseLastResponse
 355  			}
 356  		}
 357  	})
 358  }
 359  
 360  // Execute - do the http requset and get the response
 361  //
 362  // PARAMS:
 363  //   - request: the http request instance to be sent
 364  //
 365  // RETURNS:
 366  //   - response: the http response returned from the server
 367  //   - error: nil if ok otherwise the specific error
 368  func Execute(request *Request) (*Response, error) {
 369  	// Build the request object for the current requesting
 370  	httpRequest := &http.Request{
 371  		Proto:      "HTTP/1.1",
 372  		ProtoMajor: 1,
 373  		ProtoMinor: 1,
 374  	}
 375  
 376  	// get http client
 377  	curHTTPClient := httpClient
 378  	if request.HTTPClient() != nil {
 379  		curHTTPClient = request.HTTPClient()
 380  	}
 381  	if curHTTPClient == nil {
 382  		log.Infof("use default http client to execute request")
 383  		curHTTPClient = defaultHTTPClient
 384  	}
 385  
 386  	if request.Context() != nil {
 387  		httpRequest = httpRequest.WithContext(request.Context())
 388  	} else {
 389  		// Set the connection timeout for current request
 390  		curHTTPClient.Timeout = time.Duration(request.Timeout()) * time.Second
 391  	}
 392  
 393  	// Set the request method
 394  	httpRequest.Method = request.Method()
 395  
 396  	// Set the request url
 397  	internalUrl := &url.URL{
 398  		Scheme:   request.Protocol(),
 399  		Host:     request.Host(),
 400  		Path:     request.Uri(),
 401  		RawQuery: request.QueryString()}
 402  	httpRequest.URL = internalUrl
 403  
 404  	// Set the request headers
 405  	internalHeader := make(http.Header)
 406  	for k, v := range request.Headers() {
 407  		val := make([]string, 0, 1)
 408  		val = append(val, v)
 409  		internalHeader[k] = val
 410  	}
 411  	httpRequest.Header = internalHeader
 412  
 413  	if request.Body() != nil {
 414  		if request.Length() > 0 {
 415  			httpRequest.ContentLength = request.Length()
 416  			httpRequest.Body = request.Body()
 417  		} else if request.Length() < 0 {
 418  			// if set body and ContentLength <= 0, will be chunked
 419  			httpRequest.Body = request.Body()
 420  		} // else {} body == nil and ContentLength == 0
 421  	}
 422  
 423  	// Set the proxy setting if needed
 424  	if len(request.ProxyUrl()) != 0 && transport != nil {
 425  		transport.Proxy = func(_ *http.Request) (*url.URL, error) {
 426  			return url.Parse(request.ProxyUrl())
 427  		}
 428  	}
 429  
 430  	// Perform the http request and get response
 431  	// It needs to explicitly close the keep-alive connections when error occurs for the request
 432  	// that may continue sending request's data subsequently.
 433  	start := time.Now()
 434  
 435  	httpResponse, err := curHTTPClient.Do(httpRequest)
 436  	end := time.Now()
 437  	if err != nil {
 438  		if transport != nil {
 439  			transport.CloseIdleConnections()
 440  		}
 441  		return nil, err
 442  	}
 443  	if transport != nil && httpResponse.StatusCode >= 400 &&
 444  		(httpRequest.Method == PUT || httpRequest.Method == POST) {
 445  		transport.CloseIdleConnections()
 446  	}
 447  	response := &Response{httpResponse, end.Sub(start)}
 448  	return response, nil
 449  }
 450  func SetResponseHeaderTimeout(t int) {
 451  	transport = &http.Transport{
 452  		MaxIdleConnsPerHost:   defaultMaxIdleConnsPerHost,
 453  		ResponseHeaderTimeout: time.Duration(t) * time.Second,
 454  		Dial: func(network, address string) (net.Conn, error) {
 455  			conn, err := net.DialTimeout(network, address, defaultDialTimeout)
 456  			if err != nil {
 457  				return nil, err
 458  			}
 459  			tc := &timeoutConn{
 460  				conn:          conn,
 461  				smallInterval: defaultSmallInterval,
 462  				largeInterval: defaultLargeInterval,
 463  			}
 464  			err = tc.SetReadDeadline(time.Now().Add(defaultLargeInterval))
 465  			if err != nil {
 466  				return nil, err
 467  			}
 468  			return tc, nil
 469  		},
 470  	}
 471  	httpClient.Transport = transport
 472  }
 473