1 // Copyright 2023 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 // Package mlkem implements the quantum-resistant key encapsulation method
6 // ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
7 //
8 // [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
9 package mlkem
10 11 // This package targets security, correctness, simplicity, readability, and
12 // reviewability as its primary goals. All critical operations are performed in
13 // constant time.
14 //
15 // Variable and function names, as well as code layout, are selected to
16 // facilitate reviewing the implementation against the NIST FIPS 203 document.
17 //
18 // Reviewers unfamiliar with polynomials or linear algebra might find the
19 // background at https://words.filippo.io/kyber-math/ useful.
20 //
21 // This file implements the recommended parameter set ML-KEM-768. The ML-KEM-1024
22 // parameter set implementation is auto-generated from this file.
23 //
24 //go:generate go run generate1024.go -input mlkem768.go -output mlkem1024.go
25 26 import (
27 "bytes"
28 "crypto/internal/fips140"
29 "crypto/internal/fips140/drbg"
30 "crypto/internal/fips140/sha3"
31 "crypto/internal/fips140/subtle"
32 "errors"
33 )
34 35 const (
36 // ML-KEM global constants.
37 n = 256
38 q = 3329
39 40 // encodingSizeX is the byte size of a ringElement or nttElement encoded
41 // by ByteEncode_X (FIPS 203, Algorithm 5).
42 encodingSize12 = n * 12 / 8
43 encodingSize11 = n * 11 / 8
44 encodingSize10 = n * 10 / 8
45 encodingSize5 = n * 5 / 8
46 encodingSize4 = n * 4 / 8
47 encodingSize1 = n * 1 / 8
48 49 messageSize = encodingSize1
50 51 SharedKeySize = 32
52 SeedSize = 32 + 32
53 )
54 55 // ML-KEM-768 parameters.
56 const (
57 k = 3
58 59 CiphertextSize768 = k*encodingSize10 + encodingSize4
60 EncapsulationKeySize768 = k*encodingSize12 + 32
61 decapsulationKeySize768 = k*encodingSize12 + EncapsulationKeySize768 + 32 + 32
62 )
63 64 // ML-KEM-1024 parameters.
65 const (
66 k1024 = 4
67 68 CiphertextSize1024 = k1024*encodingSize11 + encodingSize5
69 EncapsulationKeySize1024 = k1024*encodingSize12 + 32
70 decapsulationKeySize1024 = k1024*encodingSize12 + EncapsulationKeySize1024 + 32 + 32
71 )
72 73 // A DecapsulationKey768 is the secret key used to decapsulate a shared key from a
74 // ciphertext. It includes various precomputed values.
75 type DecapsulationKey768 struct {
76 d [32]byte // decapsulation key seed
77 z [32]byte // implicit rejection sampling seed
78 79 ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
80 h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
81 82 encryptionKey
83 decryptionKey
84 }
85 86 // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
87 //
88 // The decapsulation key must be kept secret.
89 func (dk *DecapsulationKey768) Bytes() []byte {
90 var b [SeedSize]byte
91 copy(b[:], dk.d[:])
92 copy(b[32:], dk.z[:])
93 return b[:]
94 }
95 96 // TestingOnlyExpandedBytes768 returns the decapsulation key as a byte slice
97 // using the full expanded NIST encoding.
98 //
99 // This should only be used for ACVP testing. For all other purposes prefer
100 // the Bytes method that returns the (much smaller) seed.
101 func TestingOnlyExpandedBytes768(dk *DecapsulationKey768) []byte {
102 b := []byte{:0:decapsulationKeySize768}
103 104 // ByteEncode₁₂(s)
105 for i := range dk.s {
106 b = polyByteEncode(b, dk.s[i])
107 }
108 109 // ByteEncode₁₂(t) || ρ
110 for i := range dk.t {
111 b = polyByteEncode(b, dk.t[i])
112 }
113 b = append(b, dk.ρ[:]...)
114 115 // H(ek) || z
116 b = append(b, dk.h[:]...)
117 b = append(b, dk.z[:]...)
118 119 return b
120 }
121 122 // EncapsulationKey returns the public encapsulation key necessary to produce
123 // ciphertexts.
124 func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
125 return &EncapsulationKey768{
126 ρ: dk.ρ,
127 h: dk.h,
128 encryptionKey: dk.encryptionKey,
129 }
130 }
131 132 // An EncapsulationKey768 is the public key used to produce ciphertexts to be
133 // decapsulated by the corresponding [DecapsulationKey768].
134 type EncapsulationKey768 struct {
135 ρ [32]byte // sampleNTT seed for A
136 h [32]byte // H(ek)
137 encryptionKey
138 }
139 140 // Bytes returns the encapsulation key as a byte slice.
141 func (ek *EncapsulationKey768) Bytes() []byte {
142 // The actual logic is in a separate function to outline this allocation.
143 b := []byte{:0:EncapsulationKeySize768}
144 return ek.bytes(b)
145 }
146 147 func (ek *EncapsulationKey768) bytes(b []byte) []byte {
148 for i := range ek.t {
149 b = polyByteEncode(b, ek.t[i])
150 }
151 b = append(b, ek.ρ[:]...)
152 return b
153 }
154 155 // encryptionKey is the parsed and expanded form of a PKE encryption key.
156 type encryptionKey struct {
157 t [k]nttElement // ByteDecode₁₂(ek[:384k])
158 a [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
159 }
160 161 // decryptionKey is the parsed and expanded form of a PKE decryption key.
162 type decryptionKey struct {
163 s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize])
164 }
165 166 // GenerateKey768 generates a new decapsulation key, drawing random bytes from
167 // a DRBG. The decapsulation key must be kept secret.
168 func GenerateKey768() (*DecapsulationKey768, error) {
169 // The actual logic is in a separate function to outline this allocation.
170 dk := &DecapsulationKey768{}
171 return generateKey(dk)
172 }
173 174 func generateKey(dk *DecapsulationKey768) (*DecapsulationKey768, error) {
175 var d [32]byte
176 drbg.Read(d[:])
177 var z [32]byte
178 drbg.Read(z[:])
179 kemKeyGen(dk, &d, &z)
180 fips140.PCT("ML-KEM PCT", func() error { return kemPCT(dk) })
181 fips140.RecordApproved()
182 return dk, nil
183 }
184 185 // GenerateKeyInternal768 is a derandomized version of GenerateKey768,
186 // exclusively for use in tests.
187 func GenerateKeyInternal768(d, z *[32]byte) *DecapsulationKey768 {
188 dk := &DecapsulationKey768{}
189 kemKeyGen(dk, d, z)
190 return dk
191 }
192 193 // NewDecapsulationKey768 parses a decapsulation key from a 64-byte
194 // seed in the "d || z" form. The seed must be uniformly random.
195 func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
196 // The actual logic is in a separate function to outline this allocation.
197 dk := &DecapsulationKey768{}
198 return newKeyFromSeed(dk, seed)
199 }
200 201 func newKeyFromSeed(dk *DecapsulationKey768, seed []byte) (*DecapsulationKey768, error) {
202 if len(seed) != SeedSize {
203 return nil, errors.New("mlkem: invalid seed length")
204 }
205 d := (*[32]byte)(seed[:32])
206 z := (*[32]byte)(seed[32:])
207 kemKeyGen(dk, d, z)
208 fips140.RecordApproved()
209 return dk, nil
210 }
211 212 // TestingOnlyNewDecapsulationKey768 parses a decapsulation key from its expanded NIST format.
213 //
214 // Bytes() must not be called on the returned key, as it will not produce the
215 // original seed.
216 //
217 // This function should only be used for ACVP testing. Prefer NewDecapsulationKey768 for all
218 // other purposes.
219 func TestingOnlyNewDecapsulationKey768(b []byte) (*DecapsulationKey768, error) {
220 if len(b) != decapsulationKeySize768 {
221 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
222 }
223 224 dk := &DecapsulationKey768{}
225 for i := range dk.s {
226 var err error
227 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
228 if err != nil {
229 return nil, errors.New("mlkem: invalid secret key encoding")
230 }
231 b = b[encodingSize12:]
232 }
233 234 ek, err := NewEncapsulationKey768(b[:EncapsulationKeySize768])
235 if err != nil {
236 return nil, err
237 }
238 dk.ρ = ek.ρ
239 dk.h = ek.h
240 dk.encryptionKey = ek.encryptionKey
241 b = b[EncapsulationKeySize768:]
242 243 if !bytes.Equal(dk.h[:], b[:32]) {
244 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
245 }
246 b = b[32:]
247 248 copy(dk.z[:], b)
249 250 // Generate a random d value for use in Bytes(). This is a safety mechanism
251 // that avoids returning a broken key vs a random key if this function is
252 // called in contravention of the TestingOnlyNewDecapsulationKey768 function
253 // comment advising against it.
254 drbg.Read(dk.d[:])
255 256 return dk, nil
257 }
258 259 // kemKeyGen generates a decapsulation key.
260 //
261 // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
262 // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
263 // copies and allocations.
264 func kemKeyGen(dk *DecapsulationKey768, d, z *[32]byte) {
265 dk.d = *d
266 dk.z = *z
267 268 g := sha3.New512()
269 g.Write(d[:])
270 g.Write([]byte{k}) // Module dimension as a domain separator.
271 G := g.Sum([]byte{:0:64})
272 ρ, σ := G[:32], G[32:]
273 dk.ρ = [32]byte(ρ)
274 275 A := &dk.a
276 for i := byte(0); i < k; i++ {
277 for j := byte(0); j < k; j++ {
278 A[i*k+j] = sampleNTT(ρ, j, i)
279 }
280 }
281 282 var N byte
283 s := &dk.s
284 for i := range s {
285 s[i] = ntt(samplePolyCBD(σ, N))
286 N++
287 }
288 e := []nttElement{:k}
289 for i := range e {
290 e[i] = ntt(samplePolyCBD(σ, N))
291 N++
292 }
293 294 t := &dk.t
295 for i := range t { // t = A ◦ s + e
296 t[i] = e[i]
297 for j := range s {
298 t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j]))
299 }
300 }
301 302 H := sha3.New256()
303 ek := dk.EncapsulationKey().Bytes()
304 H.Write(ek)
305 H.Sum(dk.h[:0])
306 }
307 308 // kemPCT performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
309 // Additional Comment 1: "For key pairs generated for use with approved KEMs in
310 // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
311 // encapsulate a shared secret K leading to ciphertext c, and then applying
312 // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
313 // the two shared secret K values are equal. The PCT shall be performed either
314 // when keys are generated/imported, prior to the first exportation, or prior to
315 // the first operational use (if not exported before the first use)."
316 func kemPCT(dk *DecapsulationKey768) error {
317 ek := dk.EncapsulationKey()
318 K, c := ek.Encapsulate()
319 K1, err := dk.Decapsulate(c)
320 if err != nil {
321 return err
322 }
323 if subtle.ConstantTimeCompare(K, K1) != 1 {
324 return errors.New("mlkem: PCT failed")
325 }
326 return nil
327 }
328 329 // Encapsulate generates a shared key and an associated ciphertext from an
330 // encapsulation key, drawing random bytes from a DRBG.
331 //
332 // The shared key must be kept secret.
333 func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
334 // The actual logic is in a separate function to outline this allocation.
335 var cc [CiphertextSize768]byte
336 return ek.encapsulate(&cc)
337 }
338 339 func (ek *EncapsulationKey768) encapsulate(cc *[CiphertextSize768]byte) (sharedKey, ciphertext []byte) {
340 var m [messageSize]byte
341 drbg.Read(m[:])
342 // Note that the modulus check (step 2 of the encapsulation key check from
343 // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK.
344 fips140.RecordApproved()
345 return kemEncaps(cc, ek, &m)
346 }
347 348 // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
349 // use in tests.
350 func (ek *EncapsulationKey768) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
351 cc := &[CiphertextSize768]byte{}
352 return kemEncaps(cc, ek, m)
353 }
354 355 // kemEncaps generates a shared key and an associated ciphertext.
356 //
357 // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
358 func kemEncaps(cc *[CiphertextSize768]byte, ek *EncapsulationKey768, m *[messageSize]byte) (K, c []byte) {
359 g := sha3.New512()
360 g.Write(m[:])
361 g.Write(ek.h[:])
362 G := g.Sum(nil)
363 K, r := G[:SharedKeySize], G[SharedKeySize:]
364 c = pkeEncrypt(cc, &ek.encryptionKey, m, r)
365 return K, c
366 }
367 368 // NewEncapsulationKey768 parses an encapsulation key from its encoded form.
369 // If the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
370 func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
371 // The actual logic is in a separate function to outline this allocation.
372 ek := &EncapsulationKey768{}
373 return parseEK(ek, encapsulationKey)
374 }
375 376 // parseEK parses an encryption key from its encoded form.
377 //
378 // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
379 // Algorithm 14.
380 func parseEK(ek *EncapsulationKey768, ekPKE []byte) (*EncapsulationKey768, error) {
381 if len(ekPKE) != EncapsulationKeySize768 {
382 return nil, errors.New("mlkem: invalid encapsulation key length")
383 }
384 385 h := sha3.New256()
386 h.Write(ekPKE)
387 h.Sum(ek.h[:0])
388 389 for i := range ek.t {
390 var err error
391 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
392 if err != nil {
393 return nil, err
394 }
395 ekPKE = ekPKE[encodingSize12:]
396 }
397 copy(ek.ρ[:], ekPKE)
398 399 for i := byte(0); i < k; i++ {
400 for j := byte(0); j < k; j++ {
401 ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i)
402 }
403 }
404 405 return ek, nil
406 }
407 408 // pkeEncrypt encrypt a plaintext message.
409 //
410 // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
411 // computation of t and AT is done in parseEK.
412 func pkeEncrypt(cc *[CiphertextSize768]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte {
413 var N byte
414 r, e1 := []nttElement{:k}, []ringElement{:k}
415 for i := range r {
416 r[i] = ntt(samplePolyCBD(rnd, N))
417 N++
418 }
419 for i := range e1 {
420 e1[i] = samplePolyCBD(rnd, N)
421 N++
422 }
423 e2 := samplePolyCBD(rnd, N)
424 425 u := []ringElement{:k} // NTT⁻¹(AT ◦ r) + e1
426 for i := range u {
427 u[i] = e1[i]
428 for j := range r {
429 // Note that i and j are inverted, as we need the transposed of A.
430 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k+i], r[j])))
431 }
432 }
433 434 μ := ringDecodeAndDecompress1(m)
435 436 var vNTT nttElement // t⊺ ◦ r
437 for i := range ex.t {
438 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
439 }
440 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
441 442 c := cc[:0]
443 for _, f := range u {
444 c = ringCompressAndEncode10(c, f)
445 }
446 c = ringCompressAndEncode4(c, v)
447 448 return c
449 }
450 451 // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
452 // If the ciphertext is not valid, Decapsulate returns an error.
453 //
454 // The shared key must be kept secret.
455 func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
456 if len(ciphertext) != CiphertextSize768 {
457 return nil, errors.New("mlkem: invalid ciphertext length")
458 }
459 c := (*[CiphertextSize768]byte)(ciphertext)
460 // Note that the hash check (step 3 of the decapsulation input check from
461 // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
462 // validly generated by ML-KEM.KeyGen_internal.
463 return kemDecaps(dk, c), nil
464 }
465 466 // kemDecaps produces a shared key from a ciphertext.
467 //
468 // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
469 func kemDecaps(dk *DecapsulationKey768, c *[CiphertextSize768]byte) (K []byte) {
470 fips140.RecordApproved()
471 m := pkeDecrypt(&dk.decryptionKey, c)
472 g := sha3.New512()
473 g.Write(m[:])
474 g.Write(dk.h[:])
475 G := g.Sum([]byte{:0:64})
476 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
477 J := sha3.NewShake256()
478 J.Write(dk.z[:])
479 J.Write(c[:])
480 Kout := []byte{:SharedKeySize}
481 J.Read(Kout)
482 var cc [CiphertextSize768]byte
483 c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r)
484 485 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
486 return Kout
487 }
488 489 // pkeDecrypt decrypts a ciphertext.
490 //
491 // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
492 // although s is retained from kemKeyGen.
493 func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize768]byte) []byte {
494 u := []ringElement{:k}
495 for i := range u {
496 b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)])
497 u[i] = ringDecodeAndDecompress10(b)
498 }
499 500 b := (*[encodingSize4]byte)(c[encodingSize10*k:])
501 v := ringDecodeAndDecompress4(b)
502 503 var mask nttElement // s⊺ ◦ NTT(u)
504 for i := range dx.s {
505 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
506 }
507 w := polySub(v, inverseNTT(mask))
508 509 return ringCompressAndEncode1(nil, w)
510 }
511