kem.go raw
1 package ring
2
3 import (
4 "crypto/rand"
5 "crypto/subtle"
6 "errors"
7 "io"
8
9 "golang.org/x/crypto/sha3"
10 )
11
12 // Ring-LWE Key Encapsulation Mechanism.
13 //
14 // Security: IND-CCA2 via Fujisaki-Okamoto transform.
15 // Hardness assumption: Ring-LWE → SVP on ideal lattices in Z[x]/(x^n + 1).
16 //
17 // The core IND-CPA scheme:
18 //
19 // KeyGen: s ← small, e ← small, a ← uniform, b = a·s + e
20 // Public key: (a, b) Secret key: s
21 //
22 // Encrypt(pk, m, coins):
23 // r ← small, e1 ← small, e2 ← small
24 // u = a·r + e1
25 // v = b·r + e2 + encode(m)
26 // Ciphertext: (u, v)
27 //
28 // Decrypt(sk, ct):
29 // m' = decode(v - s·u) = decode(e2 + r·e - s·e1 + encode(m))
30 // The noise term (e2 + r·e - s·e1) is small, so rounding recovers m.
31 //
32 // The FO transform wraps this to achieve CCA2 security:
33 //
34 // Encapsulate: generate random m, derive coins = H(m, pk), encrypt m with coins
35 // Decapsulate: decrypt to get m', re-encrypt with H(m', pk), check consistency
36
37 // KEMParams holds KEM-specific parameters on top of ring parameters.
38 type KEMParams struct {
39 Ring Params
40
41 // Eta1 is the CBD parameter for the secret and first noise polynomial.
42 Eta1 int
43
44 // Eta2 is the CBD parameter for the encryption noise.
45 Eta2 int
46
47 // SharedKeyLen is the length of the shared key in bytes.
48 SharedKeyLen int
49 }
50
51 // DefaultKEMParams returns KEM parameters targeting ~128-bit post-quantum security.
52 // Uses Falcon-512 ring (n=512, q=12289) with conservative noise parameters.
53 func DefaultKEMParams() KEMParams {
54 return KEMParams{
55 Ring: Falcon512(),
56 Eta1: 3,
57 Eta2: 3,
58 SharedKeyLen: 32,
59 }
60 }
61
62 // KEMPublicKey is an Ring-LWE public key.
63 type KEMPublicKey struct {
64 A *Poly // uniform element (in NTT form)
65 B *Poly // b = a·s + e (in NTT form)
66 P KEMParams
67 }
68
69 // KEMSecretKey is an Ring-LWE secret key.
70 type KEMSecretKey struct {
71 S *Poly // secret polynomial (in NTT form)
72 PK *KEMPublicKey
73 // Z is a random "implicit rejection" value used in FO transform.
74 // If decapsulation fails consistency check, shared key is derived from Z
75 // instead of the message, preventing chosen-ciphertext attacks.
76 Z []byte
77 }
78
79 // KEMCiphertext is an encrypted key encapsulation.
80 type KEMCiphertext struct {
81 U *Poly // a·r + e1
82 V *Poly // b·r + e2 + encode(m)
83 }
84
85 // KEMKeyGen generates a fresh KEM key pair.
86 func KEMKeyGen(kp KEMParams) (*KEMPublicKey, *KEMSecretKey) {
87 return KEMKeyGenFrom(kp, rand.Reader)
88 }
89
90 // KEMKeyGenFrom generates a key pair from the given randomness source.
91 func KEMKeyGenFrom(kp KEMParams, rng io.Reader) (*KEMPublicKey, *KEMSecretKey) {
92 p := kp.Ring
93
94 // a ← uniform in R_q
95 a := UniformPolyFrom(p, rng)
96 NTT(a)
97
98 // s ← CBD_eta1 (secret key)
99 s := CBDPolyFrom(p, kp.Eta1, rng)
100 NTT(s)
101
102 // e ← CBD_eta1 (key generation noise)
103 e := CBDPolyFrom(p, kp.Eta1, rng)
104 NTT(e)
105
106 // b = a·s + e in NTT domain
107 b := MulPointwise(a, s)
108 b = Add(b, e)
109
110 // z ← random implicit rejection value
111 z := make([]byte, kp.SharedKeyLen)
112 if _, err := io.ReadFull(rng, z); err != nil {
113 panic("ring/kem: randomness source failed: " + err.Error())
114 }
115
116 pk := &KEMPublicKey{A: a, B: b, P: kp}
117 sk := &KEMSecretKey{S: s, PK: pk, Z: z}
118 return pk, sk
119 }
120
121 // cpaPKEEncrypt is the inner IND-CPA encryption.
122 // m is a 32-byte message. coins is the deterministic randomness.
123 func cpaPKEEncrypt(pk *KEMPublicKey, m []byte, coins []byte) *KEMCiphertext {
124 p := pk.P.Ring
125
126 // Derive noise polynomials from coins (deterministic for FO).
127 rng := sha3.NewShake256()
128 rng.Write(coins)
129
130 // r ← CBD_eta1
131 r := CBDPolyFrom(p, pk.P.Eta1, rng)
132 NTT(r)
133
134 // e1 ← CBD_eta2
135 e1 := CBDPolyFrom(p, pk.P.Eta2, rng)
136
137 // e2 ← CBD_eta2
138 e2 := CBDPolyFrom(p, pk.P.Eta2, rng)
139
140 // u = a·r + e1 (compute in NTT, then back to coefficient form)
141 u := MulPointwise(pk.A, r)
142 INTT(u)
143 u = Add(u, e1)
144
145 // v = b·r + e2 + encode(m)
146 v := MulPointwise(pk.B, r)
147 INTT(v)
148 v = Add(v, e2)
149
150 // Encode message: each bit of m maps to floor(q/2) in one coefficient.
151 encoded := encodeMessage(p, m)
152 v = Add(v, encoded)
153
154 return &KEMCiphertext{U: u, V: v}
155 }
156
157 // cpaPKEDecrypt is the inner IND-CPA decryption.
158 func cpaPKEDecrypt(sk *KEMSecretKey, ct *KEMCiphertext) []byte {
159 // Compute v - s·u.
160 // s is in NTT form, u is in coefficient form.
161 uNTT := ct.U.Clone()
162 NTT(uNTT)
163
164 su := MulPointwise(sk.S, uNTT)
165 INTT(su)
166
167 // noisy = v - s·u = e2 + r·e - s·e1 + encode(m)
168 noisy := Sub(ct.V, su)
169
170 return decodeMessage(noisy)
171 }
172
173 // kemMessageBytes is the fixed message size for KEM operations.
174 // 256 bits = 32 bytes. We use the first 256 coefficients to encode 256 bits.
175 const kemMessageBytes = 32
176
177 // encodeMessage maps message bytes into polynomial coefficients.
178 // Each bit becomes floor(q/2) or 0 in one coefficient.
179 // Only the first 256 coefficients are used (32 bytes × 8 bits).
180 func encodeMessage(p Params, m []byte) *Poly {
181 poly := New(p)
182 half := p.Q / 2
183 bits := kemMessageBytes * 8
184 if bits > p.N {
185 bits = p.N
186 }
187 for i := range bits {
188 byteIdx := i / 8
189 bitIdx := uint(i % 8)
190 if byteIdx < len(m) && m[byteIdx]&(1<<bitIdx) != 0 {
191 poly.Coeffs[i] = half
192 }
193 }
194 return poly
195 }
196
197 // decodeMessage recovers the fixed-size message from noisy coefficients.
198 // Reads the first 256 coefficients, each rounded to 0 or floor(q/2).
199 func decodeMessage(a *Poly) []byte {
200 q := a.params.Q
201 half := q / 2
202 quarter := q / 4
203 m := make([]byte, kemMessageBytes)
204
205 bits := kemMessageBytes * 8
206 if bits > a.params.N {
207 bits = a.params.N
208 }
209 for i := range bits {
210 c := a.Coeffs[i]
211 // Distance to q/2. If within q/4 of q/2, the bit is 1.
212 var distHalf uint32
213 if c > half {
214 distHalf = c - half
215 } else {
216 distHalf = half - c
217 }
218 if distHalf < quarter {
219 m[i/8] |= 1 << uint(i%8)
220 }
221 }
222 return m
223 }
224
225 // Encapsulate generates a shared secret and ciphertext using the public key.
226 // Returns (shared_key, ciphertext, error).
227 // This is the full FO-transformed CCA2-secure encapsulation.
228 func Encapsulate(pk *KEMPublicKey) ([]byte, *KEMCiphertext, error) {
229 return EncapsulateFrom(pk, rand.Reader)
230 }
231
232 // EncapsulateFrom uses the given randomness source.
233 func EncapsulateFrom(pk *KEMPublicKey, rng io.Reader) ([]byte, *KEMCiphertext, error) {
234 // 1. Generate random message m.
235 m := make([]byte, 32)
236 if _, err := io.ReadFull(rng, m); err != nil {
237 return nil, nil, errors.New("ring/kem: randomness failed")
238 }
239
240 // 2. Derive (K, coins) = G(m, H(pk)).
241 pkHash := hashPublicKey(pk)
242 K, coins := deriveKCoins(m, pkHash)
243
244 // 3. Encrypt m with deterministic coins.
245 ct := cpaPKEEncrypt(pk, m, coins)
246
247 // 4. Derive final shared key: K' = KDF(K, H(ct)).
248 ctHash := hashCiphertext(ct)
249 sharedKey := kdf(K, ctHash, pk.P.SharedKeyLen)
250
251 return sharedKey, ct, nil
252 }
253
254 // Decapsulate recovers the shared secret from a ciphertext using the secret key.
255 // Returns the shared key or an implicit rejection key (constant-time).
256 func Decapsulate(sk *KEMSecretKey, ct *KEMCiphertext) ([]byte, error) {
257 if ct == nil || ct.U == nil || ct.V == nil {
258 return nil, errors.New("ring/kem: nil ciphertext")
259 }
260
261 // 1. Decrypt to recover m'.
262 mPrime := cpaPKEDecrypt(sk, ct)
263
264 // 2. Re-derive (K', coins') = G(m', H(pk)).
265 pkHash := hashPublicKey(sk.PK)
266 KPrime, coinsPrime := deriveKCoins(mPrime, pkHash)
267
268 // 3. Re-encrypt with coins' to get ct'.
269 ctPrime := cpaPKEEncrypt(sk.PK, mPrime, coinsPrime)
270
271 // 4. Check ct == ct' (constant-time).
272 match := ciphertextEqual(ct, ctPrime)
273
274 // 5. If match, shared key = KDF(K', H(ct)).
275 // If no match, shared key = KDF(z, H(ct)) (implicit rejection).
276 ctHash := hashCiphertext(ct)
277 realKey := kdf(KPrime, ctHash, sk.PK.P.SharedKeyLen)
278 rejectKey := kdf(sk.Z, ctHash, sk.PK.P.SharedKeyLen)
279
280 // Constant-time select.
281 sharedKey := make([]byte, sk.PK.P.SharedKeyLen)
282 subtle.ConstantTimeCopy(match, sharedKey, realKey)
283 subtle.ConstantTimeCopy(1-match, sharedKey, rejectKey)
284
285 return sharedKey, nil
286 }
287
288 // --- internal helpers ---
289
290 func hashPublicKey(pk *KEMPublicKey) []byte {
291 h := sha3.NewShake256()
292 h.Write([]byte("hamadryad-kem-pk"))
293 h.Write(Serialize(pk.A))
294 h.Write(Serialize(pk.B))
295 out := make([]byte, 32)
296 h.Read(out)
297 return out
298 }
299
300 func hashCiphertext(ct *KEMCiphertext) []byte {
301 h := sha3.NewShake256()
302 h.Write([]byte("hamadryad-kem-ct"))
303 h.Write(Serialize(ct.U))
304 h.Write(Serialize(ct.V))
305 out := make([]byte, 32)
306 h.Read(out)
307 return out
308 }
309
310 func deriveKCoins(m, pkHash []byte) (K, coins []byte) {
311 h := sha3.NewShake256()
312 h.Write([]byte("hamadryad-kem-g"))
313 h.Write(m)
314 h.Write(pkHash)
315 out := make([]byte, 64)
316 h.Read(out)
317 return out[:32], out[32:]
318 }
319
320 func kdf(key, label []byte, outLen int) []byte {
321 h := sha3.NewShake256()
322 h.Write([]byte("hamadryad-kem-kdf"))
323 h.Write(key)
324 h.Write(label)
325 out := make([]byte, outLen)
326 h.Read(out)
327 return out
328 }
329
330 func ciphertextEqual(a, b *KEMCiphertext) int {
331 aU := Serialize(a.U)
332 bU := Serialize(b.U)
333 aV := Serialize(a.V)
334 bV := Serialize(b.V)
335
336 if len(aU) != len(bU) || len(aV) != len(bV) {
337 return 0
338 }
339
340 return subtle.ConstantTimeCompare(aU, bU) & subtle.ConstantTimeCompare(aV, bV)
341 }
342