per_host.go raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proxy
   6  
   7  import (
   8  	"context"
   9  	"net"
  10  	"net/netip"
  11  	"strings"
  12  )
  13  
  14  // A PerHost directs connections to a default Dialer unless the host name
  15  // requested matches one of a number of exceptions.
  16  type PerHost struct {
  17  	def, bypass Dialer
  18  
  19  	bypassNetworks []*net.IPNet
  20  	bypassIPs      []net.IP
  21  	bypassZones    []string
  22  	bypassHosts    []string
  23  }
  24  
  25  // NewPerHost returns a PerHost Dialer that directs connections to either
  26  // defaultDialer or bypass, depending on whether the connection matches one of
  27  // the configured rules.
  28  func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
  29  	return &PerHost{
  30  		def:    defaultDialer,
  31  		bypass: bypass,
  32  	}
  33  }
  34  
  35  // Dial connects to the address addr on the given network through either
  36  // defaultDialer or bypass.
  37  func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
  38  	host, _, err := net.SplitHostPort(addr)
  39  	if err != nil {
  40  		return nil, err
  41  	}
  42  
  43  	return p.dialerForRequest(host).Dial(network, addr)
  44  }
  45  
  46  // DialContext connects to the address addr on the given network through either
  47  // defaultDialer or bypass.
  48  func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
  49  	host, _, err := net.SplitHostPort(addr)
  50  	if err != nil {
  51  		return nil, err
  52  	}
  53  	d := p.dialerForRequest(host)
  54  	if x, ok := d.(ContextDialer); ok {
  55  		return x.DialContext(ctx, network, addr)
  56  	}
  57  	return dialContext(ctx, d, network, addr)
  58  }
  59  
  60  func (p *PerHost) dialerForRequest(host string) Dialer {
  61  	if nip, err := netip.ParseAddr(host); err == nil {
  62  		ip := net.IP(nip.AsSlice())
  63  		for _, net := range p.bypassNetworks {
  64  			if net.Contains(ip) {
  65  				return p.bypass
  66  			}
  67  		}
  68  		for _, bypassIP := range p.bypassIPs {
  69  			if bypassIP.Equal(ip) {
  70  				return p.bypass
  71  			}
  72  		}
  73  		return p.def
  74  	}
  75  
  76  	for _, zone := range p.bypassZones {
  77  		if strings.HasSuffix(host, zone) {
  78  			return p.bypass
  79  		}
  80  		if host == zone[1:] {
  81  			// For a zone ".example.com", we match "example.com"
  82  			// too.
  83  			return p.bypass
  84  		}
  85  	}
  86  	for _, bypassHost := range p.bypassHosts {
  87  		if bypassHost == host {
  88  			return p.bypass
  89  		}
  90  	}
  91  	return p.def
  92  }
  93  
  94  // AddFromString parses a string that contains comma-separated values
  95  // specifying hosts that should use the bypass proxy. Each value is either an
  96  // IP address, a CIDR range, a zone (*.example.com) or a host name
  97  // (localhost). A best effort is made to parse the string and errors are
  98  // ignored.
  99  func (p *PerHost) AddFromString(s string) {
 100  	hosts := strings.Split(s, ",")
 101  	for _, host := range hosts {
 102  		host = strings.TrimSpace(host)
 103  		if len(host) == 0 {
 104  			continue
 105  		}
 106  		if strings.Contains(host, "/") {
 107  			// We assume that it's a CIDR address like 127.0.0.0/8
 108  			if _, net, err := net.ParseCIDR(host); err == nil {
 109  				p.AddNetwork(net)
 110  			}
 111  			continue
 112  		}
 113  		if nip, err := netip.ParseAddr(host); err == nil {
 114  			p.AddIP(net.IP(nip.AsSlice()))
 115  			continue
 116  		}
 117  		if strings.HasPrefix(host, "*.") {
 118  			p.AddZone(host[1:])
 119  			continue
 120  		}
 121  		p.AddHost(host)
 122  	}
 123  }
 124  
 125  // AddIP specifies an IP address that will use the bypass proxy. Note that
 126  // this will only take effect if a literal IP address is dialed. A connection
 127  // to a named host will never match an IP.
 128  func (p *PerHost) AddIP(ip net.IP) {
 129  	p.bypassIPs = append(p.bypassIPs, ip)
 130  }
 131  
 132  // AddNetwork specifies an IP range that will use the bypass proxy. Note that
 133  // this will only take effect if a literal IP address is dialed. A connection
 134  // to a named host will never match.
 135  func (p *PerHost) AddNetwork(net *net.IPNet) {
 136  	p.bypassNetworks = append(p.bypassNetworks, net)
 137  }
 138  
 139  // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
 140  // "example.com" matches "example.com" and all of its subdomains.
 141  func (p *PerHost) AddZone(zone string) {
 142  	zone = strings.TrimSuffix(zone, ".")
 143  	if !strings.HasPrefix(zone, ".") {
 144  		zone = "." + zone
 145  	}
 146  	p.bypassZones = append(p.bypassZones, zone)
 147  }
 148  
 149  // AddHost specifies a host name that will use the bypass proxy.
 150  func (p *PerHost) AddHost(host string) {
 151  	host = strings.TrimSuffix(host, ".")
 152  	p.bypassHosts = append(p.bypassHosts, host)
 153  }
 154