dial.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proxy
   6  
   7  import (
   8  	"context"
   9  	"net"
  10  )
  11  
  12  // A ContextDialer dials using a context.
  13  type ContextDialer interface {
  14  	DialContext(ctx context.Context, network, address string) (net.Conn, error)
  15  }
  16  
  17  // Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
  18  //
  19  // The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
  20  //
  21  // Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
  22  // can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
  23  //
  24  // A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
  25  func Dial(ctx context.Context, network, address string) (net.Conn, error) {
  26  	d := FromEnvironment()
  27  	if xd, ok := d.(ContextDialer); ok {
  28  		return xd.DialContext(ctx, network, address)
  29  	}
  30  	return dialContext(ctx, d, network, address)
  31  }
  32  
  33  // WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
  34  // A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
  35  func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
  36  	var (
  37  		conn net.Conn
  38  		done = make(chan struct{}, 1)
  39  		err  error
  40  	)
  41  	go func() {
  42  		conn, err = d.Dial(network, address)
  43  		close(done)
  44  		if conn != nil && ctx.Err() != nil {
  45  			conn.Close()
  46  		}
  47  	}()
  48  	select {
  49  	case <-ctx.Done():
  50  		err = ctx.Err()
  51  	case <-done:
  52  	}
  53  	return conn, err
  54  }
  55