mlkem.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssh
   6  
   7  import (
   8  	"crypto"
   9  	"crypto/mlkem"
  10  	"crypto/sha256"
  11  	"errors"
  12  	"fmt"
  13  	"io"
  14  
  15  	"golang.org/x/crypto/curve25519"
  16  )
  17  
  18  // mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with
  19  // curve25519-sha256 key exchange method, as described by
  20  // draft-kampanakis-curdle-ssh-pq-ke-05 section 2.3.3.
  21  type mlkem768WithCurve25519sha256 struct{}
  22  
  23  func (kex *mlkem768WithCurve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
  24  	var c25519kp curve25519KeyPair
  25  	if err := c25519kp.generate(rand); err != nil {
  26  		return nil, err
  27  	}
  28  
  29  	seed := make([]byte, mlkem.SeedSize)
  30  	if _, err := io.ReadFull(rand, seed); err != nil {
  31  		return nil, err
  32  	}
  33  
  34  	mlkemDk, err := mlkem.NewDecapsulationKey768(seed)
  35  	if err != nil {
  36  		return nil, err
  37  	}
  38  
  39  	hybridKey := append(mlkemDk.EncapsulationKey().Bytes(), c25519kp.pub[:]...)
  40  	if err := c.writePacket(Marshal(&kexECDHInitMsg{hybridKey})); err != nil {
  41  		return nil, err
  42  	}
  43  
  44  	packet, err := c.readPacket()
  45  	if err != nil {
  46  		return nil, err
  47  	}
  48  
  49  	var reply kexECDHReplyMsg
  50  	if err = Unmarshal(packet, &reply); err != nil {
  51  		return nil, err
  52  	}
  53  
  54  	if len(reply.EphemeralPubKey) != mlkem.CiphertextSize768+32 {
  55  		return nil, errors.New("ssh: peer's mlkem768x25519 public value has wrong length")
  56  	}
  57  
  58  	// Perform KEM decapsulate operation to obtain shared key from ML-KEM.
  59  	mlkem768Secret, err := mlkemDk.Decapsulate(reply.EphemeralPubKey[:mlkem.CiphertextSize768])
  60  	if err != nil {
  61  		return nil, err
  62  	}
  63  
  64  	// Complete Curve25519 ECDH to obtain its shared key.
  65  	c25519Secret, err := curve25519.X25519(c25519kp.priv[:], reply.EphemeralPubKey[mlkem.CiphertextSize768:])
  66  	if err != nil {
  67  		return nil, fmt.Errorf("ssh: peer's mlkem768x25519 public value is not valid: %w", err)
  68  	}
  69  	// Compute actual shared key.
  70  	h := sha256.New()
  71  	h.Write(mlkem768Secret)
  72  	h.Write(c25519Secret)
  73  	secret := h.Sum(nil)
  74  
  75  	h.Reset()
  76  	magics.write(h)
  77  	writeString(h, reply.HostKey)
  78  	writeString(h, hybridKey)
  79  	writeString(h, reply.EphemeralPubKey)
  80  
  81  	K := make([]byte, stringLength(len(secret)))
  82  	marshalString(K, secret)
  83  	h.Write(K)
  84  
  85  	return &kexResult{
  86  		H:         h.Sum(nil),
  87  		K:         K,
  88  		HostKey:   reply.HostKey,
  89  		Signature: reply.Signature,
  90  		Hash:      crypto.SHA256,
  91  	}, nil
  92  }
  93  
  94  func (kex *mlkem768WithCurve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (*kexResult, error) {
  95  	packet, err := c.readPacket()
  96  	if err != nil {
  97  		return nil, err
  98  	}
  99  
 100  	var kexInit kexECDHInitMsg
 101  	if err = Unmarshal(packet, &kexInit); err != nil {
 102  		return nil, err
 103  	}
 104  
 105  	if len(kexInit.ClientPubKey) != mlkem.EncapsulationKeySize768+32 {
 106  		return nil, errors.New("ssh: peer's ML-KEM768/curve25519 public value has wrong length")
 107  	}
 108  
 109  	encapsulationKey, err := mlkem.NewEncapsulationKey768(kexInit.ClientPubKey[:mlkem.EncapsulationKeySize768])
 110  	if err != nil {
 111  		return nil, fmt.Errorf("ssh: peer's ML-KEM768 encapsulation key is not valid: %w", err)
 112  	}
 113  	// Perform KEM encapsulate operation to obtain ciphertext and shared key.
 114  	mlkem768Secret, mlkem768Ciphertext := encapsulationKey.Encapsulate()
 115  
 116  	// Perform server side of Curve25519 ECDH to obtain server public value and
 117  	// shared key.
 118  	var c25519kp curve25519KeyPair
 119  	if err := c25519kp.generate(rand); err != nil {
 120  		return nil, err
 121  	}
 122  	c25519Secret, err := curve25519.X25519(c25519kp.priv[:], kexInit.ClientPubKey[mlkem.EncapsulationKeySize768:])
 123  	if err != nil {
 124  		return nil, fmt.Errorf("ssh: peer's ML-KEM768/curve25519 public value is not valid: %w", err)
 125  	}
 126  	hybridKey := append(mlkem768Ciphertext, c25519kp.pub[:]...)
 127  
 128  	// Compute actual shared key.
 129  	h := sha256.New()
 130  	h.Write(mlkem768Secret)
 131  	h.Write(c25519Secret)
 132  	secret := h.Sum(nil)
 133  
 134  	hostKeyBytes := priv.PublicKey().Marshal()
 135  
 136  	h.Reset()
 137  	magics.write(h)
 138  	writeString(h, hostKeyBytes)
 139  	writeString(h, kexInit.ClientPubKey)
 140  	writeString(h, hybridKey)
 141  
 142  	K := make([]byte, stringLength(len(secret)))
 143  	marshalString(K, secret)
 144  	h.Write(K)
 145  
 146  	H := h.Sum(nil)
 147  
 148  	sig, err := signAndMarshal(priv, rand, H, algo)
 149  	if err != nil {
 150  		return nil, err
 151  	}
 152  
 153  	reply := kexECDHReplyMsg{
 154  		EphemeralPubKey: hybridKey,
 155  		HostKey:         hostKeyBytes,
 156  		Signature:       sig,
 157  	}
 158  	if err := c.writePacket(Marshal(&reply)); err != nil {
 159  		return nil, err
 160  	}
 161  	return &kexResult{
 162  		H:         H,
 163  		K:         K,
 164  		HostKey:   hostKeyBytes,
 165  		Signature: sig,
 166  		Hash:      crypto.SHA256,
 167  	}, nil
 168  }
 169