transport.go raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"bufio"
   9  	"bytes"
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  	"log"
  14  )
  15  
  16  // debugTransport if set, will print packet types as they go over the
  17  // wire. No message decoding is done, to minimize the impact on timing.
  18  const debugTransport = false
  19  
  20  // packetConn represents a transport that implements packet based
  21  // operations.
  22  type packetConn interface {
  23  	// Encrypt and send a packet of data to the remote peer.
  24  	writePacket(packet []byte) error
  25  
  26  	// Read a packet from the connection. The read is blocking,
  27  	// i.e. if error is nil, then the returned byte slice is
  28  	// always non-empty.
  29  	readPacket() ([]byte, error)
  30  
  31  	// Close closes the write-side of the connection.
  32  	Close() error
  33  }
  34  
  35  // transport is the keyingTransport that implements the SSH packet
  36  // protocol.
  37  type transport struct {
  38  	reader connectionState
  39  	writer connectionState
  40  
  41  	bufReader *bufio.Reader
  42  	bufWriter *bufio.Writer
  43  	rand      io.Reader
  44  	isClient  bool
  45  	io.Closer
  46  
  47  	strictMode     bool
  48  	initialKEXDone bool
  49  }
  50  
  51  // packetCipher represents a combination of SSH encryption/MAC
  52  // protocol.  A single instance should be used for one direction only.
  53  type packetCipher interface {
  54  	// writeCipherPacket encrypts the packet and writes it to w. The
  55  	// contents of the packet are generally scrambled.
  56  	writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
  57  
  58  	// readCipherPacket reads and decrypts a packet of data. The
  59  	// returned packet may be overwritten by future calls of
  60  	// readPacket.
  61  	readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
  62  }
  63  
  64  // connectionState represents one side (read or write) of the
  65  // connection. This is necessary because each direction has its own
  66  // keys, and can even have its own algorithms
  67  type connectionState struct {
  68  	packetCipher
  69  	seqNum           uint32
  70  	dir              direction
  71  	pendingKeyChange chan packetCipher
  72  }
  73  
  74  func (t *transport) setStrictMode() error {
  75  	if t.reader.seqNum != 1 {
  76  		return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
  77  	}
  78  	t.strictMode = true
  79  	return nil
  80  }
  81  
  82  func (t *transport) setInitialKEXDone() {
  83  	t.initialKEXDone = true
  84  }
  85  
  86  // prepareKeyChange sets up key material for a keychange. The key changes in
  87  // both directions are triggered by reading and writing a msgNewKey packet
  88  // respectively.
  89  func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error {
  90  	ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult)
  91  	if err != nil {
  92  		return err
  93  	}
  94  	t.reader.pendingKeyChange <- ciph
  95  
  96  	ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult)
  97  	if err != nil {
  98  		return err
  99  	}
 100  	t.writer.pendingKeyChange <- ciph
 101  
 102  	return nil
 103  }
 104  
 105  func (t *transport) printPacket(p []byte, write bool) {
 106  	if len(p) == 0 {
 107  		return
 108  	}
 109  	who := "server"
 110  	if t.isClient {
 111  		who = "client"
 112  	}
 113  	what := "read"
 114  	if write {
 115  		what = "write"
 116  	}
 117  
 118  	log.Println(what, who, p[0])
 119  }
 120  
 121  // Read and decrypt next packet.
 122  func (t *transport) readPacket() (p []byte, err error) {
 123  	for {
 124  		p, err = t.reader.readPacket(t.bufReader, t.strictMode)
 125  		if err != nil {
 126  			break
 127  		}
 128  		// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
 129  		if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
 130  			break
 131  		}
 132  	}
 133  	if debugTransport {
 134  		t.printPacket(p, false)
 135  	}
 136  
 137  	return p, err
 138  }
 139  
 140  func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
 141  	packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
 142  	s.seqNum++
 143  	if err == nil && len(packet) == 0 {
 144  		err = errors.New("ssh: zero length packet")
 145  	}
 146  
 147  	if len(packet) > 0 {
 148  		switch packet[0] {
 149  		case msgNewKeys:
 150  			select {
 151  			case cipher := <-s.pendingKeyChange:
 152  				s.packetCipher = cipher
 153  				if strictMode {
 154  					s.seqNum = 0
 155  				}
 156  			default:
 157  				return nil, errors.New("ssh: got bogus newkeys message")
 158  			}
 159  
 160  		case msgDisconnect:
 161  			// Transform a disconnect message into an
 162  			// error. Since this is lowest level at which
 163  			// we interpret message types, doing it here
 164  			// ensures that we don't have to handle it
 165  			// elsewhere.
 166  			var msg disconnectMsg
 167  			if err := Unmarshal(packet, &msg); err != nil {
 168  				return nil, err
 169  			}
 170  			return nil, &msg
 171  		}
 172  	}
 173  
 174  	// The packet may point to an internal buffer, so copy the
 175  	// packet out here.
 176  	fresh := make([]byte, len(packet))
 177  	copy(fresh, packet)
 178  
 179  	return fresh, err
 180  }
 181  
 182  func (t *transport) writePacket(packet []byte) error {
 183  	if debugTransport {
 184  		t.printPacket(packet, true)
 185  	}
 186  	return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
 187  }
 188  
 189  func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
 190  	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
 191  
 192  	err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
 193  	if err != nil {
 194  		return err
 195  	}
 196  	if err = w.Flush(); err != nil {
 197  		return err
 198  	}
 199  	s.seqNum++
 200  	if changeKeys {
 201  		select {
 202  		case cipher := <-s.pendingKeyChange:
 203  			s.packetCipher = cipher
 204  			if strictMode {
 205  				s.seqNum = 0
 206  			}
 207  		default:
 208  			panic("ssh: no key material for msgNewKeys")
 209  		}
 210  	}
 211  	return err
 212  }
 213  
 214  func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
 215  	t := &transport{
 216  		bufReader: bufio.NewReader(rwc),
 217  		bufWriter: bufio.NewWriter(rwc),
 218  		rand:      rand,
 219  		reader: connectionState{
 220  			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
 221  			pendingKeyChange: make(chan packetCipher, 1),
 222  		},
 223  		writer: connectionState{
 224  			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
 225  			pendingKeyChange: make(chan packetCipher, 1),
 226  		},
 227  		Closer: rwc,
 228  	}
 229  	t.isClient = isClient
 230  
 231  	if isClient {
 232  		t.reader.dir = serverKeys
 233  		t.writer.dir = clientKeys
 234  	} else {
 235  		t.reader.dir = clientKeys
 236  		t.writer.dir = serverKeys
 237  	}
 238  
 239  	return t
 240  }
 241  
 242  type direction struct {
 243  	ivTag     []byte
 244  	keyTag    []byte
 245  	macKeyTag []byte
 246  }
 247  
 248  var (
 249  	serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
 250  	clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
 251  )
 252  
 253  // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
 254  // described in RFC 4253, section 6.4. direction should either be serverKeys
 255  // (to setup server->client keys) or clientKeys (for client->server keys).
 256  func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) {
 257  	cipherMode := cipherModes[algs.Cipher]
 258  	if cipherMode == nil {
 259  		return nil, fmt.Errorf("ssh: unsupported cipher %v", algs.Cipher)
 260  	}
 261  
 262  	iv := make([]byte, cipherMode.ivSize)
 263  	key := make([]byte, cipherMode.keySize)
 264  
 265  	generateKeyMaterial(iv, d.ivTag, kex)
 266  	generateKeyMaterial(key, d.keyTag, kex)
 267  
 268  	var macKey []byte
 269  	if !aeadCiphers[algs.Cipher] {
 270  		macMode := macModes[algs.MAC]
 271  		macKey = make([]byte, macMode.keySize)
 272  		generateKeyMaterial(macKey, d.macKeyTag, kex)
 273  	}
 274  
 275  	return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
 276  }
 277  
 278  // generateKeyMaterial fills out with key material generated from tag, K, H
 279  // and sessionId, as specified in RFC 4253, section 7.2.
 280  func generateKeyMaterial(out, tag []byte, r *kexResult) {
 281  	var digestsSoFar []byte
 282  
 283  	h := r.Hash.New()
 284  	for len(out) > 0 {
 285  		h.Reset()
 286  		h.Write(r.K)
 287  		h.Write(r.H)
 288  
 289  		if len(digestsSoFar) == 0 {
 290  			h.Write(tag)
 291  			h.Write(r.SessionID)
 292  		} else {
 293  			h.Write(digestsSoFar)
 294  		}
 295  
 296  		digest := h.Sum(nil)
 297  		n := copy(out, digest)
 298  		out = out[n:]
 299  		if len(out) > 0 {
 300  			digestsSoFar = append(digestsSoFar, digest...)
 301  		}
 302  	}
 303  }
 304  
 305  const packageVersion = "SSH-2.0-Go"
 306  
 307  // Sends and receives a version line.  The versionLine string should
 308  // be US ASCII, start with "SSH-2.0-", and should not include a
 309  // newline. exchangeVersions returns the other side's version line.
 310  func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
 311  	// Contrary to the RFC, we do not ignore lines that don't
 312  	// start with "SSH-2.0-" to make the library usable with
 313  	// nonconforming servers.
 314  	for _, c := range versionLine {
 315  		// The spec disallows non US-ASCII chars, and
 316  		// specifically forbids null chars.
 317  		if c < 32 {
 318  			return nil, errors.New("ssh: junk character in version line")
 319  		}
 320  	}
 321  	if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
 322  		return
 323  	}
 324  
 325  	them, err = readVersion(rw)
 326  	return them, err
 327  }
 328  
 329  // maxVersionStringBytes is the maximum number of bytes that we'll
 330  // accept as a version string. RFC 4253 section 4.2 limits this at 255
 331  // chars
 332  const maxVersionStringBytes = 255
 333  
 334  // Read version string as specified by RFC 4253, section 4.2.
 335  func readVersion(r io.Reader) ([]byte, error) {
 336  	versionString := make([]byte, 0, 64)
 337  	var ok bool
 338  	var buf [1]byte
 339  
 340  	for length := 0; length < maxVersionStringBytes; length++ {
 341  		_, err := io.ReadFull(r, buf[:])
 342  		if err != nil {
 343  			return nil, err
 344  		}
 345  		// The RFC says that the version should be terminated with \r\n
 346  		// but several SSH servers actually only send a \n.
 347  		if buf[0] == '\n' {
 348  			if !bytes.HasPrefix(versionString, []byte("SSH-")) {
 349  				// RFC 4253 says we need to ignore all version string lines
 350  				// except the one containing the SSH version (provided that
 351  				// all the lines do not exceed 255 bytes in total).
 352  				versionString = versionString[:0]
 353  				continue
 354  			}
 355  			ok = true
 356  			break
 357  		}
 358  
 359  		// non ASCII chars are disallowed, but we are lenient,
 360  		// since Go doesn't use null-terminated strings.
 361  
 362  		// The RFC allows a comment after a space, however,
 363  		// all of it (version and comments) goes into the
 364  		// session hash.
 365  		versionString = append(versionString, buf[0])
 366  	}
 367  
 368  	if !ok {
 369  		return nil, errors.New("ssh: overflow reading version string")
 370  	}
 371  
 372  	// There might be a '\r' on the end which we should remove.
 373  	if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
 374  		versionString = versionString[:len(versionString)-1]
 375  	}
 376  	return versionString, nil
 377  }
 378