mlkem768.mx raw

   1  // Copyright 2023 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package mlkem implements the quantum-resistant key encapsulation method
   6  // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
   7  //
   8  // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
   9  package mlkem
  10  
  11  // This package targets security, correctness, simplicity, readability, and
  12  // reviewability as its primary goals. All critical operations are performed in
  13  // constant time.
  14  //
  15  // Variable and function names, as well as code layout, are selected to
  16  // facilitate reviewing the implementation against the NIST FIPS 203 document.
  17  //
  18  // Reviewers unfamiliar with polynomials or linear algebra might find the
  19  // background at https://words.filippo.io/kyber-math/ useful.
  20  //
  21  // This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
  22  // parameter set implementation is auto-generated from this file.
  23  //
  24  //go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
  25  
  26  import (
  27  	"bytes"
  28  	"crypto/internal/fips140"
  29  	"crypto/internal/fips140/drbg"
  30  	"crypto/internal/fips140/sha3"
  31  	"crypto/internal/fips140/subtle"
  32  	"errors"
  33  )
  34  
  35  const (
  36  	// ML-KEM global constants.
  37  	n = 256
  38  	q = 3329
  39  
  40  	// encodingSizeX is the byte size of a ringElement or nttElement encoded
  41  	// by ByteEncode_X (FIPS 203, Algorithm 5).
  42  	encodingSize12 = n * 12 / 8
  43  	encodingSize11 = n * 11 / 8
  44  	encodingSize10 = n * 10 / 8
  45  	encodingSize5  = n * 5 / 8
  46  	encodingSize4  = n * 4 / 8
  47  	encodingSize1  = n * 1 / 8
  48  
  49  	messageSize = encodingSize1
  50  
  51  	SharedKeySize = 32
  52  	SeedSize      = 32 + 32
  53  )
  54  
  55  // ML-KEM-768 parameters.
  56  const (
  57  	k = 3
  58  
  59  	CiphertextSize768       = k*encodingSize10 + encodingSize4
  60  	EncapsulationKeySize768 = k*encodingSize12 + 32
  61  	decapsulationKeySize768 = k*encodingSize12 + EncapsulationKeySize768 + 32 + 32
  62  )
  63  
  64  // ML-KEM-1024 parameters.
  65  const (
  66  	k1024 = 4
  67  
  68  	CiphertextSize1024       = k1024*encodingSize11 + encodingSize5
  69  	EncapsulationKeySize1024 = k1024*encodingSize12 + 32
  70  	decapsulationKeySize1024 = k1024*encodingSize12 + EncapsulationKeySize1024 + 32 + 32
  71  )
  72  
  73  // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
  74  // ciphertext. It includes various precomputed values.
  75  type DecapsulationKey768 struct {
  76  	d [32]byte // decapsulation key seed
  77  	z [32]byte // implicit rejection sampling seed
  78  
  79  	ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
  80  	h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
  81  
  82  	encryptionKey
  83  	decryptionKey
  84  }
  85  
  86  // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
  87  //
  88  // The decapsulation key must be kept secret.
  89  func (dk *DecapsulationKey768) Bytes() []byte {
  90  	var b [SeedSize]byte
  91  	copy(b[:], dk.d[:])
  92  	copy(b[32:], dk.z[:])
  93  	return b[:]
  94  }
  95  
  96  // TestingOnlyExpandedBytes768 returns the decapsulation key as a byte slice
  97  // using the full expanded NIST encoding.
  98  //
  99  // This should only be used for ACVP testing. For all other purposes prefer
 100  // the Bytes method that returns the (much smaller) seed.
 101  func TestingOnlyExpandedBytes768(dk *DecapsulationKey768) []byte {
 102  	b := []byte{:0:decapsulationKeySize768}
 103  
 104  	// ByteEncode₁₂(s)
 105  	for i := range dk.s {
 106  		b = polyByteEncode(b, dk.s[i])
 107  	}
 108  
 109  	// ByteEncode₁₂(t) || ρ
 110  	for i := range dk.t {
 111  		b = polyByteEncode(b, dk.t[i])
 112  	}
 113  	b = append(b, dk.ρ[:]...)
 114  
 115  	// H(ek) || z
 116  	b = append(b, dk.h[:]...)
 117  	b = append(b, dk.z[:]...)
 118  
 119  	return b
 120  }
 121  
 122  // EncapsulationKey returns the public encapsulation key necessary to produce
 123  // ciphertexts.
 124  func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
 125  	return &EncapsulationKey768{
 126  		ρ:             dk.ρ,
 127  		h:             dk.h,
 128  		encryptionKey: dk.encryptionKey,
 129  	}
 130  }
 131  
 132  // An EncapsulationKey768 is the public key used to produce ciphertexts to be
 133  // decapsulated by the corresponding [DecapsulationKey768].
 134  type EncapsulationKey768 struct {
 135  	ρ [32]byte // sampleNTT seed for A
 136  	h [32]byte // H(ek)
 137  	encryptionKey
 138  }
 139  
 140  // Bytes returns the encapsulation key as a byte slice.
 141  func (ek *EncapsulationKey768) Bytes() []byte {
 142  	// The actual logic is in a separate function to outline this allocation.
 143  	b := []byte{:0:EncapsulationKeySize768}
 144  	return ek.bytes(b)
 145  }
 146  
 147  func (ek *EncapsulationKey768) bytes(b []byte) []byte {
 148  	for i := range ek.t {
 149  		b = polyByteEncode(b, ek.t[i])
 150  	}
 151  	b = append(b, ek.ρ[:]...)
 152  	return b
 153  }
 154  
 155  // encryptionKey is the parsed and expanded form of a PKE encryption key.
 156  type encryptionKey struct {
 157  	t [k]nttElement     // ByteDecode₁₂(ek[:384k])
 158  	a [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
 159  }
 160  
 161  // decryptionKey is the parsed and expanded form of a PKE decryption key.
 162  type decryptionKey struct {
 163  	s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
 164  }
 165  
 166  // GenerateKey768 generates a new decapsulation key, drawing random bytes from
 167  // a DRBG. The decapsulation key must be kept secret.
 168  func GenerateKey768() (*DecapsulationKey768, error) {
 169  	// The actual logic is in a separate function to outline this allocation.
 170  	dk := &DecapsulationKey768{}
 171  	return generateKey(dk)
 172  }
 173  
 174  func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
 175  	var d [32]byte
 176  	drbg.Read(d[:])
 177  	var z [32]byte
 178  	drbg.Read(z[:])
 179  	kemKeyGen(dk, &d, &z)
 180  	fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) })
 181  	fips140.RecordApproved()
 182  	return dk, nil
 183  }
 184  
 185  // GenerateKeyInternal768 is a derandomized version of GenerateKey768,
 186  // exclusively for use in tests.
 187  func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
 188  	dk := &DecapsulationKey768{}
 189  	kemKeyGen(dk, d, z)
 190  	return dk
 191  }
 192  
 193  // NewDecapsulationKey768 parses a decapsulation key from a 64-byte
 194  // seed in the "d || z" form. The seed must be uniformly random.
 195  func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
 196  	// The actual logic is in a separate function to outline this allocation.
 197  	dk := &DecapsulationKey768{}
 198  	return newKeyFromSeed(dk, seed)
 199  }
 200  
 201  func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
 202  	if len(seed) != SeedSize {
 203  		return nil, errors.New("mlkem: invalid seed length")
 204  	}
 205  	d := (*[32]byte)(seed[:32])
 206  	z := (*[32]byte)(seed[32:])
 207  	kemKeyGen(dk, d, z)
 208  	fips140.RecordApproved()
 209  	return dk, nil
 210  }
 211  
 212  // TestingOnlyNewDecapsulationKey768 parses a decapsulation key from its expanded NIST format.
 213  //
 214  // Bytes() must not be called on the returned key, as it will not produce the
 215  // original seed.
 216  //
 217  // This function should only be used for ACVP testing. Prefer NewDecapsulationKey768 for all
 218  // other purposes.
 219  func TestingOnlyNewDecapsulationKey768(b []byte) (*DecapsulationKey768, error) {
 220  	if len(b) != decapsulationKeySize768 {
 221  		return nil, errors.New("mlkem: invalid NIST decapsulation key length")
 222  	}
 223  
 224  	dk := &DecapsulationKey768{}
 225  	for i := range dk.s {
 226  		var err error
 227  		dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
 228  		if err != nil {
 229  			return nil, errors.New("mlkem: invalid secret key encoding")
 230  		}
 231  		b = b[encodingSize12:]
 232  	}
 233  
 234  	ek, err := NewEncapsulationKey768(b[:EncapsulationKeySize768])
 235  	if err != nil {
 236  		return nil, err
 237  	}
 238  	dk.ρ = ek.ρ
 239  	dk.h = ek.h
 240  	dk.encryptionKey = ek.encryptionKey
 241  	b = b[EncapsulationKeySize768:]
 242  
 243  	if !bytes.Equal(dk.h[:], b[:32]) {
 244  		return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
 245  	}
 246  	b = b[32:]
 247  
 248  	copy(dk.z[:], b)
 249  
 250  	// Generate a random d value for use in Bytes(). This is a safety mechanism
 251  	// that avoids returning a broken key vs a random key if this function is
 252  	// called in contravention of the TestingOnlyNewDecapsulationKey768 function
 253  	// comment advising against it.
 254  	drbg.Read(dk.d[:])
 255  
 256  	return dk, nil
 257  }
 258  
 259  // kemKeyGen generates a decapsulation key.
 260  //
 261  // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
 262  // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
 263  // copies and allocations.
 264  func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
 265  	dk.d = *d
 266  	dk.z = *z
 267  
 268  	g := sha3.New512()
 269  	g.Write(d[:])
 270  	g.Write([]byte{k}) // Module dimension as a domain separator.
 271  	G := g.Sum([]byte{:0:64})
 272  	ρ, σ := G[:32], G[32:]
 273  	dk.ρ = [32]byte(ρ)
 274  
 275  	A := &dk.a
 276  	for i := byte(0); i < k; i++ {
 277  		for j := byte(0); j < k; j++ {
 278  			A[i*k+j] = sampleNTT(ρ, j, i)
 279  		}
 280  	}
 281  
 282  	var N byte
 283  	s := &dk.s
 284  	for i := range s {
 285  		s[i] = ntt(samplePolyCBD(σ, N))
 286  		N++
 287  	}
 288  	e := []nttElement{:k}
 289  	for i := range e {
 290  		e[i] = ntt(samplePolyCBD(σ, N))
 291  		N++
 292  	}
 293  
 294  	t := &dk.t
 295  	for i := range t { // t = A ◦ s + e
 296  		t[i] = e[i]
 297  		for j := range s {
 298  			t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
 299  		}
 300  	}
 301  
 302  	H := sha3.New256()
 303  	ek := dk.EncapsulationKey().Bytes()
 304  	H.Write(ek)
 305  	H.Sum(dk.h[:0])
 306  }
 307  
 308  // kemPCT performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
 309  // Additional Comment 1: "For key pairs generated for use with approved KEMs in
 310  // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
 311  // encapsulate a shared secret K leading to ciphertext c, and then applying
 312  // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
 313  // the two shared secret K values are equal. The PCT shall be performed either
 314  // when keys are generated/imported, prior to the first exportation, or prior to
 315  // the first operational use (if not exported before the first use)."
 316  func kemPCT(dk *DecapsulationKey768) error {
 317  	ek := dk.EncapsulationKey()
 318  	K, c := ek.Encapsulate()
 319  	K1, err := dk.Decapsulate(c)
 320  	if err != nil {
 321  		return err
 322  	}
 323  	if subtle.ConstantTimeCompare(K, K1) != 1 {
 324  		return errors.New("mlkem: PCT failed")
 325  	}
 326  	return nil
 327  }
 328  
 329  // Encapsulate generates a shared key and an associated ciphertext from an
 330  // encapsulation key, drawing random bytes from a DRBG.
 331  //
 332  // The shared key must be kept secret.
 333  func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
 334  	// The actual logic is in a separate function to outline this allocation.
 335  	var cc [CiphertextSize768]byte
 336  	return ek.encapsulate(&cc)
 337  }
 338  
 339  func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
 340  	var m [messageSize]byte
 341  	drbg.Read(m[:])
 342  	// Note that the modulus check (step 2 of the encapsulation key check from
 343  	// FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
 344  	fips140.RecordApproved()
 345  	return kemEncaps(cc, ek, &m)
 346  }
 347  
 348  // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
 349  // use in tests.
 350  func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
 351  	cc := &[CiphertextSize768]byte{}
 352  	return kemEncaps(cc, ek, m)
 353  }
 354  
 355  // kemEncaps generates a shared key and an associated ciphertext.
 356  //
 357  // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
 358  func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
 359  	g := sha3.New512()
 360  	g.Write(m[:])
 361  	g.Write(ek.h[:])
 362  	G := g.Sum(nil)
 363  	K, r := G[:SharedKeySize], G[SharedKeySize:]
 364  	c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
 365  	return K, c
 366  }
 367  
 368  // NewEncapsulationKey768 parses an encapsulation key from its encoded form.
 369  // If the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
 370  func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
 371  	// The actual logic is in a separate function to outline this allocation.
 372  	ek := &EncapsulationKey768{}
 373  	return parseEK(ek, encapsulationKey)
 374  }
 375  
 376  // parseEK parses an encryption key from its encoded form.
 377  //
 378  // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
 379  // Algorithm 14.
 380  func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
 381  	if len(ekPKE) != EncapsulationKeySize768 {
 382  		return nil, errors.New("mlkem: invalid encapsulation key length")
 383  	}
 384  
 385  	h := sha3.New256()
 386  	h.Write(ekPKE)
 387  	h.Sum(ek.h[:0])
 388  
 389  	for i := range ek.t {
 390  		var err error
 391  		ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
 392  		if err != nil {
 393  			return nil, err
 394  		}
 395  		ekPKE = ekPKE[encodingSize12:]
 396  	}
 397  	copy(ek.ρ[:], ekPKE)
 398  
 399  	for i := byte(0); i < k; i++ {
 400  		for j := byte(0); j < k; j++ {
 401  			ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
 402  		}
 403  	}
 404  
 405  	return ek, nil
 406  }
 407  
 408  // pkeEncrypt encrypt a plaintext message.
 409  //
 410  // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
 411  // computation of t and AT is done in parseEK.
 412  func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
 413  	var N byte
 414  	r, e1 := []nttElement{:k}, []ringElement{:k}
 415  	for i := range r {
 416  		r[i] = ntt(samplePolyCBD(rnd, N))
 417  		N++
 418  	}
 419  	for i := range e1 {
 420  		e1[i] = samplePolyCBD(rnd, N)
 421  		N++
 422  	}
 423  	e2 := samplePolyCBD(rnd, N)
 424  
 425  	u := []ringElement{:k} // NTT⁻¹(AT ◦ r) + e1
 426  	for i := range u {
 427  		u[i] = e1[i]
 428  		for j := range r {
 429  			// Note that i and j are inverted, as we need the transposed of A.
 430  			u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
 431  		}
 432  	}
 433  
 434  	μ := ringDecodeAndDecompress1(m)
 435  
 436  	var vNTT nttElement // t⊺ ◦ r
 437  	for i := range ex.t {
 438  		vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
 439  	}
 440  	v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
 441  
 442  	c := cc[:0]
 443  	for _, f := range u {
 444  		c = ringCompressAndEncode10(c, f)
 445  	}
 446  	c = ringCompressAndEncode4(c, v)
 447  
 448  	return c
 449  }
 450  
 451  // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
 452  // If the ciphertext is not valid, Decapsulate returns an error.
 453  //
 454  // The shared key must be kept secret.
 455  func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
 456  	if len(ciphertext) != CiphertextSize768 {
 457  		return nil, errors.New("mlkem: invalid ciphertext length")
 458  	}
 459  	c := (*[CiphertextSize768]byte)(ciphertext)
 460  	// Note that the hash check (step 3 of the decapsulation input check from
 461  	// FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
 462  	// validly generated by ML-KEM.KeyGen_internal.
 463  	return kemDecaps(dk, c), nil
 464  }
 465  
 466  // kemDecaps produces a shared key from a ciphertext.
 467  //
 468  // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
 469  func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
 470  	fips140.RecordApproved()
 471  	m := pkeDecrypt(&dk.decryptionKey, c)
 472  	g := sha3.New512()
 473  	g.Write(m[:])
 474  	g.Write(dk.h[:])
 475  	G := g.Sum([]byte{:0:64})
 476  	Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
 477  	J := sha3.NewShake256()
 478  	J.Write(dk.z[:])
 479  	J.Write(c[:])
 480  	Kout := []byte{:SharedKeySize}
 481  	J.Read(Kout)
 482  	var cc [CiphertextSize768]byte
 483  	c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
 484  
 485  	subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
 486  	return Kout
 487  }
 488  
 489  // pkeDecrypt decrypts a ciphertext.
 490  //
 491  // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
 492  // although s is retained from kemKeyGen.
 493  func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
 494  	u := []ringElement{:k}
 495  	for i := range u {
 496  		b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
 497  		u[i] = ringDecodeAndDecompress10(b)
 498  	}
 499  
 500  	b := (*[encodingSize4]byte)(c[encodingSize10*k:])
 501  	v := ringDecodeAndDecompress4(b)
 502  
 503  	var mask nttElement // s⊺ ◦ NTT(u)
 504  	for i := range dx.s {
 505  		mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
 506  	}
 507  	w := polySub(v, inverseNTT(mask))
 508  
 509  	return ringCompressAndEncode1(nil, w)
 510  }
 511