channel.go raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"encoding/binary"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  	"log"
  13  	"sync"
  14  )
  15  
  16  const (
  17  	minPacketLength = 9
  18  	// channelMaxPacket contains the maximum number of bytes that will be
  19  	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
  20  	// the minimum.
  21  	channelMaxPacket = 1 << 15
  22  	// We follow OpenSSH here.
  23  	channelWindowSize = 64 * channelMaxPacket
  24  )
  25  
  26  // NewChannel represents an incoming request to a channel. It must either be
  27  // accepted for use by calling Accept, or rejected by calling Reject.
  28  type NewChannel interface {
  29  	// Accept accepts the channel creation request. It returns the Channel
  30  	// and a Go channel containing SSH requests. The Go channel must be
  31  	// serviced otherwise the Channel will hang.
  32  	Accept() (Channel, <-chan *Request, error)
  33  
  34  	// Reject rejects the channel creation request. After calling
  35  	// this, no other methods on the Channel may be called.
  36  	Reject(reason RejectionReason, message string) error
  37  
  38  	// ChannelType returns the type of the channel, as supplied by the
  39  	// client.
  40  	ChannelType() string
  41  
  42  	// ExtraData returns the arbitrary payload for this channel, as supplied
  43  	// by the client. This data is specific to the channel type.
  44  	ExtraData() []byte
  45  }
  46  
  47  // A Channel is an ordered, reliable, flow-controlled, duplex stream
  48  // that is multiplexed over an SSH connection.
  49  type Channel interface {
  50  	// Read reads up to len(data) bytes from the channel.
  51  	Read(data []byte) (int, error)
  52  
  53  	// Write writes len(data) bytes to the channel.
  54  	Write(data []byte) (int, error)
  55  
  56  	// Close signals end of channel use. No data may be sent after this
  57  	// call.
  58  	Close() error
  59  
  60  	// CloseWrite signals the end of sending in-band
  61  	// data. Requests may still be sent, and the other side may
  62  	// still send data
  63  	CloseWrite() error
  64  
  65  	// SendRequest sends a channel request.  If wantReply is true,
  66  	// it will wait for a reply and return the result as a
  67  	// boolean, otherwise the return value will be false. Channel
  68  	// requests are out-of-band messages so they may be sent even
  69  	// if the data stream is closed or blocked by flow control.
  70  	// If the channel is closed before a reply is returned, io.EOF
  71  	// is returned.
  72  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
  73  
  74  	// Stderr returns an io.ReadWriter that writes to this channel
  75  	// with the extended data type set to stderr. Stderr may
  76  	// safely be read and written from a different goroutine than
  77  	// Read and Write respectively.
  78  	Stderr() io.ReadWriter
  79  }
  80  
  81  // Request is a request sent outside of the normal stream of
  82  // data. Requests can either be specific to an SSH channel, or they
  83  // can be global.
  84  type Request struct {
  85  	Type      string
  86  	WantReply bool
  87  	Payload   []byte
  88  
  89  	ch  *channel
  90  	mux *mux
  91  }
  92  
  93  // Reply sends a response to a request. It must be called for all requests
  94  // where WantReply is true and is a no-op otherwise. The payload argument is
  95  // ignored for replies to channel-specific requests.
  96  func (r *Request) Reply(ok bool, payload []byte) error {
  97  	if !r.WantReply {
  98  		return nil
  99  	}
 100  
 101  	if r.ch == nil {
 102  		return r.mux.ackRequest(ok, payload)
 103  	}
 104  
 105  	return r.ch.ackRequest(ok)
 106  }
 107  
 108  // RejectionReason is an enumeration used when rejecting channel creation
 109  // requests. See RFC 4254, section 5.1.
 110  type RejectionReason uint32
 111  
 112  const (
 113  	Prohibited RejectionReason = iota + 1
 114  	ConnectionFailed
 115  	UnknownChannelType
 116  	ResourceShortage
 117  )
 118  
 119  // String converts the rejection reason to human readable form.
 120  func (r RejectionReason) String() string {
 121  	switch r {
 122  	case Prohibited:
 123  		return "administratively prohibited"
 124  	case ConnectionFailed:
 125  		return "connect failed"
 126  	case UnknownChannelType:
 127  		return "unknown channel type"
 128  	case ResourceShortage:
 129  		return "resource shortage"
 130  	}
 131  	return fmt.Sprintf("unknown reason %d", int(r))
 132  }
 133  
 134  func min(a uint32, b int) uint32 {
 135  	if a < uint32(b) {
 136  		return a
 137  	}
 138  	return uint32(b)
 139  }
 140  
 141  type channelDirection uint8
 142  
 143  const (
 144  	channelInbound channelDirection = iota
 145  	channelOutbound
 146  )
 147  
 148  // channel is an implementation of the Channel interface that works
 149  // with the mux class.
 150  type channel struct {
 151  	// R/O after creation
 152  	chanType          string
 153  	extraData         []byte
 154  	localId, remoteId uint32
 155  
 156  	// maxIncomingPayload and maxRemotePayload are the maximum
 157  	// payload sizes of normal and extended data packets for
 158  	// receiving and sending, respectively. The wire packet will
 159  	// be 9 or 13 bytes larger (excluding encryption overhead).
 160  	maxIncomingPayload uint32
 161  	maxRemotePayload   uint32
 162  
 163  	mux *mux
 164  
 165  	// decided is set to true if an accept or reject message has been sent
 166  	// (for outbound channels) or received (for inbound channels).
 167  	decided bool
 168  
 169  	// direction contains either channelOutbound, for channels created
 170  	// locally, or channelInbound, for channels created by the peer.
 171  	direction channelDirection
 172  
 173  	// Pending internal channel messages.
 174  	msg chan interface{}
 175  
 176  	// Since requests have no ID, there can be only one request
 177  	// with WantReply=true outstanding.  This lock is held by a
 178  	// goroutine that has such an outgoing request pending.
 179  	sentRequestMu sync.Mutex
 180  
 181  	incomingRequests chan *Request
 182  
 183  	sentEOF bool
 184  
 185  	// thread-safe data
 186  	remoteWin  window
 187  	pending    *buffer
 188  	extPending *buffer
 189  
 190  	// windowMu protects myWindow, the flow-control window, and myConsumed,
 191  	// the number of bytes consumed since we last increased myWindow
 192  	windowMu   sync.Mutex
 193  	myWindow   uint32
 194  	myConsumed uint32
 195  
 196  	// writeMu serializes calls to mux.conn.writePacket() and
 197  	// protects sentClose and packetPool. This mutex must be
 198  	// different from windowMu, as writePacket can block if there
 199  	// is a key exchange pending.
 200  	writeMu   sync.Mutex
 201  	sentClose bool
 202  
 203  	// packetPool has a buffer for each extended channel ID to
 204  	// save allocations during writes.
 205  	packetPool map[uint32][]byte
 206  }
 207  
 208  // writePacket sends a packet. If the packet is a channel close, it updates
 209  // sentClose. This method takes the lock c.writeMu.
 210  func (ch *channel) writePacket(packet []byte) error {
 211  	ch.writeMu.Lock()
 212  	if ch.sentClose {
 213  		ch.writeMu.Unlock()
 214  		return io.EOF
 215  	}
 216  	ch.sentClose = (packet[0] == msgChannelClose)
 217  	err := ch.mux.conn.writePacket(packet)
 218  	ch.writeMu.Unlock()
 219  	return err
 220  }
 221  
 222  func (ch *channel) sendMessage(msg interface{}) error {
 223  	if debugMux {
 224  		log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
 225  	}
 226  
 227  	p := Marshal(msg)
 228  	binary.BigEndian.PutUint32(p[1:], ch.remoteId)
 229  	return ch.writePacket(p)
 230  }
 231  
 232  // WriteExtended writes data to a specific extended stream. These streams are
 233  // used, for example, for stderr.
 234  func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
 235  	if ch.sentEOF {
 236  		return 0, io.EOF
 237  	}
 238  	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
 239  	opCode := byte(msgChannelData)
 240  	headerLength := uint32(9)
 241  	if extendedCode > 0 {
 242  		headerLength += 4
 243  		opCode = msgChannelExtendedData
 244  	}
 245  
 246  	ch.writeMu.Lock()
 247  	packet := ch.packetPool[extendedCode]
 248  	// We don't remove the buffer from packetPool, so
 249  	// WriteExtended calls from different goroutines will be
 250  	// flagged as errors by the race detector.
 251  	ch.writeMu.Unlock()
 252  
 253  	for len(data) > 0 {
 254  		space := min(ch.maxRemotePayload, len(data))
 255  		if space, err = ch.remoteWin.reserve(space); err != nil {
 256  			return n, err
 257  		}
 258  		if want := headerLength + space; uint32(cap(packet)) < want {
 259  			packet = make([]byte, want)
 260  		} else {
 261  			packet = packet[:want]
 262  		}
 263  
 264  		todo := data[:space]
 265  
 266  		packet[0] = opCode
 267  		binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
 268  		if extendedCode > 0 {
 269  			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
 270  		}
 271  		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
 272  		copy(packet[headerLength:], todo)
 273  		if err = ch.writePacket(packet); err != nil {
 274  			return n, err
 275  		}
 276  
 277  		n += len(todo)
 278  		data = data[len(todo):]
 279  	}
 280  
 281  	ch.writeMu.Lock()
 282  	ch.packetPool[extendedCode] = packet
 283  	ch.writeMu.Unlock()
 284  
 285  	return n, err
 286  }
 287  
 288  func (ch *channel) handleData(packet []byte) error {
 289  	headerLen := 9
 290  	isExtendedData := packet[0] == msgChannelExtendedData
 291  	if isExtendedData {
 292  		headerLen = 13
 293  	}
 294  	if len(packet) < headerLen {
 295  		// malformed data packet
 296  		return parseError(packet[0])
 297  	}
 298  
 299  	var extended uint32
 300  	if isExtendedData {
 301  		extended = binary.BigEndian.Uint32(packet[5:])
 302  	}
 303  
 304  	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
 305  	if length == 0 {
 306  		return nil
 307  	}
 308  	if length > ch.maxIncomingPayload {
 309  		// TODO(hanwen): should send Disconnect?
 310  		return errors.New("ssh: incoming packet exceeds maximum payload size")
 311  	}
 312  
 313  	data := packet[headerLen:]
 314  	if length != uint32(len(data)) {
 315  		return errors.New("ssh: wrong packet length")
 316  	}
 317  
 318  	ch.windowMu.Lock()
 319  	if ch.myWindow < length {
 320  		ch.windowMu.Unlock()
 321  		// TODO(hanwen): should send Disconnect with reason?
 322  		return errors.New("ssh: remote side wrote too much")
 323  	}
 324  	ch.myWindow -= length
 325  	ch.windowMu.Unlock()
 326  
 327  	if extended == 1 {
 328  		ch.extPending.write(data)
 329  	} else if extended > 0 {
 330  		// discard other extended data.
 331  	} else {
 332  		ch.pending.write(data)
 333  	}
 334  	return nil
 335  }
 336  
 337  func (c *channel) adjustWindow(adj uint32) error {
 338  	c.windowMu.Lock()
 339  	// Since myConsumed and myWindow are managed on our side, and can never
 340  	// exceed the initial window setting, we don't worry about overflow.
 341  	c.myConsumed += adj
 342  	var sendAdj uint32
 343  	if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
 344  		(c.myWindow < channelWindowSize/2) {
 345  		sendAdj = c.myConsumed
 346  		c.myConsumed = 0
 347  		c.myWindow += sendAdj
 348  	}
 349  	c.windowMu.Unlock()
 350  	if sendAdj == 0 {
 351  		return nil
 352  	}
 353  	return c.sendMessage(windowAdjustMsg{
 354  		AdditionalBytes: sendAdj,
 355  	})
 356  }
 357  
 358  func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
 359  	switch extended {
 360  	case 1:
 361  		n, err = c.extPending.Read(data)
 362  	case 0:
 363  		n, err = c.pending.Read(data)
 364  	default:
 365  		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
 366  	}
 367  
 368  	if n > 0 {
 369  		err = c.adjustWindow(uint32(n))
 370  		// sendWindowAdjust can return io.EOF if the remote
 371  		// peer has closed the connection, however we want to
 372  		// defer forwarding io.EOF to the caller of Read until
 373  		// the buffer has been drained.
 374  		if n > 0 && err == io.EOF {
 375  			err = nil
 376  		}
 377  	}
 378  
 379  	return n, err
 380  }
 381  
 382  func (c *channel) close() {
 383  	c.pending.eof()
 384  	c.extPending.eof()
 385  	close(c.msg)
 386  	close(c.incomingRequests)
 387  	c.writeMu.Lock()
 388  	// This is not necessary for a normal channel teardown, but if
 389  	// there was another error, it is.
 390  	c.sentClose = true
 391  	c.writeMu.Unlock()
 392  	// Unblock writers.
 393  	c.remoteWin.close()
 394  }
 395  
 396  // responseMessageReceived is called when a success or failure message is
 397  // received on a channel to check that such a message is reasonable for the
 398  // given channel.
 399  func (ch *channel) responseMessageReceived() error {
 400  	if ch.direction == channelInbound {
 401  		return errors.New("ssh: channel response message received on inbound channel")
 402  	}
 403  	if ch.decided {
 404  		return errors.New("ssh: duplicate response received for channel")
 405  	}
 406  	ch.decided = true
 407  	return nil
 408  }
 409  
 410  func (ch *channel) handlePacket(packet []byte) error {
 411  	switch packet[0] {
 412  	case msgChannelData, msgChannelExtendedData:
 413  		return ch.handleData(packet)
 414  	case msgChannelClose:
 415  		ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
 416  		ch.mux.chanList.remove(ch.localId)
 417  		ch.close()
 418  		return nil
 419  	case msgChannelEOF:
 420  		// RFC 4254 is mute on how EOF affects dataExt messages but
 421  		// it is logical to signal EOF at the same time.
 422  		ch.extPending.eof()
 423  		ch.pending.eof()
 424  		return nil
 425  	}
 426  
 427  	decoded, err := decode(packet)
 428  	if err != nil {
 429  		return err
 430  	}
 431  
 432  	switch msg := decoded.(type) {
 433  	case *channelOpenFailureMsg:
 434  		if err := ch.responseMessageReceived(); err != nil {
 435  			return err
 436  		}
 437  		ch.mux.chanList.remove(msg.PeersID)
 438  		ch.msg <- msg
 439  	case *channelOpenConfirmMsg:
 440  		if err := ch.responseMessageReceived(); err != nil {
 441  			return err
 442  		}
 443  		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
 444  			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
 445  		}
 446  		ch.remoteId = msg.MyID
 447  		ch.maxRemotePayload = msg.MaxPacketSize
 448  		ch.remoteWin.add(msg.MyWindow)
 449  		ch.msg <- msg
 450  	case *windowAdjustMsg:
 451  		if !ch.remoteWin.add(msg.AdditionalBytes) {
 452  			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
 453  		}
 454  	case *channelRequestMsg:
 455  		req := Request{
 456  			Type:      msg.Request,
 457  			WantReply: msg.WantReply,
 458  			Payload:   msg.RequestSpecificData,
 459  			ch:        ch,
 460  		}
 461  
 462  		ch.incomingRequests <- &req
 463  	default:
 464  		ch.msg <- msg
 465  	}
 466  	return nil
 467  }
 468  
 469  func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
 470  	ch := &channel{
 471  		remoteWin:        window{Cond: newCond()},
 472  		myWindow:         channelWindowSize,
 473  		pending:          newBuffer(),
 474  		extPending:       newBuffer(),
 475  		direction:        direction,
 476  		incomingRequests: make(chan *Request, chanSize),
 477  		msg:              make(chan interface{}, chanSize),
 478  		chanType:         chanType,
 479  		extraData:        extraData,
 480  		mux:              m,
 481  		packetPool:       make(map[uint32][]byte),
 482  	}
 483  	ch.localId = m.chanList.add(ch)
 484  	return ch
 485  }
 486  
 487  var errUndecided = errors.New("ssh: must Accept or Reject channel")
 488  var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
 489  
 490  type extChannel struct {
 491  	code uint32
 492  	ch   *channel
 493  }
 494  
 495  func (e *extChannel) Write(data []byte) (n int, err error) {
 496  	return e.ch.WriteExtended(data, e.code)
 497  }
 498  
 499  func (e *extChannel) Read(data []byte) (n int, err error) {
 500  	return e.ch.ReadExtended(data, e.code)
 501  }
 502  
 503  func (ch *channel) Accept() (Channel, <-chan *Request, error) {
 504  	if ch.decided {
 505  		return nil, nil, errDecidedAlready
 506  	}
 507  	ch.maxIncomingPayload = channelMaxPacket
 508  	confirm := channelOpenConfirmMsg{
 509  		PeersID:       ch.remoteId,
 510  		MyID:          ch.localId,
 511  		MyWindow:      ch.myWindow,
 512  		MaxPacketSize: ch.maxIncomingPayload,
 513  	}
 514  	ch.decided = true
 515  	if err := ch.sendMessage(confirm); err != nil {
 516  		return nil, nil, err
 517  	}
 518  
 519  	return ch, ch.incomingRequests, nil
 520  }
 521  
 522  func (ch *channel) Reject(reason RejectionReason, message string) error {
 523  	if ch.decided {
 524  		return errDecidedAlready
 525  	}
 526  	reject := channelOpenFailureMsg{
 527  		PeersID:  ch.remoteId,
 528  		Reason:   reason,
 529  		Message:  message,
 530  		Language: "en",
 531  	}
 532  	ch.decided = true
 533  	return ch.sendMessage(reject)
 534  }
 535  
 536  func (ch *channel) Read(data []byte) (int, error) {
 537  	if !ch.decided {
 538  		return 0, errUndecided
 539  	}
 540  	return ch.ReadExtended(data, 0)
 541  }
 542  
 543  func (ch *channel) Write(data []byte) (int, error) {
 544  	if !ch.decided {
 545  		return 0, errUndecided
 546  	}
 547  	return ch.WriteExtended(data, 0)
 548  }
 549  
 550  func (ch *channel) CloseWrite() error {
 551  	if !ch.decided {
 552  		return errUndecided
 553  	}
 554  	ch.sentEOF = true
 555  	return ch.sendMessage(channelEOFMsg{
 556  		PeersID: ch.remoteId})
 557  }
 558  
 559  func (ch *channel) Close() error {
 560  	if !ch.decided {
 561  		return errUndecided
 562  	}
 563  
 564  	return ch.sendMessage(channelCloseMsg{
 565  		PeersID: ch.remoteId})
 566  }
 567  
 568  // Extended returns an io.ReadWriter that sends and receives data on the given,
 569  // SSH extended stream. Such streams are used, for example, for stderr.
 570  func (ch *channel) Extended(code uint32) io.ReadWriter {
 571  	if !ch.decided {
 572  		return nil
 573  	}
 574  	return &extChannel{code, ch}
 575  }
 576  
 577  func (ch *channel) Stderr() io.ReadWriter {
 578  	return ch.Extended(1)
 579  }
 580  
 581  func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
 582  	if !ch.decided {
 583  		return false, errUndecided
 584  	}
 585  
 586  	if wantReply {
 587  		ch.sentRequestMu.Lock()
 588  		defer ch.sentRequestMu.Unlock()
 589  	}
 590  
 591  	msg := channelRequestMsg{
 592  		PeersID:             ch.remoteId,
 593  		Request:             name,
 594  		WantReply:           wantReply,
 595  		RequestSpecificData: payload,
 596  	}
 597  
 598  	if err := ch.sendMessage(msg); err != nil {
 599  		return false, err
 600  	}
 601  
 602  	if wantReply {
 603  		m, ok := (<-ch.msg)
 604  		if !ok {
 605  			return false, io.EOF
 606  		}
 607  		switch m.(type) {
 608  		case *channelRequestFailureMsg:
 609  			return false, nil
 610  		case *channelRequestSuccessMsg:
 611  			return true, nil
 612  		default:
 613  			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
 614  		}
 615  	}
 616  
 617  	return false, nil
 618  }
 619  
 620  // ackRequest either sends an ack or nack to the channel request.
 621  func (ch *channel) ackRequest(ok bool) error {
 622  	if !ch.decided {
 623  		return errUndecided
 624  	}
 625  
 626  	var msg interface{}
 627  	if !ok {
 628  		msg = channelRequestFailureMsg{
 629  			PeersID: ch.remoteId,
 630  		}
 631  	} else {
 632  		msg = channelRequestSuccessMsg{
 633  			PeersID: ch.remoteId,
 634  		}
 635  	}
 636  	return ch.sendMessage(msg)
 637  }
 638  
 639  func (ch *channel) ChannelType() string {
 640  	return ch.chanType
 641  }
 642  
 643  func (ch *channel) ExtraData() []byte {
 644  	return ch.extraData
 645  }
 646