mux.go raw

   1  // Copyright 2013 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"encoding/binary"
   9  	"fmt"
  10  	"io"
  11  	"log"
  12  	"sync"
  13  	"sync/atomic"
  14  )
  15  
  16  // debugMux, if set, causes messages in the connection protocol to be
  17  // logged.
  18  const debugMux = false
  19  
  20  // chanList is a thread safe channel list.
  21  type chanList struct {
  22  	// protects concurrent access to chans
  23  	sync.Mutex
  24  
  25  	// chans are indexed by the local id of the channel, which the
  26  	// other side should send in the PeersId field.
  27  	chans []*channel
  28  
  29  	// This is a debugging aid: it offsets all IDs by this
  30  	// amount. This helps distinguish otherwise identical
  31  	// server/client muxes
  32  	offset uint32
  33  }
  34  
  35  // Assigns a channel ID to the given channel.
  36  func (c *chanList) add(ch *channel) uint32 {
  37  	c.Lock()
  38  	defer c.Unlock()
  39  	for i := range c.chans {
  40  		if c.chans[i] == nil {
  41  			c.chans[i] = ch
  42  			return uint32(i) + c.offset
  43  		}
  44  	}
  45  	c.chans = append(c.chans, ch)
  46  	return uint32(len(c.chans)-1) + c.offset
  47  }
  48  
  49  // getChan returns the channel for the given ID.
  50  func (c *chanList) getChan(id uint32) *channel {
  51  	id -= c.offset
  52  
  53  	c.Lock()
  54  	defer c.Unlock()
  55  	if id < uint32(len(c.chans)) {
  56  		return c.chans[id]
  57  	}
  58  	return nil
  59  }
  60  
  61  func (c *chanList) remove(id uint32) {
  62  	id -= c.offset
  63  	c.Lock()
  64  	if id < uint32(len(c.chans)) {
  65  		c.chans[id] = nil
  66  	}
  67  	c.Unlock()
  68  }
  69  
  70  // dropAll forgets all channels it knows, returning them in a slice.
  71  func (c *chanList) dropAll() []*channel {
  72  	c.Lock()
  73  	defer c.Unlock()
  74  	var r []*channel
  75  
  76  	for _, ch := range c.chans {
  77  		if ch == nil {
  78  			continue
  79  		}
  80  		r = append(r, ch)
  81  	}
  82  	c.chans = nil
  83  	return r
  84  }
  85  
  86  // mux represents the state for the SSH connection protocol, which
  87  // multiplexes many channels onto a single packet transport.
  88  type mux struct {
  89  	conn     packetConn
  90  	chanList chanList
  91  
  92  	incomingChannels chan NewChannel
  93  
  94  	globalSentMu     sync.Mutex
  95  	globalResponses  chan interface{}
  96  	incomingRequests chan *Request
  97  
  98  	errCond *sync.Cond
  99  	err     error
 100  }
 101  
 102  // When debugging, each new chanList instantiation has a different
 103  // offset.
 104  var globalOff uint32
 105  
 106  func (m *mux) Wait() error {
 107  	m.errCond.L.Lock()
 108  	defer m.errCond.L.Unlock()
 109  	for m.err == nil {
 110  		m.errCond.Wait()
 111  	}
 112  	return m.err
 113  }
 114  
 115  // newMux returns a mux that runs over the given connection.
 116  func newMux(p packetConn) *mux {
 117  	m := &mux{
 118  		conn:             p,
 119  		incomingChannels: make(chan NewChannel, chanSize),
 120  		globalResponses:  make(chan interface{}, 1),
 121  		incomingRequests: make(chan *Request, chanSize),
 122  		errCond:          newCond(),
 123  	}
 124  	if debugMux {
 125  		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
 126  	}
 127  
 128  	go m.loop()
 129  	return m
 130  }
 131  
 132  func (m *mux) sendMessage(msg interface{}) error {
 133  	p := Marshal(msg)
 134  	if debugMux {
 135  		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
 136  	}
 137  	return m.conn.writePacket(p)
 138  }
 139  
 140  func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
 141  	if wantReply {
 142  		m.globalSentMu.Lock()
 143  		defer m.globalSentMu.Unlock()
 144  	}
 145  
 146  	if err := m.sendMessage(globalRequestMsg{
 147  		Type:      name,
 148  		WantReply: wantReply,
 149  		Data:      payload,
 150  	}); err != nil {
 151  		return false, nil, err
 152  	}
 153  
 154  	if !wantReply {
 155  		return false, nil, nil
 156  	}
 157  
 158  	msg, ok := <-m.globalResponses
 159  	if !ok {
 160  		return false, nil, io.EOF
 161  	}
 162  	switch msg := msg.(type) {
 163  	case *globalRequestFailureMsg:
 164  		return false, msg.Data, nil
 165  	case *globalRequestSuccessMsg:
 166  		return true, msg.Data, nil
 167  	default:
 168  		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
 169  	}
 170  }
 171  
 172  // ackRequest must be called after processing a global request that
 173  // has WantReply set.
 174  func (m *mux) ackRequest(ok bool, data []byte) error {
 175  	if ok {
 176  		return m.sendMessage(globalRequestSuccessMsg{Data: data})
 177  	}
 178  	return m.sendMessage(globalRequestFailureMsg{Data: data})
 179  }
 180  
 181  func (m *mux) Close() error {
 182  	return m.conn.Close()
 183  }
 184  
 185  // loop runs the connection machine. It will process packets until an
 186  // error is encountered. To synchronize on loop exit, use mux.Wait.
 187  func (m *mux) loop() {
 188  	var err error
 189  	for err == nil {
 190  		err = m.onePacket()
 191  	}
 192  
 193  	for _, ch := range m.chanList.dropAll() {
 194  		ch.close()
 195  	}
 196  
 197  	close(m.incomingChannels)
 198  	close(m.incomingRequests)
 199  	close(m.globalResponses)
 200  
 201  	m.conn.Close()
 202  
 203  	m.errCond.L.Lock()
 204  	m.err = err
 205  	m.errCond.Broadcast()
 206  	m.errCond.L.Unlock()
 207  
 208  	if debugMux {
 209  		log.Println("loop exit", err)
 210  	}
 211  }
 212  
 213  // onePacket reads and processes one packet.
 214  func (m *mux) onePacket() error {
 215  	packet, err := m.conn.readPacket()
 216  	if err != nil {
 217  		return err
 218  	}
 219  
 220  	if debugMux {
 221  		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
 222  			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
 223  		} else {
 224  			p, _ := decode(packet)
 225  			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
 226  		}
 227  	}
 228  
 229  	switch packet[0] {
 230  	case msgChannelOpen:
 231  		return m.handleChannelOpen(packet)
 232  	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
 233  		return m.handleGlobalPacket(packet)
 234  	case msgPing:
 235  		var msg pingMsg
 236  		if err := Unmarshal(packet, &msg); err != nil {
 237  			return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
 238  		}
 239  		return m.sendMessage(pongMsg(msg))
 240  	}
 241  
 242  	// assume a channel packet.
 243  	if len(packet) < 5 {
 244  		return parseError(packet[0])
 245  	}
 246  	id := binary.BigEndian.Uint32(packet[1:])
 247  	ch := m.chanList.getChan(id)
 248  	if ch == nil {
 249  		return m.handleUnknownChannelPacket(id, packet)
 250  	}
 251  
 252  	return ch.handlePacket(packet)
 253  }
 254  
 255  func (m *mux) handleGlobalPacket(packet []byte) error {
 256  	msg, err := decode(packet)
 257  	if err != nil {
 258  		return err
 259  	}
 260  
 261  	switch msg := msg.(type) {
 262  	case *globalRequestMsg:
 263  		m.incomingRequests <- &Request{
 264  			Type:      msg.Type,
 265  			WantReply: msg.WantReply,
 266  			Payload:   msg.Data,
 267  			mux:       m,
 268  		}
 269  	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
 270  		m.globalResponses <- msg
 271  	default:
 272  		panic(fmt.Sprintf("not a global message %#v", msg))
 273  	}
 274  
 275  	return nil
 276  }
 277  
 278  // handleChannelOpen schedules a channel to be Accept()ed.
 279  func (m *mux) handleChannelOpen(packet []byte) error {
 280  	var msg channelOpenMsg
 281  	if err := Unmarshal(packet, &msg); err != nil {
 282  		return err
 283  	}
 284  
 285  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
 286  		failMsg := channelOpenFailureMsg{
 287  			PeersID:  msg.PeersID,
 288  			Reason:   ConnectionFailed,
 289  			Message:  "invalid request",
 290  			Language: "en_US.UTF-8",
 291  		}
 292  		return m.sendMessage(failMsg)
 293  	}
 294  
 295  	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
 296  	c.remoteId = msg.PeersID
 297  	c.maxRemotePayload = msg.MaxPacketSize
 298  	c.remoteWin.add(msg.PeersWindow)
 299  	m.incomingChannels <- c
 300  	return nil
 301  }
 302  
 303  func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
 304  	ch, err := m.openChannel(chanType, extra)
 305  	if err != nil {
 306  		return nil, nil, err
 307  	}
 308  
 309  	return ch, ch.incomingRequests, nil
 310  }
 311  
 312  func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
 313  	ch := m.newChannel(chanType, channelOutbound, extra)
 314  
 315  	ch.maxIncomingPayload = channelMaxPacket
 316  
 317  	open := channelOpenMsg{
 318  		ChanType:         chanType,
 319  		PeersWindow:      ch.myWindow,
 320  		MaxPacketSize:    ch.maxIncomingPayload,
 321  		TypeSpecificData: extra,
 322  		PeersID:          ch.localId,
 323  	}
 324  	if err := m.sendMessage(open); err != nil {
 325  		return nil, err
 326  	}
 327  
 328  	switch msg := (<-ch.msg).(type) {
 329  	case *channelOpenConfirmMsg:
 330  		return ch, nil
 331  	case *channelOpenFailureMsg:
 332  		return nil, &OpenChannelError{msg.Reason, msg.Message}
 333  	default:
 334  		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
 335  	}
 336  }
 337  
 338  func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
 339  	msg, err := decode(packet)
 340  	if err != nil {
 341  		return err
 342  	}
 343  
 344  	switch msg := msg.(type) {
 345  	// RFC 4254 section 5.4 says unrecognized channel requests should
 346  	// receive a failure response.
 347  	case *channelRequestMsg:
 348  		if msg.WantReply {
 349  			return m.sendMessage(channelRequestFailureMsg{
 350  				PeersID: msg.PeersID,
 351  			})
 352  		}
 353  		return nil
 354  	default:
 355  		return fmt.Errorf("ssh: invalid channel %d", id)
 356  	}
 357  }
 358