1 // Code generated by generate1024.go. DO NOT EDIT.
2 3 package mlkem
4 5 import (
6 "bytes"
7 "crypto/internal/fips140"
8 "crypto/internal/fips140/drbg"
9 "crypto/internal/fips140/sha3"
10 "crypto/internal/fips140/subtle"
11 "errors"
12 )
13 14 // A DecapsulationKey1024 is the secret key used to decapsulate a shared key from a
15 // ciphertext. It includes various precomputed values.
16 type DecapsulationKey1024 struct {
17 d [32]byte // decapsulation key seed
18 z [32]byte // implicit rejection sampling seed
19 20 ρ [32]byte // sampleNTT seed for A, stored for the encapsulation key
21 h [32]byte // H(ek), stored for ML-KEM.Decaps_internal
22 23 encryptionKey1024
24 decryptionKey1024
25 }
26 27 // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
28 //
29 // The decapsulation key must be kept secret.
30 func (dk *DecapsulationKey1024) Bytes() []byte {
31 var b [SeedSize]byte
32 copy(b[:], dk.d[:])
33 copy(b[32:], dk.z[:])
34 return b[:]
35 }
36 37 // TestingOnlyExpandedBytes1024 returns the decapsulation key as a byte slice
38 // using the full expanded NIST encoding.
39 //
40 // This should only be used for ACVP testing. For all other purposes prefer
41 // the Bytes method that returns the (much smaller) seed.
42 func TestingOnlyExpandedBytes1024(dk *DecapsulationKey1024) []byte {
43 b := []byte{:0:decapsulationKeySize1024}
44 45 // ByteEncode₁₂(s)
46 for i := range dk.s {
47 b = polyByteEncode(b, dk.s[i])
48 }
49 50 // ByteEncode₁₂(t) || ρ
51 for i := range dk.t {
52 b = polyByteEncode(b, dk.t[i])
53 }
54 b = append(b, dk.ρ[:]...)
55 56 // H(ek) || z
57 b = append(b, dk.h[:]...)
58 b = append(b, dk.z[:]...)
59 60 return b
61 }
62 63 // EncapsulationKey returns the public encapsulation key necessary to produce
64 // ciphertexts.
65 func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
66 return &EncapsulationKey1024{
67 ρ: dk.ρ,
68 h: dk.h,
69 encryptionKey1024: dk.encryptionKey1024,
70 }
71 }
72 73 // An EncapsulationKey1024 is the public key used to produce ciphertexts to be
74 // decapsulated by the corresponding [DecapsulationKey1024].
75 type EncapsulationKey1024 struct {
76 ρ [32]byte // sampleNTT seed for A
77 h [32]byte // H(ek)
78 encryptionKey1024
79 }
80 81 // Bytes returns the encapsulation key as a byte slice.
82 func (ek *EncapsulationKey1024) Bytes() []byte {
83 // The actual logic is in a separate function to outline this allocation.
84 b := []byte{:0:EncapsulationKeySize1024}
85 return ek.bytes(b)
86 }
87 88 func (ek *EncapsulationKey1024) bytes(b []byte) []byte {
89 for i := range ek.t {
90 b = polyByteEncode(b, ek.t[i])
91 }
92 b = append(b, ek.ρ[:]...)
93 return b
94 }
95 96 // encryptionKey1024 is the parsed and expanded form of a PKE encryption key.
97 type encryptionKey1024 struct {
98 t [k1024]nttElement // ByteDecode₁₂(ek[:384k])
99 a [k1024 * k1024]nttElement // A[i*k+j] = sampleNTT(ρ, j, i)
100 }
101 102 // decryptionKey1024 is the parsed and expanded form of a PKE decryption key.
103 type decryptionKey1024 struct {
104 s [k1024]nttElement // ByteDecode₁₂(dk[:decryptionKey1024Size])
105 }
106 107 // GenerateKey1024 generates a new decapsulation key, drawing random bytes from
108 // a DRBG. The decapsulation key must be kept secret.
109 func GenerateKey1024() (*DecapsulationKey1024, error) {
110 // The actual logic is in a separate function to outline this allocation.
111 dk := &DecapsulationKey1024{}
112 return generateKey1024(dk)
113 }
114 115 func generateKey1024(dk *DecapsulationKey1024) (*DecapsulationKey1024, error) {
116 var d [32]byte
117 drbg.Read(d[:])
118 var z [32]byte
119 drbg.Read(z[:])
120 kemKeyGen1024(dk, &d, &z)
121 fips140.PCT("ML-KEM PCT", func() error { return kemPCT1024(dk) })
122 fips140.RecordApproved()
123 return dk, nil
124 }
125 126 // GenerateKeyInternal1024 is a derandomized version of GenerateKey1024,
127 // exclusively for use in tests.
128 func GenerateKeyInternal1024(d, z *[32]byte) *DecapsulationKey1024 {
129 dk := &DecapsulationKey1024{}
130 kemKeyGen1024(dk, d, z)
131 return dk
132 }
133 134 // NewDecapsulationKey1024 parses a decapsulation key from a 64-byte
135 // seed in the "d || z" form. The seed must be uniformly random.
136 func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
137 // The actual logic is in a separate function to outline this allocation.
138 dk := &DecapsulationKey1024{}
139 return newKeyFromSeed1024(dk, seed)
140 }
141 142 func newKeyFromSeed1024(dk *DecapsulationKey1024, seed []byte) (*DecapsulationKey1024, error) {
143 if len(seed) != SeedSize {
144 return nil, errors.New("mlkem: invalid seed length")
145 }
146 d := (*[32]byte)(seed[:32])
147 z := (*[32]byte)(seed[32:])
148 kemKeyGen1024(dk, d, z)
149 fips140.RecordApproved()
150 return dk, nil
151 }
152 153 // TestingOnlyNewDecapsulationKey1024 parses a decapsulation key from its expanded NIST format.
154 //
155 // Bytes() must not be called on the returned key, as it will not produce the
156 // original seed.
157 //
158 // This function should only be used for ACVP testing. Prefer NewDecapsulationKey1024 for all
159 // other purposes.
160 func TestingOnlyNewDecapsulationKey1024(b []byte) (*DecapsulationKey1024, error) {
161 if len(b) != decapsulationKeySize1024 {
162 return nil, errors.New("mlkem: invalid NIST decapsulation key length")
163 }
164 165 dk := &DecapsulationKey1024{}
166 for i := range dk.s {
167 var err error
168 dk.s[i], err = polyByteDecode[nttElement](b[:encodingSize12])
169 if err != nil {
170 return nil, errors.New("mlkem: invalid secret key encoding")
171 }
172 b = b[encodingSize12:]
173 }
174 175 ek, err := NewEncapsulationKey1024(b[:EncapsulationKeySize1024])
176 if err != nil {
177 return nil, err
178 }
179 dk.ρ = ek.ρ
180 dk.h = ek.h
181 dk.encryptionKey1024 = ek.encryptionKey1024
182 b = b[EncapsulationKeySize1024:]
183 184 if !bytes.Equal(dk.h[:], b[:32]) {
185 return nil, errors.New("mlkem: inconsistent H(ek) in encoded bytes")
186 }
187 b = b[32:]
188 189 copy(dk.z[:], b)
190 191 // Generate a random d value for use in Bytes(). This is a safety mechanism
192 // that avoids returning a broken key vs a random key if this function is
193 // called in contravention of the TestingOnlyNewDecapsulationKey1024 function
194 // comment advising against it.
195 drbg.Read(dk.d[:])
196 197 return dk, nil
198 }
199 200 // kemKeyGen1024 generates a decapsulation key.
201 //
202 // It implements ML-KEM.KeyGen_internal according to FIPS 203, Algorithm 16, and
203 // K-PKE.KeyGen according to FIPS 203, Algorithm 13. The two are merged to save
204 // copies and allocations.
205 func kemKeyGen1024(dk *DecapsulationKey1024, d, z *[32]byte) {
206 dk.d = *d
207 dk.z = *z
208 209 g := sha3.New512()
210 g.Write(d[:])
211 g.Write([]byte{k1024}) // Module dimension as a domain separator.
212 G := g.Sum([]byte{:0:64})
213 ρ, σ := G[:32], G[32:]
214 dk.ρ = [32]byte(ρ)
215 216 A := &dk.a
217 for i := byte(0); i < k1024; i++ {
218 for j := byte(0); j < k1024; j++ {
219 A[i*k1024+j] = sampleNTT(ρ, j, i)
220 }
221 }
222 223 var N byte
224 s := &dk.s
225 for i := range s {
226 s[i] = ntt(samplePolyCBD(σ, N))
227 N++
228 }
229 e := []nttElement{:k1024}
230 for i := range e {
231 e[i] = ntt(samplePolyCBD(σ, N))
232 N++
233 }
234 235 t := &dk.t
236 for i := range t { // t = A ◦ s + e
237 t[i] = e[i]
238 for j := range s {
239 t[i] = polyAdd(t[i], nttMul(A[i*k1024+j], s[j]))
240 }
241 }
242 243 H := sha3.New256()
244 ek := dk.EncapsulationKey().Bytes()
245 H.Write(ek)
246 H.Sum(dk.h[:0])
247 }
248 249 // kemPCT1024 performs a Pairwise Consistency Test per FIPS 140-3 IG 10.3.A
250 // Additional Comment 1: "For key pairs generated for use with approved KEMs in
251 // FIPS 203, the PCT shall consist of applying the encapsulation key ek to
252 // encapsulate a shared secret K leading to ciphertext c, and then applying
253 // decapsulation key dk to retrieve the same shared secret K. The PCT passes if
254 // the two shared secret K values are equal. The PCT shall be performed either
255 // when keys are generated/imported, prior to the first exportation, or prior to
256 // the first operational use (if not exported before the first use)."
257 func kemPCT1024(dk *DecapsulationKey1024) error {
258 ek := dk.EncapsulationKey()
259 K, c := ek.Encapsulate()
260 K1, err := dk.Decapsulate(c)
261 if err != nil {
262 return err
263 }
264 if subtle.ConstantTimeCompare(K, K1) != 1 {
265 return errors.New("mlkem: PCT failed")
266 }
267 return nil
268 }
269 270 // Encapsulate generates a shared key and an associated ciphertext from an
271 // encapsulation key, drawing random bytes from a DRBG.
272 //
273 // The shared key must be kept secret.
274 func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
275 // The actual logic is in a separate function to outline this allocation.
276 var cc [CiphertextSize1024]byte
277 return ek.encapsulate(&cc)
278 }
279 280 func (ek *EncapsulationKey1024) encapsulate(cc *[CiphertextSize1024]byte) (sharedKey, ciphertext []byte) {
281 var m [messageSize]byte
282 drbg.Read(m[:])
283 // Note that the modulus check (step 2 of the encapsulation key check from
284 // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK1024.
285 fips140.RecordApproved()
286 return kemEncaps1024(cc, ek, &m)
287 }
288 289 // EncapsulateInternal is a derandomized version of Encapsulate, exclusively for
290 // use in tests.
291 func (ek *EncapsulationKey1024) EncapsulateInternal(m *[32]byte) (sharedKey, ciphertext []byte) {
292 cc := &[CiphertextSize1024]byte{}
293 return kemEncaps1024(cc, ek, m)
294 }
295 296 // kemEncaps1024 generates a shared key and an associated ciphertext.
297 //
298 // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17.
299 func kemEncaps1024(cc *[CiphertextSize1024]byte, ek *EncapsulationKey1024, m *[messageSize]byte) (K, c []byte) {
300 g := sha3.New512()
301 g.Write(m[:])
302 g.Write(ek.h[:])
303 G := g.Sum(nil)
304 K, r := G[:SharedKeySize], G[SharedKeySize:]
305 c = pkeEncrypt1024(cc, &ek.encryptionKey1024, m, r)
306 return K, c
307 }
308 309 // NewEncapsulationKey1024 parses an encapsulation key from its encoded form.
310 // If the encapsulation key is not valid, NewEncapsulationKey1024 returns an error.
311 func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
312 // The actual logic is in a separate function to outline this allocation.
313 ek := &EncapsulationKey1024{}
314 return parseEK1024(ek, encapsulationKey)
315 }
316 317 // parseEK1024 parses an encryption key from its encoded form.
318 //
319 // It implements the initial stages of K-PKE.Encrypt according to FIPS 203,
320 // Algorithm 14.
321 func parseEK1024(ek *EncapsulationKey1024, ekPKE []byte) (*EncapsulationKey1024, error) {
322 if len(ekPKE) != EncapsulationKeySize1024 {
323 return nil, errors.New("mlkem: invalid encapsulation key length")
324 }
325 326 h := sha3.New256()
327 h.Write(ekPKE)
328 h.Sum(ek.h[:0])
329 330 for i := range ek.t {
331 var err error
332 ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12])
333 if err != nil {
334 return nil, err
335 }
336 ekPKE = ekPKE[encodingSize12:]
337 }
338 copy(ek.ρ[:], ekPKE)
339 340 for i := byte(0); i < k1024; i++ {
341 for j := byte(0); j < k1024; j++ {
342 ek.a[i*k1024+j] = sampleNTT(ek.ρ[:], j, i)
343 }
344 }
345 346 return ek, nil
347 }
348 349 // pkeEncrypt1024 encrypt a plaintext message.
350 //
351 // It implements K-PKE.Encrypt according to FIPS 203, Algorithm 14, although the
352 // computation of t and AT is done in parseEK1024.
353 func pkeEncrypt1024(cc *[CiphertextSize1024]byte, ex *encryptionKey1024, m *[messageSize]byte, rnd []byte) []byte {
354 var N byte
355 r, e1 := []nttElement{:k1024}, []ringElement{:k1024}
356 for i := range r {
357 r[i] = ntt(samplePolyCBD(rnd, N))
358 N++
359 }
360 for i := range e1 {
361 e1[i] = samplePolyCBD(rnd, N)
362 N++
363 }
364 e2 := samplePolyCBD(rnd, N)
365 366 u := []ringElement{:k1024} // NTT⁻¹(AT ◦ r) + e1
367 for i := range u {
368 u[i] = e1[i]
369 for j := range r {
370 // Note that i and j are inverted, as we need the transposed of A.
371 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.a[j*k1024+i], r[j])))
372 }
373 }
374 375 μ := ringDecodeAndDecompress1(m)
376 377 var vNTT nttElement // t⊺ ◦ r
378 for i := range ex.t {
379 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i]))
380 }
381 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ)
382 383 c := cc[:0]
384 for _, f := range u {
385 c = ringCompressAndEncode11(c, f)
386 }
387 c = ringCompressAndEncode5(c, v)
388 389 return c
390 }
391 392 // Decapsulate generates a shared key from a ciphertext and a decapsulation key.
393 // If the ciphertext is not valid, Decapsulate returns an error.
394 //
395 // The shared key must be kept secret.
396 func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
397 if len(ciphertext) != CiphertextSize1024 {
398 return nil, errors.New("mlkem: invalid ciphertext length")
399 }
400 c := (*[CiphertextSize1024]byte)(ciphertext)
401 // Note that the hash check (step 3 of the decapsulation input check from
402 // FIPS 203, Section 7.3) is foregone as a DecapsulationKey is always
403 // validly generated by ML-KEM.KeyGen_internal.
404 return kemDecaps1024(dk, c), nil
405 }
406 407 // kemDecaps1024 produces a shared key from a ciphertext.
408 //
409 // It implements ML-KEM.Decaps_internal according to FIPS 203, Algorithm 18.
410 func kemDecaps1024(dk *DecapsulationKey1024, c *[CiphertextSize1024]byte) (K []byte) {
411 fips140.RecordApproved()
412 m := pkeDecrypt1024(&dk.decryptionKey1024, c)
413 g := sha3.New512()
414 g.Write(m[:])
415 g.Write(dk.h[:])
416 G := g.Sum([]byte{:0:64})
417 Kprime, r := G[:SharedKeySize], G[SharedKeySize:]
418 J := sha3.NewShake256()
419 J.Write(dk.z[:])
420 J.Write(c[:])
421 Kout := []byte{:SharedKeySize}
422 J.Read(Kout)
423 var cc [CiphertextSize1024]byte
424 c1 := pkeEncrypt1024(&cc, &dk.encryptionKey1024, (*[32]byte)(m), r)
425 426 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime)
427 return Kout
428 }
429 430 // pkeDecrypt1024 decrypts a ciphertext.
431 //
432 // It implements K-PKE.Decrypt according to FIPS 203, Algorithm 15,
433 // although s is retained from kemKeyGen1024.
434 func pkeDecrypt1024(dx *decryptionKey1024, c *[CiphertextSize1024]byte) []byte {
435 u := []ringElement{:k1024}
436 for i := range u {
437 b := (*[encodingSize11]byte)(c[encodingSize11*i : encodingSize11*(i+1)])
438 u[i] = ringDecodeAndDecompress11(b)
439 }
440 441 b := (*[encodingSize5]byte)(c[encodingSize11*k1024:])
442 v := ringDecodeAndDecompress5(b)
443 444 var mask nttElement // s⊺ ◦ NTT(u)
445 for i := range dx.s {
446 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i])))
447 }
448 w := polySub(v, inverseNTT(mask))
449 450 return ringCompressAndEncode1(nil, w)
451 }
452