kem_test.go raw
1 package ring
2
3 import (
4 "bytes"
5 "testing"
6 )
7
8 // TestKEMRoundTrip verifies encapsulate/decapsulate produce the same shared key.
9 func TestKEMRoundTrip(t *testing.T) {
10 kp := DefaultKEMParams()
11 pk, sk := KEMKeyGen(kp)
12
13 sharedKeyEnc, ct, err := Encapsulate(pk)
14 if err != nil {
15 t.Fatalf("Encapsulate: %v", err)
16 }
17
18 sharedKeyDec, err := Decapsulate(sk, ct)
19 if err != nil {
20 t.Fatalf("Decapsulate: %v", err)
21 }
22
23 if !bytes.Equal(sharedKeyEnc, sharedKeyDec) {
24 t.Fatalf("shared keys don't match:\n enc: %x\n dec: %x", sharedKeyEnc, sharedKeyDec)
25 }
26
27 if len(sharedKeyEnc) != kp.SharedKeyLen {
28 t.Fatalf("shared key length: got %d, want %d", len(sharedKeyEnc), kp.SharedKeyLen)
29 }
30 }
31
32 // TestKEMMultipleRoundTrips verifies consistency across multiple key pairs.
33 func TestKEMMultipleRoundTrips(t *testing.T) {
34 kp := DefaultKEMParams()
35 for i := range 10 {
36 pk, sk := KEMKeyGen(kp)
37 sharedKeyEnc, ct, err := Encapsulate(pk)
38 if err != nil {
39 t.Fatalf("round %d: Encapsulate: %v", i, err)
40 }
41 sharedKeyDec, err := Decapsulate(sk, ct)
42 if err != nil {
43 t.Fatalf("round %d: Decapsulate: %v", i, err)
44 }
45 if !bytes.Equal(sharedKeyEnc, sharedKeyDec) {
46 t.Fatalf("round %d: shared keys don't match", i)
47 }
48 }
49 }
50
51 // TestKEMTamperedCiphertext verifies that a modified ciphertext produces
52 // a different shared key (implicit rejection, not an error).
53 func TestKEMTamperedCiphertext(t *testing.T) {
54 kp := DefaultKEMParams()
55 pk, sk := KEMKeyGen(kp)
56
57 sharedKeyEnc, ct, err := Encapsulate(pk)
58 if err != nil {
59 t.Fatal(err)
60 }
61
62 // Tamper: flip a coefficient in U.
63 ct.U.Coeffs[0] = (ct.U.Coeffs[0] + 1) % kp.Ring.Q
64
65 sharedKeyDec, err := Decapsulate(sk, ct)
66 if err != nil {
67 t.Fatal(err)
68 }
69
70 // The decapsulated key should be different (implicit rejection).
71 if bytes.Equal(sharedKeyEnc, sharedKeyDec) {
72 t.Fatal("tampered ciphertext produced same shared key — FO transform not working")
73 }
74 }
75
76 // TestKEMDifferentKeys verifies two key pairs produce different shared keys.
77 func TestKEMDifferentKeys(t *testing.T) {
78 kp := DefaultKEMParams()
79 pk1, _ := KEMKeyGen(kp)
80 _, sk2 := KEMKeyGen(kp)
81
82 _, ct, err := Encapsulate(pk1)
83 if err != nil {
84 t.Fatal(err)
85 }
86
87 // Decapsulate with wrong secret key — should produce rejection key.
88 sharedKeyWrong, err := Decapsulate(sk2, ct)
89 if err != nil {
90 t.Fatal(err)
91 }
92
93 // The key should be deterministically derived but wrong.
94 if len(sharedKeyWrong) != kp.SharedKeyLen {
95 t.Fatalf("wrong key length: got %d, want %d", len(sharedKeyWrong), kp.SharedKeyLen)
96 }
97 }
98
99 // TestCPAEncryptDecrypt verifies the inner CPA scheme directly.
100 func TestCPAEncryptDecrypt(t *testing.T) {
101 kp := DefaultKEMParams()
102 pk, sk := KEMKeyGen(kp)
103
104 msg := make([]byte, 32)
105 for i := range msg {
106 msg[i] = byte(i)
107 }
108
109 coins := make([]byte, 32)
110 for i := range coins {
111 coins[i] = byte(i + 100)
112 }
113
114 ct := cpaPKEEncrypt(pk, msg, coins)
115 recovered := cpaPKEDecrypt(sk, ct)
116
117 if !bytes.Equal(msg, recovered) {
118 t.Fatalf("CPA decrypt failed:\n sent: %x\n recv: %x", msg, recovered)
119 }
120 }
121
122 // TestEncodeDecodeMessage verifies the message encoding round-trip.
123 func TestEncodeDecodeMessage(t *testing.T) {
124 p := Falcon512()
125 msg := make([]byte, 32)
126 for i := range msg {
127 msg[i] = byte(0xAA) // alternating bits
128 }
129
130 encoded := encodeMessage(p, msg)
131 decoded := decodeMessage(encoded)
132
133 if !bytes.Equal(msg, decoded[:32]) {
134 t.Fatalf("encode/decode failed:\n sent: %x\n recv: %x", msg, decoded[:32])
135 }
136 }
137
138 func BenchmarkKEMKeyGen(b *testing.B) {
139 kp := DefaultKEMParams()
140 for range b.N {
141 KEMKeyGen(kp)
142 }
143 }
144
145 func BenchmarkKEMEncapsulate(b *testing.B) {
146 kp := DefaultKEMParams()
147 pk, _ := KEMKeyGen(kp)
148 b.ResetTimer()
149 for range b.N {
150 Encapsulate(pk)
151 }
152 }
153
154 func BenchmarkKEMDecapsulate(b *testing.B) {
155 kp := DefaultKEMParams()
156 pk, sk := KEMKeyGen(kp)
157 _, ct, _ := Encapsulate(pk)
158 b.ResetTimer()
159 for range b.N {
160 Decapsulate(sk, ct)
161 }
162 }
163