kem.go raw

   1  package ring
   2  
   3  import (
   4  	"crypto/rand"
   5  	"crypto/subtle"
   6  	"errors"
   7  	"io"
   8  
   9  	"golang.org/x/crypto/sha3"
  10  )
  11  
  12  // Ring-LWE Key Encapsulation Mechanism.
  13  //
  14  // Security: IND-CCA2 via Fujisaki-Okamoto transform.
  15  // Hardness assumption: Ring-LWE → SVP on ideal lattices in Z[x]/(x^n + 1).
  16  //
  17  // The core IND-CPA scheme:
  18  //
  19  //	KeyGen: s ← small, e ← small, a ← uniform, b = a·s + e
  20  //	  Public key: (a, b)    Secret key: s
  21  //
  22  //	Encrypt(pk, m, coins):
  23  //	  r ← small, e1 ← small, e2 ← small
  24  //	  u = a·r + e1
  25  //	  v = b·r + e2 + encode(m)
  26  //	  Ciphertext: (u, v)
  27  //
  28  //	Decrypt(sk, ct):
  29  //	  m' = decode(v - s·u) = decode(e2 + r·e - s·e1 + encode(m))
  30  //	  The noise term (e2 + r·e - s·e1) is small, so rounding recovers m.
  31  //
  32  // The FO transform wraps this to achieve CCA2 security:
  33  //
  34  //	Encapsulate: generate random m, derive coins = H(m, pk), encrypt m with coins
  35  //	Decapsulate: decrypt to get m', re-encrypt with H(m', pk), check consistency
  36  
  37  // KEMParams holds KEM-specific parameters on top of ring parameters.
  38  type KEMParams struct {
  39  	Ring Params
  40  
  41  	// Eta1 is the CBD parameter for the secret and first noise polynomial.
  42  	Eta1 int
  43  
  44  	// Eta2 is the CBD parameter for the encryption noise.
  45  	Eta2 int
  46  
  47  	// SharedKeyLen is the length of the shared key in bytes.
  48  	SharedKeyLen int
  49  }
  50  
  51  // DefaultKEMParams returns KEM parameters targeting ~128-bit post-quantum security.
  52  // Uses Falcon-512 ring (n=512, q=12289) with conservative noise parameters.
  53  func DefaultKEMParams() KEMParams {
  54  	return KEMParams{
  55  		Ring:         Falcon512(),
  56  		Eta1:         3,
  57  		Eta2:         3,
  58  		SharedKeyLen: 32,
  59  	}
  60  }
  61  
  62  // KEMPublicKey is an Ring-LWE public key.
  63  type KEMPublicKey struct {
  64  	A *Poly // uniform element (in NTT form)
  65  	B *Poly // b = a·s + e (in NTT form)
  66  	P KEMParams
  67  }
  68  
  69  // KEMSecretKey is an Ring-LWE secret key.
  70  type KEMSecretKey struct {
  71  	S  *Poly // secret polynomial (in NTT form)
  72  	PK *KEMPublicKey
  73  	// Z is a random "implicit rejection" value used in FO transform.
  74  	// If decapsulation fails consistency check, shared key is derived from Z
  75  	// instead of the message, preventing chosen-ciphertext attacks.
  76  	Z []byte
  77  }
  78  
  79  // KEMCiphertext is an encrypted key encapsulation.
  80  type KEMCiphertext struct {
  81  	U *Poly // a·r + e1
  82  	V *Poly // b·r + e2 + encode(m)
  83  }
  84  
  85  // KEMKeyGen generates a fresh KEM key pair.
  86  func KEMKeyGen(kp KEMParams) (*KEMPublicKey, *KEMSecretKey) {
  87  	return KEMKeyGenFrom(kp, rand.Reader)
  88  }
  89  
  90  // KEMKeyGenFrom generates a key pair from the given randomness source.
  91  func KEMKeyGenFrom(kp KEMParams, rng io.Reader) (*KEMPublicKey, *KEMSecretKey) {
  92  	p := kp.Ring
  93  
  94  	// a ← uniform in R_q
  95  	a := UniformPolyFrom(p, rng)
  96  	NTT(a)
  97  
  98  	// s ← CBD_eta1 (secret key)
  99  	s := CBDPolyFrom(p, kp.Eta1, rng)
 100  	NTT(s)
 101  
 102  	// e ← CBD_eta1 (key generation noise)
 103  	e := CBDPolyFrom(p, kp.Eta1, rng)
 104  	NTT(e)
 105  
 106  	// b = a·s + e in NTT domain
 107  	b := MulPointwise(a, s)
 108  	b = Add(b, e)
 109  
 110  	// z ← random implicit rejection value
 111  	z := make([]byte, kp.SharedKeyLen)
 112  	if _, err := io.ReadFull(rng, z); err != nil {
 113  		panic("ring/kem: randomness source failed: " + err.Error())
 114  	}
 115  
 116  	pk := &KEMPublicKey{A: a, B: b, P: kp}
 117  	sk := &KEMSecretKey{S: s, PK: pk, Z: z}
 118  	return pk, sk
 119  }
 120  
 121  // cpaPKEEncrypt is the inner IND-CPA encryption.
 122  // m is a 32-byte message. coins is the deterministic randomness.
 123  func cpaPKEEncrypt(pk *KEMPublicKey, m []byte, coins []byte) *KEMCiphertext {
 124  	p := pk.P.Ring
 125  
 126  	// Derive noise polynomials from coins (deterministic for FO).
 127  	rng := sha3.NewShake256()
 128  	rng.Write(coins)
 129  
 130  	// r ← CBD_eta1
 131  	r := CBDPolyFrom(p, pk.P.Eta1, rng)
 132  	NTT(r)
 133  
 134  	// e1 ← CBD_eta2
 135  	e1 := CBDPolyFrom(p, pk.P.Eta2, rng)
 136  
 137  	// e2 ← CBD_eta2
 138  	e2 := CBDPolyFrom(p, pk.P.Eta2, rng)
 139  
 140  	// u = a·r + e1 (compute in NTT, then back to coefficient form)
 141  	u := MulPointwise(pk.A, r)
 142  	INTT(u)
 143  	u = Add(u, e1)
 144  
 145  	// v = b·r + e2 + encode(m)
 146  	v := MulPointwise(pk.B, r)
 147  	INTT(v)
 148  	v = Add(v, e2)
 149  
 150  	// Encode message: each bit of m maps to floor(q/2) in one coefficient.
 151  	encoded := encodeMessage(p, m)
 152  	v = Add(v, encoded)
 153  
 154  	return &KEMCiphertext{U: u, V: v}
 155  }
 156  
 157  // cpaPKEDecrypt is the inner IND-CPA decryption.
 158  func cpaPKEDecrypt(sk *KEMSecretKey, ct *KEMCiphertext) []byte {
 159  	// Compute v - s·u.
 160  	// s is in NTT form, u is in coefficient form.
 161  	uNTT := ct.U.Clone()
 162  	NTT(uNTT)
 163  
 164  	su := MulPointwise(sk.S, uNTT)
 165  	INTT(su)
 166  
 167  	// noisy = v - s·u = e2 + r·e - s·e1 + encode(m)
 168  	noisy := Sub(ct.V, su)
 169  
 170  	return decodeMessage(noisy)
 171  }
 172  
 173  // kemMessageBytes is the fixed message size for KEM operations.
 174  // 256 bits = 32 bytes. We use the first 256 coefficients to encode 256 bits.
 175  const kemMessageBytes = 32
 176  
 177  // encodeMessage maps message bytes into polynomial coefficients.
 178  // Each bit becomes floor(q/2) or 0 in one coefficient.
 179  // Only the first 256 coefficients are used (32 bytes × 8 bits).
 180  func encodeMessage(p Params, m []byte) *Poly {
 181  	poly := New(p)
 182  	half := p.Q / 2
 183  	bits := kemMessageBytes * 8
 184  	if bits > p.N {
 185  		bits = p.N
 186  	}
 187  	for i := range bits {
 188  		byteIdx := i / 8
 189  		bitIdx := uint(i % 8)
 190  		if byteIdx < len(m) && m[byteIdx]&(1<<bitIdx) != 0 {
 191  			poly.Coeffs[i] = half
 192  		}
 193  	}
 194  	return poly
 195  }
 196  
 197  // decodeMessage recovers the fixed-size message from noisy coefficients.
 198  // Reads the first 256 coefficients, each rounded to 0 or floor(q/2).
 199  func decodeMessage(a *Poly) []byte {
 200  	q := a.params.Q
 201  	half := q / 2
 202  	quarter := q / 4
 203  	m := make([]byte, kemMessageBytes)
 204  
 205  	bits := kemMessageBytes * 8
 206  	if bits > a.params.N {
 207  		bits = a.params.N
 208  	}
 209  	for i := range bits {
 210  		c := a.Coeffs[i]
 211  		// Distance to q/2. If within q/4 of q/2, the bit is 1.
 212  		var distHalf uint32
 213  		if c > half {
 214  			distHalf = c - half
 215  		} else {
 216  			distHalf = half - c
 217  		}
 218  		if distHalf < quarter {
 219  			m[i/8] |= 1 << uint(i%8)
 220  		}
 221  	}
 222  	return m
 223  }
 224  
 225  // Encapsulate generates a shared secret and ciphertext using the public key.
 226  // Returns (shared_key, ciphertext, error).
 227  // This is the full FO-transformed CCA2-secure encapsulation.
 228  func Encapsulate(pk *KEMPublicKey) ([]byte, *KEMCiphertext, error) {
 229  	return EncapsulateFrom(pk, rand.Reader)
 230  }
 231  
 232  // EncapsulateFrom uses the given randomness source.
 233  func EncapsulateFrom(pk *KEMPublicKey, rng io.Reader) ([]byte, *KEMCiphertext, error) {
 234  	// 1. Generate random message m.
 235  	m := make([]byte, 32)
 236  	if _, err := io.ReadFull(rng, m); err != nil {
 237  		return nil, nil, errors.New("ring/kem: randomness failed")
 238  	}
 239  
 240  	// 2. Derive (K, coins) = G(m, H(pk)).
 241  	pkHash := hashPublicKey(pk)
 242  	K, coins := deriveKCoins(m, pkHash)
 243  
 244  	// 3. Encrypt m with deterministic coins.
 245  	ct := cpaPKEEncrypt(pk, m, coins)
 246  
 247  	// 4. Derive final shared key: K' = KDF(K, H(ct)).
 248  	ctHash := hashCiphertext(ct)
 249  	sharedKey := kdf(K, ctHash, pk.P.SharedKeyLen)
 250  
 251  	return sharedKey, ct, nil
 252  }
 253  
 254  // Decapsulate recovers the shared secret from a ciphertext using the secret key.
 255  // Returns the shared key or an implicit rejection key (constant-time).
 256  func Decapsulate(sk *KEMSecretKey, ct *KEMCiphertext) ([]byte, error) {
 257  	if ct == nil || ct.U == nil || ct.V == nil {
 258  		return nil, errors.New("ring/kem: nil ciphertext")
 259  	}
 260  
 261  	// 1. Decrypt to recover m'.
 262  	mPrime := cpaPKEDecrypt(sk, ct)
 263  
 264  	// 2. Re-derive (K', coins') = G(m', H(pk)).
 265  	pkHash := hashPublicKey(sk.PK)
 266  	KPrime, coinsPrime := deriveKCoins(mPrime, pkHash)
 267  
 268  	// 3. Re-encrypt with coins' to get ct'.
 269  	ctPrime := cpaPKEEncrypt(sk.PK, mPrime, coinsPrime)
 270  
 271  	// 4. Check ct == ct' (constant-time).
 272  	match := ciphertextEqual(ct, ctPrime)
 273  
 274  	// 5. If match, shared key = KDF(K', H(ct)).
 275  	//    If no match, shared key = KDF(z, H(ct)) (implicit rejection).
 276  	ctHash := hashCiphertext(ct)
 277  	realKey := kdf(KPrime, ctHash, sk.PK.P.SharedKeyLen)
 278  	rejectKey := kdf(sk.Z, ctHash, sk.PK.P.SharedKeyLen)
 279  
 280  	// Constant-time select.
 281  	sharedKey := make([]byte, sk.PK.P.SharedKeyLen)
 282  	subtle.ConstantTimeCopy(match, sharedKey, realKey)
 283  	subtle.ConstantTimeCopy(1-match, sharedKey, rejectKey)
 284  
 285  	return sharedKey, nil
 286  }
 287  
 288  // --- internal helpers ---
 289  
 290  func hashPublicKey(pk *KEMPublicKey) []byte {
 291  	h := sha3.NewShake256()
 292  	h.Write([]byte("hamadryad-kem-pk"))
 293  	h.Write(Serialize(pk.A))
 294  	h.Write(Serialize(pk.B))
 295  	out := make([]byte, 32)
 296  	h.Read(out)
 297  	return out
 298  }
 299  
 300  func hashCiphertext(ct *KEMCiphertext) []byte {
 301  	h := sha3.NewShake256()
 302  	h.Write([]byte("hamadryad-kem-ct"))
 303  	h.Write(Serialize(ct.U))
 304  	h.Write(Serialize(ct.V))
 305  	out := make([]byte, 32)
 306  	h.Read(out)
 307  	return out
 308  }
 309  
 310  func deriveKCoins(m, pkHash []byte) (K, coins []byte) {
 311  	h := sha3.NewShake256()
 312  	h.Write([]byte("hamadryad-kem-g"))
 313  	h.Write(m)
 314  	h.Write(pkHash)
 315  	out := make([]byte, 64)
 316  	h.Read(out)
 317  	return out[:32], out[32:]
 318  }
 319  
 320  func kdf(key, label []byte, outLen int) []byte {
 321  	h := sha3.NewShake256()
 322  	h.Write([]byte("hamadryad-kem-kdf"))
 323  	h.Write(key)
 324  	h.Write(label)
 325  	out := make([]byte, outLen)
 326  	h.Read(out)
 327  	return out
 328  }
 329  
 330  func ciphertextEqual(a, b *KEMCiphertext) int {
 331  	aU := Serialize(a.U)
 332  	bU := Serialize(b.U)
 333  	aV := Serialize(a.V)
 334  	bV := Serialize(b.V)
 335  
 336  	if len(aU) != len(bU) || len(aV) != len(bV) {
 337  		return 0
 338  	}
 339  
 340  	return subtle.ConstantTimeCompare(aU, bU) & subtle.ConstantTimeCompare(aV, bV)
 341  }
 342