handshake.go raw

   1  // Copyright 2013 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  	"log"
  12  	"net"
  13  	"slices"
  14  	"strings"
  15  	"sync"
  16  )
  17  
  18  // debugHandshake, if set, prints messages sent and received.  Key
  19  // exchange messages are printed as if DH were used, so the debug
  20  // messages are wrong when using ECDH.
  21  const debugHandshake = false
  22  
  23  // chanSize sets the amount of buffering SSH connections. This is
  24  // primarily for testing: setting chanSize=0 uncovers deadlocks more
  25  // quickly.
  26  const chanSize = 16
  27  
  28  // maxPendingPackets sets the maximum number of packets to queue while waiting
  29  // for KEX to complete. This limits the total pending data to maxPendingPackets
  30  // * maxPacket bytes, which is ~16.8MB.
  31  const maxPendingPackets = 64
  32  
  33  // keyingTransport is a packet based transport that supports key
  34  // changes. It need not be thread-safe. It should pass through
  35  // msgNewKeys in both directions.
  36  type keyingTransport interface {
  37  	packetConn
  38  
  39  	// prepareKeyChange sets up a key change. The key change for a
  40  	// direction will be effected if a msgNewKeys message is sent
  41  	// or received.
  42  	prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error
  43  
  44  	// setStrictMode sets the strict KEX mode, notably triggering
  45  	// sequence number resets on sending or receiving msgNewKeys.
  46  	// If the sequence number is already > 1 when setStrictMode
  47  	// is called, an error is returned.
  48  	setStrictMode() error
  49  
  50  	// setInitialKEXDone indicates to the transport that the initial key exchange
  51  	// was completed
  52  	setInitialKEXDone()
  53  }
  54  
  55  // handshakeTransport implements rekeying on top of a keyingTransport
  56  // and offers a thread-safe writePacket() interface.
  57  type handshakeTransport struct {
  58  	conn   keyingTransport
  59  	config *Config
  60  
  61  	serverVersion []byte
  62  	clientVersion []byte
  63  
  64  	// hostKeys is non-empty if we are the server. In that case,
  65  	// it contains all host keys that can be used to sign the
  66  	// connection.
  67  	hostKeys []Signer
  68  
  69  	// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
  70  	// it contains the supported client public key authentication algorithms.
  71  	publicKeyAuthAlgorithms []string
  72  
  73  	// hostKeyAlgorithms is non-empty if we are the client. In that case,
  74  	// we accept these key types from the server as host key.
  75  	hostKeyAlgorithms []string
  76  
  77  	// On read error, incoming is closed, and readError is set.
  78  	incoming  chan []byte
  79  	readError error
  80  
  81  	mu sync.Mutex
  82  	// Condition for the above mutex. It is used to notify a completed key
  83  	// exchange or a write failure. Writes can wait for this condition while a
  84  	// key exchange is in progress.
  85  	writeCond      *sync.Cond
  86  	writeError     error
  87  	sentInitPacket []byte
  88  	sentInitMsg    *kexInitMsg
  89  	// Used to queue writes when a key exchange is in progress. The length is
  90  	// limited by pendingPacketsSize. Once full, writes will block until the key
  91  	// exchange is completed or an error occurs. If not empty, it is emptied
  92  	// all at once when the key exchange is completed in kexLoop.
  93  	pendingPackets   [][]byte
  94  	writePacketsLeft uint32
  95  	writeBytesLeft   int64
  96  	userAuthComplete bool // whether the user authentication phase is complete
  97  
  98  	// If the read loop wants to schedule a kex, it pings this
  99  	// channel, and the write loop will send out a kex
 100  	// message.
 101  	requestKex chan struct{}
 102  
 103  	// If the other side requests or confirms a kex, its kexInit
 104  	// packet is sent here for the write loop to find it.
 105  	startKex    chan *pendingKex
 106  	kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
 107  
 108  	// data for host key checking
 109  	hostKeyCallback HostKeyCallback
 110  	dialAddress     string
 111  	remoteAddr      net.Addr
 112  
 113  	// bannerCallback is non-empty if we are the client and it has been set in
 114  	// ClientConfig. In that case it is called during the user authentication
 115  	// dance to handle a custom server's message.
 116  	bannerCallback BannerCallback
 117  
 118  	// Algorithms agreed in the last key exchange.
 119  	algorithms *NegotiatedAlgorithms
 120  
 121  	// Counters exclusively owned by readLoop.
 122  	readPacketsLeft uint32
 123  	readBytesLeft   int64
 124  
 125  	// The session ID or nil if first kex did not complete yet.
 126  	sessionID []byte
 127  
 128  	// strictMode indicates if the other side of the handshake indicated
 129  	// that we should be following the strict KEX protocol restrictions.
 130  	strictMode bool
 131  }
 132  
 133  type pendingKex struct {
 134  	otherInit []byte
 135  	done      chan error
 136  }
 137  
 138  func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
 139  	t := &handshakeTransport{
 140  		conn:          conn,
 141  		serverVersion: serverVersion,
 142  		clientVersion: clientVersion,
 143  		incoming:      make(chan []byte, chanSize),
 144  		requestKex:    make(chan struct{}, 1),
 145  		startKex:      make(chan *pendingKex),
 146  		kexLoopDone:   make(chan struct{}),
 147  
 148  		config: config,
 149  	}
 150  	t.writeCond = sync.NewCond(&t.mu)
 151  	t.resetReadThresholds()
 152  	t.resetWriteThresholds()
 153  
 154  	// We always start with a mandatory key exchange.
 155  	t.requestKex <- struct{}{}
 156  	return t
 157  }
 158  
 159  func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
 160  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 161  	t.dialAddress = dialAddr
 162  	t.remoteAddr = addr
 163  	t.hostKeyCallback = config.HostKeyCallback
 164  	t.bannerCallback = config.BannerCallback
 165  	if config.HostKeyAlgorithms != nil {
 166  		t.hostKeyAlgorithms = config.HostKeyAlgorithms
 167  	} else {
 168  		t.hostKeyAlgorithms = defaultHostKeyAlgos
 169  	}
 170  	go t.readLoop()
 171  	go t.kexLoop()
 172  	return t
 173  }
 174  
 175  func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
 176  	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
 177  	t.hostKeys = config.hostKeys
 178  	t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
 179  	go t.readLoop()
 180  	go t.kexLoop()
 181  	return t
 182  }
 183  
 184  func (t *handshakeTransport) getSessionID() []byte {
 185  	return t.sessionID
 186  }
 187  
 188  func (t *handshakeTransport) getAlgorithms() NegotiatedAlgorithms {
 189  	return *t.algorithms
 190  }
 191  
 192  // waitSession waits for the session to be established. This should be
 193  // the first thing to call after instantiating handshakeTransport.
 194  func (t *handshakeTransport) waitSession() error {
 195  	p, err := t.readPacket()
 196  	if err != nil {
 197  		return err
 198  	}
 199  	if p[0] != msgNewKeys {
 200  		return fmt.Errorf("ssh: first packet should be msgNewKeys")
 201  	}
 202  
 203  	return nil
 204  }
 205  
 206  func (t *handshakeTransport) id() string {
 207  	if len(t.hostKeys) > 0 {
 208  		return "server"
 209  	}
 210  	return "client"
 211  }
 212  
 213  func (t *handshakeTransport) printPacket(p []byte, write bool) {
 214  	action := "got"
 215  	if write {
 216  		action = "sent"
 217  	}
 218  
 219  	if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
 220  		log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
 221  	} else {
 222  		msg, err := decode(p)
 223  		log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
 224  	}
 225  }
 226  
 227  func (t *handshakeTransport) readPacket() ([]byte, error) {
 228  	p, ok := <-t.incoming
 229  	if !ok {
 230  		return nil, t.readError
 231  	}
 232  	return p, nil
 233  }
 234  
 235  func (t *handshakeTransport) readLoop() {
 236  	first := true
 237  	for {
 238  		p, err := t.readOnePacket(first)
 239  		first = false
 240  		if err != nil {
 241  			t.readError = err
 242  			close(t.incoming)
 243  			break
 244  		}
 245  		// If this is the first kex, and strict KEX mode is enabled,
 246  		// we don't ignore any messages, as they may be used to manipulate
 247  		// the packet sequence numbers.
 248  		if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
 249  			continue
 250  		}
 251  		t.incoming <- p
 252  	}
 253  
 254  	// Stop writers too.
 255  	t.recordWriteError(t.readError)
 256  
 257  	// Unblock the writer should it wait for this.
 258  	close(t.startKex)
 259  
 260  	// Don't close t.requestKex; it's also written to from writePacket.
 261  }
 262  
 263  func (t *handshakeTransport) pushPacket(p []byte) error {
 264  	if debugHandshake {
 265  		t.printPacket(p, true)
 266  	}
 267  	return t.conn.writePacket(p)
 268  }
 269  
 270  func (t *handshakeTransport) getWriteError() error {
 271  	t.mu.Lock()
 272  	defer t.mu.Unlock()
 273  	return t.writeError
 274  }
 275  
 276  func (t *handshakeTransport) recordWriteError(err error) {
 277  	t.mu.Lock()
 278  	defer t.mu.Unlock()
 279  	if t.writeError == nil && err != nil {
 280  		t.writeError = err
 281  		t.writeCond.Broadcast()
 282  	}
 283  }
 284  
 285  func (t *handshakeTransport) requestKeyExchange() {
 286  	select {
 287  	case t.requestKex <- struct{}{}:
 288  	default:
 289  		// something already requested a kex, so do nothing.
 290  	}
 291  }
 292  
 293  func (t *handshakeTransport) resetWriteThresholds() {
 294  	t.writePacketsLeft = packetRekeyThreshold
 295  	if t.config.RekeyThreshold > 0 {
 296  		t.writeBytesLeft = int64(t.config.RekeyThreshold)
 297  	} else if t.algorithms != nil {
 298  		t.writeBytesLeft = t.algorithms.Write.rekeyBytes()
 299  	} else {
 300  		t.writeBytesLeft = 1 << 30
 301  	}
 302  }
 303  
 304  func (t *handshakeTransport) kexLoop() {
 305  
 306  write:
 307  	for t.getWriteError() == nil {
 308  		var request *pendingKex
 309  		var sent bool
 310  
 311  		for request == nil || !sent {
 312  			var ok bool
 313  			select {
 314  			case request, ok = <-t.startKex:
 315  				if !ok {
 316  					break write
 317  				}
 318  			case <-t.requestKex:
 319  				break
 320  			}
 321  
 322  			if !sent {
 323  				if err := t.sendKexInit(); err != nil {
 324  					t.recordWriteError(err)
 325  					break
 326  				}
 327  				sent = true
 328  			}
 329  		}
 330  
 331  		if err := t.getWriteError(); err != nil {
 332  			if request != nil {
 333  				request.done <- err
 334  			}
 335  			break
 336  		}
 337  
 338  		// We're not servicing t.requestKex, but that is OK:
 339  		// we never block on sending to t.requestKex.
 340  
 341  		// We're not servicing t.startKex, but the remote end
 342  		// has just sent us a kexInitMsg, so it can't send
 343  		// another key change request, until we close the done
 344  		// channel on the pendingKex request.
 345  
 346  		err := t.enterKeyExchange(request.otherInit)
 347  
 348  		t.mu.Lock()
 349  		t.writeError = err
 350  		t.sentInitPacket = nil
 351  		t.sentInitMsg = nil
 352  
 353  		t.resetWriteThresholds()
 354  
 355  		// we have completed the key exchange. Since the
 356  		// reader is still blocked, it is safe to clear out
 357  		// the requestKex channel. This avoids the situation
 358  		// where: 1) we consumed our own request for the
 359  		// initial kex, and 2) the kex from the remote side
 360  		// caused another send on the requestKex channel,
 361  	clear:
 362  		for {
 363  			select {
 364  			case <-t.requestKex:
 365  				//
 366  			default:
 367  				break clear
 368  			}
 369  		}
 370  
 371  		request.done <- t.writeError
 372  
 373  		// kex finished. Push packets that we received while
 374  		// the kex was in progress. Don't look at t.startKex
 375  		// and don't increment writtenSinceKex: if we trigger
 376  		// another kex while we are still busy with the last
 377  		// one, things will become very confusing.
 378  		for _, p := range t.pendingPackets {
 379  			t.writeError = t.pushPacket(p)
 380  			if t.writeError != nil {
 381  				break
 382  			}
 383  		}
 384  		t.pendingPackets = t.pendingPackets[:0]
 385  		// Unblock writePacket if waiting for KEX.
 386  		t.writeCond.Broadcast()
 387  		t.mu.Unlock()
 388  	}
 389  
 390  	// Unblock reader.
 391  	t.conn.Close()
 392  
 393  	// drain startKex channel. We don't service t.requestKex
 394  	// because nobody does blocking sends there.
 395  	for request := range t.startKex {
 396  		request.done <- t.getWriteError()
 397  	}
 398  
 399  	// Mark that the loop is done so that Close can return.
 400  	close(t.kexLoopDone)
 401  }
 402  
 403  // The protocol uses uint32 for packet counters, so we can't let them
 404  // reach 1<<32.  We will actually read and write more packets than
 405  // this, though: the other side may send more packets, and after we
 406  // hit this limit on writing we will send a few more packets for the
 407  // key exchange itself.
 408  const packetRekeyThreshold = (1 << 31)
 409  
 410  func (t *handshakeTransport) resetReadThresholds() {
 411  	t.readPacketsLeft = packetRekeyThreshold
 412  	if t.config.RekeyThreshold > 0 {
 413  		t.readBytesLeft = int64(t.config.RekeyThreshold)
 414  	} else if t.algorithms != nil {
 415  		t.readBytesLeft = t.algorithms.Read.rekeyBytes()
 416  	} else {
 417  		t.readBytesLeft = 1 << 30
 418  	}
 419  }
 420  
 421  func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
 422  	p, err := t.conn.readPacket()
 423  	if err != nil {
 424  		return nil, err
 425  	}
 426  
 427  	if t.readPacketsLeft > 0 {
 428  		t.readPacketsLeft--
 429  	} else {
 430  		t.requestKeyExchange()
 431  	}
 432  
 433  	if t.readBytesLeft > 0 {
 434  		t.readBytesLeft -= int64(len(p))
 435  	} else {
 436  		t.requestKeyExchange()
 437  	}
 438  
 439  	if debugHandshake {
 440  		t.printPacket(p, false)
 441  	}
 442  
 443  	if first && p[0] != msgKexInit {
 444  		return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
 445  	}
 446  
 447  	if p[0] != msgKexInit {
 448  		return p, nil
 449  	}
 450  
 451  	firstKex := t.sessionID == nil
 452  
 453  	kex := pendingKex{
 454  		done:      make(chan error, 1),
 455  		otherInit: p,
 456  	}
 457  	t.startKex <- &kex
 458  	err = <-kex.done
 459  
 460  	if debugHandshake {
 461  		log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
 462  	}
 463  
 464  	if err != nil {
 465  		return nil, err
 466  	}
 467  
 468  	t.resetReadThresholds()
 469  
 470  	// By default, a key exchange is hidden from higher layers by
 471  	// translating it into msgIgnore.
 472  	successPacket := []byte{msgIgnore}
 473  	if firstKex {
 474  		// sendKexInit() for the first kex waits for
 475  		// msgNewKeys so the authentication process is
 476  		// guaranteed to happen over an encrypted transport.
 477  		successPacket = []byte{msgNewKeys}
 478  	}
 479  
 480  	return successPacket, nil
 481  }
 482  
 483  const (
 484  	kexStrictClient = "kex-strict-c-v00@openssh.com"
 485  	kexStrictServer = "kex-strict-s-v00@openssh.com"
 486  )
 487  
 488  // sendKexInit sends a key change message.
 489  func (t *handshakeTransport) sendKexInit() error {
 490  	t.mu.Lock()
 491  	defer t.mu.Unlock()
 492  	if t.sentInitMsg != nil {
 493  		// kexInits may be sent either in response to the other side,
 494  		// or because our side wants to initiate a key change, so we
 495  		// may have already sent a kexInit. In that case, don't send a
 496  		// second kexInit.
 497  		return nil
 498  	}
 499  
 500  	msg := &kexInitMsg{
 501  		CiphersClientServer:     t.config.Ciphers,
 502  		CiphersServerClient:     t.config.Ciphers,
 503  		MACsClientServer:        t.config.MACs,
 504  		MACsServerClient:        t.config.MACs,
 505  		CompressionClientServer: supportedCompressions,
 506  		CompressionServerClient: supportedCompressions,
 507  	}
 508  	io.ReadFull(t.config.Rand, msg.Cookie[:])
 509  
 510  	// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
 511  	// and possibly to add the ext-info extension algorithm. Since the slice may be the
 512  	// user owned KeyExchanges, we create our own slice in order to avoid using user
 513  	// owned memory by mistake.
 514  	msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
 515  	msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
 516  
 517  	isServer := len(t.hostKeys) > 0
 518  	if isServer {
 519  		for _, k := range t.hostKeys {
 520  			// If k is a MultiAlgorithmSigner, we restrict the signature
 521  			// algorithms. If k is a AlgorithmSigner, presume it supports all
 522  			// signature algorithms associated with the key format. If k is not
 523  			// an AlgorithmSigner, we can only assume it only supports the
 524  			// algorithms that matches the key format. (This means that Sign
 525  			// can't pick a different default).
 526  			keyFormat := k.PublicKey().Type()
 527  
 528  			switch s := k.(type) {
 529  			case MultiAlgorithmSigner:
 530  				for _, algo := range algorithmsForKeyFormat(keyFormat) {
 531  					if slices.Contains(s.Algorithms(), underlyingAlgo(algo)) {
 532  						msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
 533  					}
 534  				}
 535  			case AlgorithmSigner:
 536  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
 537  			default:
 538  				msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
 539  			}
 540  		}
 541  
 542  		if t.sessionID == nil {
 543  			msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
 544  		}
 545  	} else {
 546  		msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
 547  
 548  		// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
 549  		// algorithms the server supports for public key authentication. See RFC
 550  		// 8308, Section 2.1.
 551  		//
 552  		// We also send the strict KEX mode extension algorithm, in order to opt
 553  		// into the strict KEX mode.
 554  		if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
 555  			msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
 556  			msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
 557  		}
 558  
 559  	}
 560  
 561  	packet := Marshal(msg)
 562  
 563  	// writePacket destroys the contents, so save a copy.
 564  	packetCopy := make([]byte, len(packet))
 565  	copy(packetCopy, packet)
 566  
 567  	if err := t.pushPacket(packetCopy); err != nil {
 568  		return err
 569  	}
 570  
 571  	t.sentInitMsg = msg
 572  	t.sentInitPacket = packet
 573  
 574  	return nil
 575  }
 576  
 577  var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")
 578  
 579  func (t *handshakeTransport) writePacket(p []byte) error {
 580  	t.mu.Lock()
 581  	defer t.mu.Unlock()
 582  
 583  	switch p[0] {
 584  	case msgKexInit:
 585  		return errors.New("ssh: only handshakeTransport can send kexInit")
 586  	case msgNewKeys:
 587  		return errors.New("ssh: only handshakeTransport can send newKeys")
 588  	case msgUserAuthBanner:
 589  		if t.userAuthComplete {
 590  			return errSendBannerPhase
 591  		}
 592  	case msgUserAuthSuccess:
 593  		t.userAuthComplete = true
 594  	}
 595  
 596  	if t.writeError != nil {
 597  		return t.writeError
 598  	}
 599  
 600  	if t.sentInitMsg != nil {
 601  		if len(t.pendingPackets) < maxPendingPackets {
 602  			// Copy the packet so the writer can reuse the buffer.
 603  			cp := make([]byte, len(p))
 604  			copy(cp, p)
 605  			t.pendingPackets = append(t.pendingPackets, cp)
 606  			return nil
 607  		}
 608  		for t.sentInitMsg != nil {
 609  			// Block and wait for KEX to complete or an error.
 610  			t.writeCond.Wait()
 611  			if t.writeError != nil {
 612  				return t.writeError
 613  			}
 614  		}
 615  	}
 616  
 617  	if t.writeBytesLeft > 0 {
 618  		t.writeBytesLeft -= int64(len(p))
 619  	} else {
 620  		t.requestKeyExchange()
 621  	}
 622  
 623  	if t.writePacketsLeft > 0 {
 624  		t.writePacketsLeft--
 625  	} else {
 626  		t.requestKeyExchange()
 627  	}
 628  
 629  	if err := t.pushPacket(p); err != nil {
 630  		t.writeError = err
 631  		t.writeCond.Broadcast()
 632  	}
 633  
 634  	return nil
 635  }
 636  
 637  func (t *handshakeTransport) Close() error {
 638  	// Close the connection. This should cause the readLoop goroutine to wake up
 639  	// and close t.startKex, which will shut down kexLoop if running.
 640  	err := t.conn.Close()
 641  
 642  	// Wait for the kexLoop goroutine to complete.
 643  	// At that point we know that the readLoop goroutine is complete too,
 644  	// because kexLoop itself waits for readLoop to close the startKex channel.
 645  	<-t.kexLoopDone
 646  
 647  	return err
 648  }
 649  
 650  func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 651  	if debugHandshake {
 652  		log.Printf("%s entered key exchange", t.id())
 653  	}
 654  
 655  	otherInit := &kexInitMsg{}
 656  	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
 657  		return err
 658  	}
 659  
 660  	magics := handshakeMagics{
 661  		clientVersion: t.clientVersion,
 662  		serverVersion: t.serverVersion,
 663  		clientKexInit: otherInitPacket,
 664  		serverKexInit: t.sentInitPacket,
 665  	}
 666  
 667  	clientInit := otherInit
 668  	serverInit := t.sentInitMsg
 669  	isClient := len(t.hostKeys) == 0
 670  	if isClient {
 671  		clientInit, serverInit = serverInit, clientInit
 672  
 673  		magics.clientKexInit = t.sentInitPacket
 674  		magics.serverKexInit = otherInitPacket
 675  	}
 676  
 677  	var err error
 678  	t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
 679  	if err != nil {
 680  		return err
 681  	}
 682  
 683  	if t.sessionID == nil && ((isClient && slices.Contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && slices.Contains(clientInit.KexAlgos, kexStrictClient))) {
 684  		t.strictMode = true
 685  		if err := t.conn.setStrictMode(); err != nil {
 686  			return err
 687  		}
 688  	}
 689  
 690  	// We don't send FirstKexFollows, but we handle receiving it.
 691  	//
 692  	// RFC 4253 section 7 defines the kex and the agreement method for
 693  	// first_kex_packet_follows. It states that the guessed packet
 694  	// should be ignored if the "kex algorithm and/or the host
 695  	// key algorithm is guessed wrong (server and client have
 696  	// different preferred algorithm), or if any of the other
 697  	// algorithms cannot be agreed upon". The other algorithms have
 698  	// already been checked above so the kex algorithm and host key
 699  	// algorithm are checked here.
 700  	if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
 701  		// other side sent a kex message for the wrong algorithm,
 702  		// which we have to ignore.
 703  		if _, err := t.conn.readPacket(); err != nil {
 704  			return err
 705  		}
 706  	}
 707  
 708  	kex, ok := kexAlgoMap[t.algorithms.KeyExchange]
 709  	if !ok {
 710  		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.KeyExchange)
 711  	}
 712  
 713  	var result *kexResult
 714  	if len(t.hostKeys) > 0 {
 715  		result, err = t.server(kex, &magics)
 716  	} else {
 717  		result, err = t.client(kex, &magics)
 718  	}
 719  
 720  	if err != nil {
 721  		return err
 722  	}
 723  
 724  	firstKeyExchange := t.sessionID == nil
 725  	if firstKeyExchange {
 726  		t.sessionID = result.H
 727  	}
 728  	result.SessionID = t.sessionID
 729  
 730  	if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
 731  		return err
 732  	}
 733  	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 734  		return err
 735  	}
 736  
 737  	// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
 738  	// message with the server-sig-algs extension if the client supports it. See
 739  	// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
 740  	if !isClient && firstKeyExchange && slices.Contains(clientInit.KexAlgos, "ext-info-c") {
 741  		supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
 742  		extInfo := &extInfoMsg{
 743  			NumExtensions: 2,
 744  			Payload:       make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
 745  		}
 746  		extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
 747  		extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
 748  		extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
 749  		extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
 750  		extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
 751  		extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
 752  		extInfo.Payload = appendInt(extInfo.Payload, 1)
 753  		extInfo.Payload = append(extInfo.Payload, "0"...)
 754  		if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
 755  			return err
 756  		}
 757  	}
 758  
 759  	if packet, err := t.conn.readPacket(); err != nil {
 760  		return err
 761  	} else if packet[0] != msgNewKeys {
 762  		return unexpectedMessageError(msgNewKeys, packet[0])
 763  	}
 764  
 765  	if firstKeyExchange {
 766  		// Indicates to the transport that the first key exchange is completed
 767  		// after receiving SSH_MSG_NEWKEYS.
 768  		t.conn.setInitialKEXDone()
 769  	}
 770  
 771  	return nil
 772  }
 773  
 774  // algorithmSignerWrapper is an AlgorithmSigner that only supports the default
 775  // key format algorithm.
 776  //
 777  // This is technically a violation of the AlgorithmSigner interface, but it
 778  // should be unreachable given where we use this. Anyway, at least it returns an
 779  // error instead of panicing or producing an incorrect signature.
 780  type algorithmSignerWrapper struct {
 781  	Signer
 782  }
 783  
 784  func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
 785  	if algorithm != underlyingAlgo(a.PublicKey().Type()) {
 786  		return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
 787  	}
 788  	return a.Sign(rand, data)
 789  }
 790  
 791  func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
 792  	for _, k := range hostKeys {
 793  		if s, ok := k.(MultiAlgorithmSigner); ok {
 794  			if !slices.Contains(s.Algorithms(), underlyingAlgo(algo)) {
 795  				continue
 796  			}
 797  		}
 798  
 799  		if algo == k.PublicKey().Type() {
 800  			return algorithmSignerWrapper{k}
 801  		}
 802  
 803  		k, ok := k.(AlgorithmSigner)
 804  		if !ok {
 805  			continue
 806  		}
 807  		for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
 808  			if algo == a {
 809  				return k
 810  			}
 811  		}
 812  	}
 813  	return nil
 814  }
 815  
 816  func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
 817  	hostKey := pickHostKey(t.hostKeys, t.algorithms.HostKey)
 818  	if hostKey == nil {
 819  		return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
 820  	}
 821  
 822  	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.HostKey)
 823  	return r, err
 824  }
 825  
 826  func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
 827  	result, err := kex.Client(t.conn, t.config.Rand, magics)
 828  	if err != nil {
 829  		return nil, err
 830  	}
 831  
 832  	hostKey, err := ParsePublicKey(result.HostKey)
 833  	if err != nil {
 834  		return nil, err
 835  	}
 836  
 837  	if err := verifyHostKeySignature(hostKey, t.algorithms.HostKey, result); err != nil {
 838  		return nil, err
 839  	}
 840  
 841  	err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
 842  	if err != nil {
 843  		return nil, err
 844  	}
 845  
 846  	return result, nil
 847  }
 848