package ring import ( "bytes" "testing" ) // TestKEMRoundTrip verifies encapsulate/decapsulate produce the same shared key. func TestKEMRoundTrip(t *testing.T) { kp := DefaultKEMParams() pk, sk := KEMKeyGen(kp) sharedKeyEnc, ct, err := Encapsulate(pk) if err != nil { t.Fatalf("Encapsulate: %v", err) } sharedKeyDec, err := Decapsulate(sk, ct) if err != nil { t.Fatalf("Decapsulate: %v", err) } if !bytes.Equal(sharedKeyEnc, sharedKeyDec) { t.Fatalf("shared keys don't match:\n enc: %x\n dec: %x", sharedKeyEnc, sharedKeyDec) } if len(sharedKeyEnc) != kp.SharedKeyLen { t.Fatalf("shared key length: got %d, want %d", len(sharedKeyEnc), kp.SharedKeyLen) } } // TestKEMMultipleRoundTrips verifies consistency across multiple key pairs. func TestKEMMultipleRoundTrips(t *testing.T) { kp := DefaultKEMParams() for i := range 10 { pk, sk := KEMKeyGen(kp) sharedKeyEnc, ct, err := Encapsulate(pk) if err != nil { t.Fatalf("round %d: Encapsulate: %v", i, err) } sharedKeyDec, err := Decapsulate(sk, ct) if err != nil { t.Fatalf("round %d: Decapsulate: %v", i, err) } if !bytes.Equal(sharedKeyEnc, sharedKeyDec) { t.Fatalf("round %d: shared keys don't match", i) } } } // TestKEMTamperedCiphertext verifies that a modified ciphertext produces // a different shared key (implicit rejection, not an error). func TestKEMTamperedCiphertext(t *testing.T) { kp := DefaultKEMParams() pk, sk := KEMKeyGen(kp) sharedKeyEnc, ct, err := Encapsulate(pk) if err != nil { t.Fatal(err) } // Tamper: flip a coefficient in U. ct.U.Coeffs[0] = (ct.U.Coeffs[0] + 1) % kp.Ring.Q sharedKeyDec, err := Decapsulate(sk, ct) if err != nil { t.Fatal(err) } // The decapsulated key should be different (implicit rejection). if bytes.Equal(sharedKeyEnc, sharedKeyDec) { t.Fatal("tampered ciphertext produced same shared key — FO transform not working") } } // TestKEMDifferentKeys verifies two key pairs produce different shared keys. func TestKEMDifferentKeys(t *testing.T) { kp := DefaultKEMParams() pk1, _ := KEMKeyGen(kp) _, sk2 := KEMKeyGen(kp) _, ct, err := Encapsulate(pk1) if err != nil { t.Fatal(err) } // Decapsulate with wrong secret key — should produce rejection key. sharedKeyWrong, err := Decapsulate(sk2, ct) if err != nil { t.Fatal(err) } // The key should be deterministically derived but wrong. if len(sharedKeyWrong) != kp.SharedKeyLen { t.Fatalf("wrong key length: got %d, want %d", len(sharedKeyWrong), kp.SharedKeyLen) } } // TestCPAEncryptDecrypt verifies the inner CPA scheme directly. func TestCPAEncryptDecrypt(t *testing.T) { kp := DefaultKEMParams() pk, sk := KEMKeyGen(kp) msg := make([]byte, 32) for i := range msg { msg[i] = byte(i) } coins := make([]byte, 32) for i := range coins { coins[i] = byte(i + 100) } ct := cpaPKEEncrypt(pk, msg, coins) recovered := cpaPKEDecrypt(sk, ct) if !bytes.Equal(msg, recovered) { t.Fatalf("CPA decrypt failed:\n sent: %x\n recv: %x", msg, recovered) } } // TestEncodeDecodeMessage verifies the message encoding round-trip. func TestEncodeDecodeMessage(t *testing.T) { p := Falcon512() msg := make([]byte, 32) for i := range msg { msg[i] = byte(0xAA) // alternating bits } encoded := encodeMessage(p, msg) decoded := decodeMessage(encoded) if !bytes.Equal(msg, decoded[:32]) { t.Fatalf("encode/decode failed:\n sent: %x\n recv: %x", msg, decoded[:32]) } } func BenchmarkKEMKeyGen(b *testing.B) { kp := DefaultKEMParams() for range b.N { KEMKeyGen(kp) } } func BenchmarkKEMEncapsulate(b *testing.B) { kp := DefaultKEMParams() pk, _ := KEMKeyGen(kp) b.ResetTimer() for range b.N { Encapsulate(pk) } } func BenchmarkKEMDecapsulate(b *testing.B) { kp := DefaultKEMParams() pk, sk := KEMKeyGen(kp) _, ct, _ := Encapsulate(pk) b.ResetTimer() for range b.N { Decapsulate(sk, ct) } }