he_test.go raw
1 package ring
2
3 import "testing"
4
5 // TestHEEncryptDecrypt verifies single-bit encrypt/decrypt.
6 func TestHEEncryptDecrypt(t *testing.T) {
7 kp := DefaultHEParams()
8 pk, sk, _ := HEKeyGen(kp)
9
10 for _, bit := range []int{0, 1} {
11 ct := HEEncrypt(pk, bit)
12 got := HEDecrypt(sk, ct)
13 if got != bit {
14 t.Fatalf("bit=%d: decrypted to %d", bit, got)
15 }
16 }
17 }
18
19 // TestHEAdd verifies homomorphic addition (XOR).
20 func TestHEAdd(t *testing.T) {
21 kp := DefaultHEParams()
22 pk, sk, _ := HEKeyGen(kp)
23
24 // Test all XOR combinations.
25 for _, a := range []int{0, 1} {
26 for _, b := range []int{0, 1} {
27 ctA := HEEncrypt(pk, a)
28 ctB := HEEncrypt(pk, b)
29 ctSum := HEAdd(ctA, ctB)
30 got := HEDecrypt(sk, ctSum)
31 want := a ^ b
32 if got != want {
33 t.Fatalf("%d XOR %d: got %d, want %d", a, b, got, want)
34 }
35 }
36 }
37 }
38
39 // TestHENot verifies homomorphic NOT.
40 func TestHENot(t *testing.T) {
41 kp := DefaultHEParams()
42 pk, sk, _ := HEKeyGen(kp)
43
44 for _, bit := range []int{0, 1} {
45 ct := HEEncrypt(pk, bit)
46 ctNot := HENot(ct)
47 got := HEDecrypt(sk, ctNot)
48 want := 1 - bit
49 if got != want {
50 t.Fatalf("NOT(%d): got %d, want %d", bit, got, want)
51 }
52 }
53 }
54
55 // TestHEMul verifies homomorphic multiplication (AND).
56 func TestHEMul(t *testing.T) {
57 kp := DefaultHEParams()
58 pk, sk, rlk := HEKeyGen(kp)
59
60 for _, a := range []int{0, 1} {
61 for _, b := range []int{0, 1} {
62 ctA := HEEncrypt(pk, a)
63 ctB := HEEncrypt(pk, b)
64 ctProd := HEMul(ctA, ctB, rlk)
65 got := HEDecrypt(sk, ctProd)
66 want := a & b
67 if got != want {
68 t.Fatalf("%d AND %d: got %d, want %d", a, b, got, want)
69 }
70 }
71 }
72 }
73
74 // TestHERerandomize verifies that rerandomization preserves the plaintext.
75 func TestHERerandomize(t *testing.T) {
76 kp := DefaultHEParams()
77 pk, sk, _ := HEKeyGen(kp)
78
79 for _, bit := range []int{0, 1} {
80 ct := HEEncrypt(pk, bit)
81 ctRerand := Rerandomize(pk, ct)
82
83 got := HEDecrypt(sk, ctRerand)
84 if got != bit {
85 t.Fatalf("rerandomize(%d): decrypted to %d", bit, got)
86 }
87
88 // Verify ciphertexts are actually different.
89 if Equal(ct.U, ctRerand.U) && Equal(ct.V, ctRerand.V) {
90 t.Fatal("rerandomized ciphertext is identical — randomness failure")
91 }
92 }
93 }
94
95 // TestHEChainedAdd verifies multiple additions stay within noise budget.
96 func TestHEChainedAdd(t *testing.T) {
97 kp := DefaultHEParams()
98 pk, sk, _ := HEKeyGen(kp)
99
100 // XOR 16 bits together: result should be 0 (even count of 1s) or 1 (odd).
101 bits := []int{1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0}
102 expected := 0
103 for _, b := range bits {
104 expected ^= b
105 }
106
107 acc := HEEncrypt(pk, bits[0])
108 for _, b := range bits[1:] {
109 ct := HEEncrypt(pk, b)
110 acc = HEAdd(acc, ct)
111 }
112
113 got := HEDecrypt(sk, acc)
114 if got != expected {
115 t.Fatalf("chained XOR: got %d, want %d (noise estimate: %.1f)", got, expected, acc.NoiseEstimate)
116 }
117 }
118
119 // TestHEBGVNoiseParity verifies that BGV encryption produces even noise + message.
120 func TestHEBGVNoiseParity(t *testing.T) {
121 kp := DefaultHEParams()
122 pk, sk, _ := HEKeyGen(kp)
123 p := kp.Ring
124
125 for _, bit := range []int{0, 1} {
126 ct := HEEncrypt(pk, bit)
127
128 // Compute v - s·u
129 uNTT := ct.U.Clone()
130 NTT(uNTT)
131 su := MulPointwise(sk.S, uNTT)
132 INTT(su)
133 noisy := Sub(ct.V, su)
134
135 c := noisy.Coeffs[0]
136 q := p.Q
137 if c > q/2 {
138 c = q - c
139 }
140 parity := c % 2
141
142 t.Logf("bit=%d: coeff[0]=%d, centered=%d, parity=%d (want parity=%d)",
143 bit, noisy.Coeffs[0], c, parity, bit)
144
145 if int(parity) != bit {
146 t.Errorf("bit=%d: wrong parity %d", bit, parity)
147 }
148 }
149 }
150
151 // TestHEMulDeg2 tests the tensor product without relinearization.
152 // This isolates whether the issue is in the tensor product or relinearization.
153 func TestHEMulDeg2(t *testing.T) {
154 kp := DefaultHEParams()
155 pk, sk, _ := HEKeyGen(kp)
156 p := kp.Ring
157
158 for _, a := range []int{0, 1} {
159 for _, b := range []int{0, 1} {
160 ctA := HEEncrypt(pk, a)
161 ctB := HEEncrypt(pk, b)
162
163 // Tensor product: c0 = v_a * v_b, c1 = v_a*u_b + u_a*v_b, c2 = u_a * u_b
164 c0 := Mul(ctA.V, ctB.V)
165 c1a := Mul(ctA.V, ctB.U)
166 c1b := Mul(ctA.U, ctB.V)
167 c1 := Add(c1a, c1b)
168 c2 := Mul(ctA.U, ctB.U)
169
170 // Degree-2 decryption: c0 - s*c1 + s²*c2
171 // s is in NTT form
172 sCoeff := sk.S.Clone()
173
174 c1NTT := c1.Clone()
175 NTT(c1NTT)
176 sc1 := MulPointwise(sCoeff, c1NTT)
177 INTT(sc1)
178
179 // s²
180 s2 := MulPointwise(sCoeff, sCoeff)
181 c2NTT := c2.Clone()
182 NTT(c2NTT)
183 s2c2 := MulPointwise(s2, c2NTT)
184 INTT(s2c2)
185
186 result := Sub(c0, sc1)
187 result = Add(result, s2c2)
188
189 // Decode: result[0] mod 2
190 q := p.Q
191 c := result.Coeffs[0]
192 if c > q/2 {
193 c = q - c
194 }
195 got := int(c % 2)
196 want := a & b
197
198 t.Logf("%d AND %d: coeff[0]=%d, centered=%d, mod2=%d (want %d)",
199 a, b, result.Coeffs[0], c, got, want)
200
201 if got != want {
202 t.Errorf("%d AND %d: got %d, want %d", a, b, got, want)
203 }
204 }
205 }
206 }
207
208 func BenchmarkHEEncrypt(b *testing.B) {
209 kp := DefaultHEParams()
210 pk, _, _ := HEKeyGen(kp)
211 b.ResetTimer()
212 for range b.N {
213 HEEncrypt(pk, 1)
214 }
215 }
216
217 func BenchmarkHEAdd(b *testing.B) {
218 kp := DefaultHEParams()
219 pk, _, _ := HEKeyGen(kp)
220 ct0 := HEEncrypt(pk, 0)
221 ct1 := HEEncrypt(pk, 1)
222 b.ResetTimer()
223 for range b.N {
224 HEAdd(ct0, ct1)
225 }
226 }
227
228 func BenchmarkHEMul(b *testing.B) {
229 kp := DefaultHEParams()
230 pk, _, rlk := HEKeyGen(kp)
231 ct0 := HEEncrypt(pk, 1)
232 ct1 := HEEncrypt(pk, 1)
233 b.ResetTimer()
234 for range b.N {
235 HEMul(ct0, ct1, rlk)
236 }
237 }
238