package ring import ( "crypto/rand" "crypto/subtle" "errors" "io" "golang.org/x/crypto/sha3" ) // Ring-LWE Key Encapsulation Mechanism. // // Security: IND-CCA2 via Fujisaki-Okamoto transform. // Hardness assumption: Ring-LWE → SVP on ideal lattices in Z[x]/(x^n + 1). // // The core IND-CPA scheme: // // KeyGen: s ← small, e ← small, a ← uniform, b = a·s + e // Public key: (a, b) Secret key: s // // Encrypt(pk, m, coins): // r ← small, e1 ← small, e2 ← small // u = a·r + e1 // v = b·r + e2 + encode(m) // Ciphertext: (u, v) // // Decrypt(sk, ct): // m' = decode(v - s·u) = decode(e2 + r·e - s·e1 + encode(m)) // The noise term (e2 + r·e - s·e1) is small, so rounding recovers m. // // The FO transform wraps this to achieve CCA2 security: // // Encapsulate: generate random m, derive coins = H(m, pk), encrypt m with coins // Decapsulate: decrypt to get m', re-encrypt with H(m', pk), check consistency // KEMParams holds KEM-specific parameters on top of ring parameters. type KEMParams struct { Ring Params // Eta1 is the CBD parameter for the secret and first noise polynomial. Eta1 int // Eta2 is the CBD parameter for the encryption noise. Eta2 int // SharedKeyLen is the length of the shared key in bytes. SharedKeyLen int } // DefaultKEMParams returns KEM parameters targeting ~128-bit post-quantum security. // Uses Falcon-512 ring (n=512, q=12289) with conservative noise parameters. func DefaultKEMParams() KEMParams { return KEMParams{ Ring: Falcon512(), Eta1: 3, Eta2: 3, SharedKeyLen: 32, } } // KEMPublicKey is an Ring-LWE public key. type KEMPublicKey struct { A *Poly // uniform element (in NTT form) B *Poly // b = a·s + e (in NTT form) P KEMParams } // KEMSecretKey is an Ring-LWE secret key. type KEMSecretKey struct { S *Poly // secret polynomial (in NTT form) PK *KEMPublicKey // Z is a random "implicit rejection" value used in FO transform. // If decapsulation fails consistency check, shared key is derived from Z // instead of the message, preventing chosen-ciphertext attacks. Z []byte } // KEMCiphertext is an encrypted key encapsulation. type KEMCiphertext struct { U *Poly // a·r + e1 V *Poly // b·r + e2 + encode(m) } // KEMKeyGen generates a fresh KEM key pair. func KEMKeyGen(kp KEMParams) (*KEMPublicKey, *KEMSecretKey) { return KEMKeyGenFrom(kp, rand.Reader) } // KEMKeyGenFrom generates a key pair from the given randomness source. func KEMKeyGenFrom(kp KEMParams, rng io.Reader) (*KEMPublicKey, *KEMSecretKey) { p := kp.Ring // a ← uniform in R_q a := UniformPolyFrom(p, rng) NTT(a) // s ← CBD_eta1 (secret key) s := CBDPolyFrom(p, kp.Eta1, rng) NTT(s) // e ← CBD_eta1 (key generation noise) e := CBDPolyFrom(p, kp.Eta1, rng) NTT(e) // b = a·s + e in NTT domain b := MulPointwise(a, s) b = Add(b, e) // z ← random implicit rejection value z := make([]byte, kp.SharedKeyLen) if _, err := io.ReadFull(rng, z); err != nil { panic("ring/kem: randomness source failed: " + err.Error()) } pk := &KEMPublicKey{A: a, B: b, P: kp} sk := &KEMSecretKey{S: s, PK: pk, Z: z} return pk, sk } // cpaPKEEncrypt is the inner IND-CPA encryption. // m is a 32-byte message. coins is the deterministic randomness. func cpaPKEEncrypt(pk *KEMPublicKey, m []byte, coins []byte) *KEMCiphertext { p := pk.P.Ring // Derive noise polynomials from coins (deterministic for FO). rng := sha3.NewShake256() rng.Write(coins) // r ← CBD_eta1 r := CBDPolyFrom(p, pk.P.Eta1, rng) NTT(r) // e1 ← CBD_eta2 e1 := CBDPolyFrom(p, pk.P.Eta2, rng) // e2 ← CBD_eta2 e2 := CBDPolyFrom(p, pk.P.Eta2, rng) // u = a·r + e1 (compute in NTT, then back to coefficient form) u := MulPointwise(pk.A, r) INTT(u) u = Add(u, e1) // v = b·r + e2 + encode(m) v := MulPointwise(pk.B, r) INTT(v) v = Add(v, e2) // Encode message: each bit of m maps to floor(q/2) in one coefficient. encoded := encodeMessage(p, m) v = Add(v, encoded) return &KEMCiphertext{U: u, V: v} } // cpaPKEDecrypt is the inner IND-CPA decryption. func cpaPKEDecrypt(sk *KEMSecretKey, ct *KEMCiphertext) []byte { // Compute v - s·u. // s is in NTT form, u is in coefficient form. uNTT := ct.U.Clone() NTT(uNTT) su := MulPointwise(sk.S, uNTT) INTT(su) // noisy = v - s·u = e2 + r·e - s·e1 + encode(m) noisy := Sub(ct.V, su) return decodeMessage(noisy) } // kemMessageBytes is the fixed message size for KEM operations. // 256 bits = 32 bytes. We use the first 256 coefficients to encode 256 bits. const kemMessageBytes = 32 // encodeMessage maps message bytes into polynomial coefficients. // Each bit becomes floor(q/2) or 0 in one coefficient. // Only the first 256 coefficients are used (32 bytes × 8 bits). func encodeMessage(p Params, m []byte) *Poly { poly := New(p) half := p.Q / 2 bits := kemMessageBytes * 8 if bits > p.N { bits = p.N } for i := range bits { byteIdx := i / 8 bitIdx := uint(i % 8) if byteIdx < len(m) && m[byteIdx]&(1< a.params.N { bits = a.params.N } for i := range bits { c := a.Coeffs[i] // Distance to q/2. If within q/4 of q/2, the bit is 1. var distHalf uint32 if c > half { distHalf = c - half } else { distHalf = half - c } if distHalf < quarter { m[i/8] |= 1 << uint(i%8) } } return m } // Encapsulate generates a shared secret and ciphertext using the public key. // Returns (shared_key, ciphertext, error). // This is the full FO-transformed CCA2-secure encapsulation. func Encapsulate(pk *KEMPublicKey) ([]byte, *KEMCiphertext, error) { return EncapsulateFrom(pk, rand.Reader) } // EncapsulateFrom uses the given randomness source. func EncapsulateFrom(pk *KEMPublicKey, rng io.Reader) ([]byte, *KEMCiphertext, error) { // 1. Generate random message m. m := make([]byte, 32) if _, err := io.ReadFull(rng, m); err != nil { return nil, nil, errors.New("ring/kem: randomness failed") } // 2. Derive (K, coins) = G(m, H(pk)). pkHash := hashPublicKey(pk) K, coins := deriveKCoins(m, pkHash) // 3. Encrypt m with deterministic coins. ct := cpaPKEEncrypt(pk, m, coins) // 4. Derive final shared key: K' = KDF(K, H(ct)). ctHash := hashCiphertext(ct) sharedKey := kdf(K, ctHash, pk.P.SharedKeyLen) return sharedKey, ct, nil } // Decapsulate recovers the shared secret from a ciphertext using the secret key. // Returns the shared key or an implicit rejection key (constant-time). func Decapsulate(sk *KEMSecretKey, ct *KEMCiphertext) ([]byte, error) { if ct == nil || ct.U == nil || ct.V == nil { return nil, errors.New("ring/kem: nil ciphertext") } // 1. Decrypt to recover m'. mPrime := cpaPKEDecrypt(sk, ct) // 2. Re-derive (K', coins') = G(m', H(pk)). pkHash := hashPublicKey(sk.PK) KPrime, coinsPrime := deriveKCoins(mPrime, pkHash) // 3. Re-encrypt with coins' to get ct'. ctPrime := cpaPKEEncrypt(sk.PK, mPrime, coinsPrime) // 4. Check ct == ct' (constant-time). match := ciphertextEqual(ct, ctPrime) // 5. If match, shared key = KDF(K', H(ct)). // If no match, shared key = KDF(z, H(ct)) (implicit rejection). ctHash := hashCiphertext(ct) realKey := kdf(KPrime, ctHash, sk.PK.P.SharedKeyLen) rejectKey := kdf(sk.Z, ctHash, sk.PK.P.SharedKeyLen) // Constant-time select. sharedKey := make([]byte, sk.PK.P.SharedKeyLen) subtle.ConstantTimeCopy(match, sharedKey, realKey) subtle.ConstantTimeCopy(1-match, sharedKey, rejectKey) return sharedKey, nil } // --- internal helpers --- func hashPublicKey(pk *KEMPublicKey) []byte { h := sha3.NewShake256() h.Write([]byte("hamadryad-kem-pk")) h.Write(Serialize(pk.A)) h.Write(Serialize(pk.B)) out := make([]byte, 32) h.Read(out) return out } func hashCiphertext(ct *KEMCiphertext) []byte { h := sha3.NewShake256() h.Write([]byte("hamadryad-kem-ct")) h.Write(Serialize(ct.U)) h.Write(Serialize(ct.V)) out := make([]byte, 32) h.Read(out) return out } func deriveKCoins(m, pkHash []byte) (K, coins []byte) { h := sha3.NewShake256() h.Write([]byte("hamadryad-kem-g")) h.Write(m) h.Write(pkHash) out := make([]byte, 64) h.Read(out) return out[:32], out[32:] } func kdf(key, label []byte, outLen int) []byte { h := sha3.NewShake256() h.Write([]byte("hamadryad-kem-kdf")) h.Write(key) h.Write(label) out := make([]byte, outLen) h.Read(out) return out } func ciphertextEqual(a, b *KEMCiphertext) int { aU := Serialize(a.U) bU := Serialize(b.U) aV := Serialize(a.V) bV := Serialize(b.V) if len(aU) != len(bU) || len(aV) != len(bV) { return 0 } return subtle.ConstantTimeCompare(aU, bU) & subtle.ConstantTimeCompare(aV, bV) }