tcpip.go raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"context"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  	"math/rand"
  13  	"net"
  14  	"net/netip"
  15  	"strconv"
  16  	"strings"
  17  	"sync"
  18  	"time"
  19  )
  20  
  21  // Listen requests the remote peer open a listening socket on
  22  // addr. Incoming connections will be available by calling Accept on
  23  // the returned net.Listener. The listener must be serviced, or the
  24  // SSH connection may hang.
  25  // N must be "tcp", "tcp4", "tcp6", or "unix".
  26  //
  27  // If the address is a hostname, it is sent to the remote peer as-is, without
  28  // being resolved locally, and the Listener Addr method will return a zero IP.
  29  func (c *Client) Listen(n, addr string) (net.Listener, error) {
  30  	switch n {
  31  	case "tcp", "tcp4", "tcp6":
  32  		host, portStr, err := net.SplitHostPort(addr)
  33  		if err != nil {
  34  			return nil, err
  35  		}
  36  		port, err := strconv.ParseInt(portStr, 10, 32)
  37  		if err != nil {
  38  			return nil, err
  39  		}
  40  		return c.listenTCPInternal(host, int(port))
  41  	case "unix":
  42  		return c.ListenUnix(addr)
  43  	default:
  44  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
  45  	}
  46  }
  47  
  48  // Automatic port allocation is broken with OpenSSH before 6.0. See
  49  // also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
  50  // particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
  51  // rather than the actual port number. This means you can never open
  52  // two different listeners with auto allocated ports. We work around
  53  // this by trying explicit ports until we succeed.
  54  
  55  const openSSHPrefix = "OpenSSH_"
  56  
  57  var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
  58  
  59  // isBrokenOpenSSHVersion returns true if the given version string
  60  // specifies a version of OpenSSH that is known to have a bug in port
  61  // forwarding.
  62  func isBrokenOpenSSHVersion(versionStr string) bool {
  63  	i := strings.Index(versionStr, openSSHPrefix)
  64  	if i < 0 {
  65  		return false
  66  	}
  67  	i += len(openSSHPrefix)
  68  	j := i
  69  	for ; j < len(versionStr); j++ {
  70  		if versionStr[j] < '0' || versionStr[j] > '9' {
  71  			break
  72  		}
  73  	}
  74  	version, _ := strconv.Atoi(versionStr[i:j])
  75  	return version < 6
  76  }
  77  
  78  // autoPortListenWorkaround simulates automatic port allocation by
  79  // trying random ports repeatedly.
  80  func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
  81  	var sshListener net.Listener
  82  	var err error
  83  	const tries = 10
  84  	for i := 0; i < tries; i++ {
  85  		addr := *laddr
  86  		addr.Port = 1024 + portRandomizer.Intn(60000)
  87  		sshListener, err = c.ListenTCP(&addr)
  88  		if err == nil {
  89  			laddr.Port = addr.Port
  90  			return sshListener, err
  91  		}
  92  	}
  93  	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
  94  }
  95  
  96  // RFC 4254 7.1
  97  type channelForwardMsg struct {
  98  	addr  string
  99  	rport uint32
 100  }
 101  
 102  // handleForwards starts goroutines handling forwarded connections.
 103  // It's called on first use by (*Client).ListenTCP to not launch
 104  // goroutines until needed.
 105  func (c *Client) handleForwards() {
 106  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
 107  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
 108  }
 109  
 110  // ListenTCP requests the remote peer open a listening socket
 111  // on laddr. Incoming connections will be available by calling
 112  // Accept on the returned net.Listener.
 113  //
 114  // ListenTCP accepts an IP address, to provide a hostname use [Client.Listen]
 115  // with "tcp", "tcp4", or "tcp6" network instead.
 116  func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
 117  	c.handleForwardsOnce.Do(c.handleForwards)
 118  	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
 119  		return c.autoPortListenWorkaround(laddr)
 120  	}
 121  
 122  	return c.listenTCPInternal(laddr.IP.String(), laddr.Port)
 123  }
 124  
 125  func (c *Client) listenTCPInternal(host string, port int) (net.Listener, error) {
 126  	c.handleForwardsOnce.Do(c.handleForwards)
 127  
 128  	m := channelForwardMsg{
 129  		host,
 130  		uint32(port),
 131  	}
 132  	// send message
 133  	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
 134  	if err != nil {
 135  		return nil, err
 136  	}
 137  	if !ok {
 138  		return nil, errors.New("ssh: tcpip-forward request denied by peer")
 139  	}
 140  
 141  	// If the original port was 0, then the remote side will
 142  	// supply a real port number in the response.
 143  	if port == 0 {
 144  		var p struct {
 145  			Port uint32
 146  		}
 147  		if err := Unmarshal(resp, &p); err != nil {
 148  			return nil, err
 149  		}
 150  		port = int(p.Port)
 151  	}
 152  	// Construct a local address placeholder for the remote listener. If the
 153  	// original host is an IP address, preserve it so that Listener.Addr()
 154  	// reports the same IP. If the host is a hostname or cannot be parsed as an
 155  	// IP, fall back to IPv4zero. The port field is always set, even if the
 156  	// original port was 0, because in that case the remote server will assign
 157  	// one, allowing callers to determine which port was selected.
 158  	ip := net.IPv4zero
 159  	if parsed, err := netip.ParseAddr(host); err == nil {
 160  		ip = net.IP(parsed.AsSlice())
 161  	}
 162  	laddr := &net.TCPAddr{
 163  		IP:   ip,
 164  		Port: port,
 165  	}
 166  	addr := net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
 167  	ch := c.forwards.add("tcp", addr)
 168  
 169  	return &tcpListener{laddr, addr, c, ch}, nil
 170  }
 171  
 172  // forwardList stores a mapping between remote
 173  // forward requests and the tcpListeners.
 174  type forwardList struct {
 175  	sync.Mutex
 176  	entries []forwardEntry
 177  }
 178  
 179  // forwardEntry represents an established mapping of a laddr on a
 180  // remote ssh server to a channel connected to a tcpListener.
 181  type forwardEntry struct {
 182  	addr    string // host:port or socket path
 183  	network string // tcp or unix
 184  	c       chan forward
 185  }
 186  
 187  // forward represents an incoming forwarded tcpip connection. The
 188  // arguments to add/remove/lookup should be address as specified in
 189  // the original forward-request.
 190  type forward struct {
 191  	newCh NewChannel // the ssh client channel underlying this forward
 192  	raddr net.Addr   // the raddr of the incoming connection
 193  }
 194  
 195  func (l *forwardList) add(n, addr string) chan forward {
 196  	l.Lock()
 197  	defer l.Unlock()
 198  	f := forwardEntry{
 199  		addr:    addr,
 200  		network: n,
 201  		c:       make(chan forward, 1),
 202  	}
 203  	l.entries = append(l.entries, f)
 204  	return f.c
 205  }
 206  
 207  // See RFC 4254, section 7.2
 208  type forwardedTCPPayload struct {
 209  	Addr       string
 210  	Port       uint32
 211  	OriginAddr string
 212  	OriginPort uint32
 213  }
 214  
 215  // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
 216  func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
 217  	if port == 0 || port > 65535 {
 218  		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
 219  	}
 220  	ip, err := netip.ParseAddr(addr)
 221  	if err != nil {
 222  		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
 223  	}
 224  	return &net.TCPAddr{IP: net.IP(ip.AsSlice()), Port: int(port)}, nil
 225  }
 226  
 227  func (l *forwardList) handleChannels(in <-chan NewChannel) {
 228  	for ch := range in {
 229  		var (
 230  			addr    string
 231  			network string
 232  			raddr   net.Addr
 233  			err     error
 234  		)
 235  		switch channelType := ch.ChannelType(); channelType {
 236  		case "forwarded-tcpip":
 237  			var payload forwardedTCPPayload
 238  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
 239  				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
 240  				continue
 241  			}
 242  
 243  			// RFC 4254 section 7.2 specifies that incoming addresses should
 244  			// list the address that was connected, in string format. It is the
 245  			// same address used in the tcpip-forward request. The originator
 246  			// address is an IP address instead.
 247  			addr = net.JoinHostPort(payload.Addr, strconv.FormatUint(uint64(payload.Port), 10))
 248  
 249  			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
 250  			if err != nil {
 251  				ch.Reject(ConnectionFailed, err.Error())
 252  				continue
 253  			}
 254  			network = "tcp"
 255  		case "forwarded-streamlocal@openssh.com":
 256  			var payload forwardedStreamLocalPayload
 257  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
 258  				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
 259  				continue
 260  			}
 261  			addr = payload.SocketPath
 262  			raddr = &net.UnixAddr{
 263  				Name: "@",
 264  				Net:  "unix",
 265  			}
 266  			network = "unix"
 267  		default:
 268  			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
 269  		}
 270  		if ok := l.forward(network, addr, raddr, ch); !ok {
 271  			// Section 7.2, implementations MUST reject spurious incoming
 272  			// connections.
 273  			ch.Reject(Prohibited, "no forward for address")
 274  			continue
 275  		}
 276  
 277  	}
 278  }
 279  
 280  // remove removes the forward entry, and the channel feeding its
 281  // listener.
 282  func (l *forwardList) remove(n, addr string) {
 283  	l.Lock()
 284  	defer l.Unlock()
 285  	for i, f := range l.entries {
 286  		if n == f.network && addr == f.addr {
 287  			l.entries = append(l.entries[:i], l.entries[i+1:]...)
 288  			close(f.c)
 289  			return
 290  		}
 291  	}
 292  }
 293  
 294  // closeAll closes and clears all forwards.
 295  func (l *forwardList) closeAll() {
 296  	l.Lock()
 297  	defer l.Unlock()
 298  	for _, f := range l.entries {
 299  		close(f.c)
 300  	}
 301  	l.entries = nil
 302  }
 303  
 304  func (l *forwardList) forward(n, addr string, raddr net.Addr, ch NewChannel) bool {
 305  	l.Lock()
 306  	defer l.Unlock()
 307  	for _, f := range l.entries {
 308  		if n == f.network && addr == f.addr {
 309  			f.c <- forward{newCh: ch, raddr: raddr}
 310  			return true
 311  		}
 312  	}
 313  	return false
 314  }
 315  
 316  type tcpListener struct {
 317  	laddr *net.TCPAddr
 318  	addr  string
 319  
 320  	conn *Client
 321  	in   <-chan forward
 322  }
 323  
 324  // Accept waits for and returns the next connection to the listener.
 325  func (l *tcpListener) Accept() (net.Conn, error) {
 326  	s, ok := <-l.in
 327  	if !ok {
 328  		return nil, io.EOF
 329  	}
 330  	ch, incoming, err := s.newCh.Accept()
 331  	if err != nil {
 332  		return nil, err
 333  	}
 334  	go DiscardRequests(incoming)
 335  
 336  	return &chanConn{
 337  		Channel: ch,
 338  		laddr:   l.laddr,
 339  		raddr:   s.raddr,
 340  	}, nil
 341  }
 342  
 343  // Close closes the listener.
 344  func (l *tcpListener) Close() error {
 345  	host, port, err := net.SplitHostPort(l.addr)
 346  	if err != nil {
 347  		return err
 348  	}
 349  	rport, err := strconv.ParseUint(port, 10, 32)
 350  	if err != nil {
 351  		return err
 352  	}
 353  	m := channelForwardMsg{
 354  		host,
 355  		uint32(rport),
 356  	}
 357  
 358  	// this also closes the listener.
 359  	l.conn.forwards.remove("tcp", l.addr)
 360  	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
 361  	if err == nil && !ok {
 362  		err = errors.New("ssh: cancel-tcpip-forward failed")
 363  	}
 364  	return err
 365  }
 366  
 367  // Addr returns the listener's network address.
 368  func (l *tcpListener) Addr() net.Addr {
 369  	return l.laddr
 370  }
 371  
 372  // DialContext initiates a connection to the addr from the remote host.
 373  //
 374  // The provided Context must be non-nil. If the context expires before the
 375  // connection is complete, an error is returned. Once successfully connected,
 376  // any expiration of the context will not affect the connection.
 377  //
 378  // See func Dial for additional information.
 379  func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
 380  	if err := ctx.Err(); err != nil {
 381  		return nil, err
 382  	}
 383  	type connErr struct {
 384  		conn net.Conn
 385  		err  error
 386  	}
 387  	ch := make(chan connErr)
 388  	go func() {
 389  		conn, err := c.Dial(n, addr)
 390  		select {
 391  		case ch <- connErr{conn, err}:
 392  		case <-ctx.Done():
 393  			if conn != nil {
 394  				conn.Close()
 395  			}
 396  		}
 397  	}()
 398  	select {
 399  	case res := <-ch:
 400  		return res.conn, res.err
 401  	case <-ctx.Done():
 402  		return nil, ctx.Err()
 403  	}
 404  }
 405  
 406  // Dial initiates a connection to the addr from the remote host.
 407  // The resulting connection has a zero LocalAddr() and RemoteAddr().
 408  func (c *Client) Dial(n, addr string) (net.Conn, error) {
 409  	var ch Channel
 410  	switch n {
 411  	case "tcp", "tcp4", "tcp6":
 412  		// Parse the address into host and numeric port.
 413  		host, portString, err := net.SplitHostPort(addr)
 414  		if err != nil {
 415  			return nil, err
 416  		}
 417  		port, err := strconv.ParseUint(portString, 10, 16)
 418  		if err != nil {
 419  			return nil, err
 420  		}
 421  		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
 422  		if err != nil {
 423  			return nil, err
 424  		}
 425  		// Use a zero address for local and remote address.
 426  		zeroAddr := &net.TCPAddr{
 427  			IP:   net.IPv4zero,
 428  			Port: 0,
 429  		}
 430  		return &chanConn{
 431  			Channel: ch,
 432  			laddr:   zeroAddr,
 433  			raddr:   zeroAddr,
 434  		}, nil
 435  	case "unix":
 436  		var err error
 437  		ch, err = c.dialStreamLocal(addr)
 438  		if err != nil {
 439  			return nil, err
 440  		}
 441  		return &chanConn{
 442  			Channel: ch,
 443  			laddr: &net.UnixAddr{
 444  				Name: "@",
 445  				Net:  "unix",
 446  			},
 447  			raddr: &net.UnixAddr{
 448  				Name: addr,
 449  				Net:  "unix",
 450  			},
 451  		}, nil
 452  	default:
 453  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
 454  	}
 455  }
 456  
 457  // DialTCP connects to the remote address raddr on the network net,
 458  // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
 459  // as the local address for the connection.
 460  func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
 461  	if laddr == nil {
 462  		laddr = &net.TCPAddr{
 463  			IP:   net.IPv4zero,
 464  			Port: 0,
 465  		}
 466  	}
 467  	ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
 468  	if err != nil {
 469  		return nil, err
 470  	}
 471  	return &chanConn{
 472  		Channel: ch,
 473  		laddr:   laddr,
 474  		raddr:   raddr,
 475  	}, nil
 476  }
 477  
 478  // RFC 4254 7.2
 479  type channelOpenDirectMsg struct {
 480  	raddr string
 481  	rport uint32
 482  	laddr string
 483  	lport uint32
 484  }
 485  
 486  func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
 487  	msg := channelOpenDirectMsg{
 488  		raddr: raddr,
 489  		rport: uint32(rport),
 490  		laddr: laddr,
 491  		lport: uint32(lport),
 492  	}
 493  	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
 494  	if err != nil {
 495  		return nil, err
 496  	}
 497  	go DiscardRequests(in)
 498  	return ch, nil
 499  }
 500  
 501  type tcpChan struct {
 502  	Channel // the backing channel
 503  }
 504  
 505  // chanConn fulfills the net.Conn interface without
 506  // the tcpChan having to hold laddr or raddr directly.
 507  type chanConn struct {
 508  	Channel
 509  	laddr, raddr net.Addr
 510  }
 511  
 512  // LocalAddr returns the local network address.
 513  func (t *chanConn) LocalAddr() net.Addr {
 514  	return t.laddr
 515  }
 516  
 517  // RemoteAddr returns the remote network address.
 518  func (t *chanConn) RemoteAddr() net.Addr {
 519  	return t.raddr
 520  }
 521  
 522  // SetDeadline sets the read and write deadlines associated
 523  // with the connection.
 524  func (t *chanConn) SetDeadline(deadline time.Time) error {
 525  	if err := t.SetReadDeadline(deadline); err != nil {
 526  		return err
 527  	}
 528  	return t.SetWriteDeadline(deadline)
 529  }
 530  
 531  // SetReadDeadline sets the read deadline.
 532  // A zero value for t means Read will not time out.
 533  // After the deadline, the error from Read will implement net.Error
 534  // with Timeout() == true.
 535  func (t *chanConn) SetReadDeadline(deadline time.Time) error {
 536  	// for compatibility with previous version,
 537  	// the error message contains "tcpChan"
 538  	return errors.New("ssh: tcpChan: deadline not supported")
 539  }
 540  
 541  // SetWriteDeadline exists to satisfy the net.Conn interface
 542  // but is not implemented by this type.  It always returns an error.
 543  func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
 544  	return errors.New("ssh: tcpChan: deadline not supported")
 545  }
 546