client.go raw

   1  // Copyright 2018 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package socks
   6  
   7  import (
   8  	"context"
   9  	"errors"
  10  	"io"
  11  	"net"
  12  	"strconv"
  13  	"time"
  14  )
  15  
  16  var (
  17  	noDeadline   = time.Time{}
  18  	aLongTimeAgo = time.Unix(1, 0)
  19  )
  20  
  21  func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
  22  	host, port, err := splitHostPort(address)
  23  	if err != nil {
  24  		return nil, err
  25  	}
  26  	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
  27  		c.SetDeadline(deadline)
  28  		defer c.SetDeadline(noDeadline)
  29  	}
  30  	if ctx != context.Background() {
  31  		errCh := make(chan error, 1)
  32  		done := make(chan struct{})
  33  		defer func() {
  34  			close(done)
  35  			if ctxErr == nil {
  36  				ctxErr = <-errCh
  37  			}
  38  		}()
  39  		go func() {
  40  			select {
  41  			case <-ctx.Done():
  42  				c.SetDeadline(aLongTimeAgo)
  43  				errCh <- ctx.Err()
  44  			case <-done:
  45  				errCh <- nil
  46  			}
  47  		}()
  48  	}
  49  
  50  	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
  51  	b = append(b, Version5)
  52  	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
  53  		b = append(b, 1, byte(AuthMethodNotRequired))
  54  	} else {
  55  		ams := d.AuthMethods
  56  		if len(ams) > 255 {
  57  			return nil, errors.New("too many authentication methods")
  58  		}
  59  		b = append(b, byte(len(ams)))
  60  		for _, am := range ams {
  61  			b = append(b, byte(am))
  62  		}
  63  	}
  64  	if _, ctxErr = c.Write(b); ctxErr != nil {
  65  		return
  66  	}
  67  
  68  	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
  69  		return
  70  	}
  71  	if b[0] != Version5 {
  72  		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
  73  	}
  74  	am := AuthMethod(b[1])
  75  	if am == AuthMethodNoAcceptableMethods {
  76  		return nil, errors.New("no acceptable authentication methods")
  77  	}
  78  	if d.Authenticate != nil {
  79  		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
  80  			return
  81  		}
  82  	}
  83  
  84  	b = b[:0]
  85  	b = append(b, Version5, byte(d.cmd), 0)
  86  	if ip := net.ParseIP(host); ip != nil {
  87  		if ip4 := ip.To4(); ip4 != nil {
  88  			b = append(b, AddrTypeIPv4)
  89  			b = append(b, ip4...)
  90  		} else if ip6 := ip.To16(); ip6 != nil {
  91  			b = append(b, AddrTypeIPv6)
  92  			b = append(b, ip6...)
  93  		} else {
  94  			return nil, errors.New("unknown address type")
  95  		}
  96  	} else {
  97  		if len(host) > 255 {
  98  			return nil, errors.New("FQDN too long")
  99  		}
 100  		b = append(b, AddrTypeFQDN)
 101  		b = append(b, byte(len(host)))
 102  		b = append(b, host...)
 103  	}
 104  	b = append(b, byte(port>>8), byte(port))
 105  	if _, ctxErr = c.Write(b); ctxErr != nil {
 106  		return
 107  	}
 108  
 109  	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
 110  		return
 111  	}
 112  	if b[0] != Version5 {
 113  		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
 114  	}
 115  	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
 116  		return nil, errors.New("unknown error " + cmdErr.String())
 117  	}
 118  	if b[2] != 0 {
 119  		return nil, errors.New("non-zero reserved field")
 120  	}
 121  	l := 2
 122  	var a Addr
 123  	switch b[3] {
 124  	case AddrTypeIPv4:
 125  		l += net.IPv4len
 126  		a.IP = make(net.IP, net.IPv4len)
 127  	case AddrTypeIPv6:
 128  		l += net.IPv6len
 129  		a.IP = make(net.IP, net.IPv6len)
 130  	case AddrTypeFQDN:
 131  		if _, err := io.ReadFull(c, b[:1]); err != nil {
 132  			return nil, err
 133  		}
 134  		l += int(b[0])
 135  	default:
 136  		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
 137  	}
 138  	if cap(b) < l {
 139  		b = make([]byte, l)
 140  	} else {
 141  		b = b[:l]
 142  	}
 143  	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
 144  		return
 145  	}
 146  	if a.IP != nil {
 147  		copy(a.IP, b)
 148  	} else {
 149  		a.Name = string(b[:len(b)-2])
 150  	}
 151  	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
 152  	return &a, nil
 153  }
 154  
 155  func splitHostPort(address string) (string, int, error) {
 156  	host, port, err := net.SplitHostPort(address)
 157  	if err != nil {
 158  		return "", 0, err
 159  	}
 160  	portnum, err := strconv.Atoi(port)
 161  	if err != nil {
 162  		return "", 0, err
 163  	}
 164  	if 1 > portnum || portnum > 0xffff {
 165  		return "", 0, errors.New("port number out of range " + port)
 166  	}
 167  	return host, portnum, nil
 168  }
 169