ring_test.go raw
1 package ring
2
3 import (
4 "crypto/rand"
5 "math/big"
6 "testing"
7 )
8
9 // TestNTTRoundTrip verifies NTT(INTT(a)) == a and INTT(NTT(a)) == a.
10 func TestNTTRoundTrip(t *testing.T) {
11 for _, p := range []Params{Falcon512(), NewHope256()} {
12 t.Run(paramName(p), func(t *testing.T) {
13 a := randomPoly(p)
14 orig := a.Clone()
15
16 NTT(a)
17 if Equal(a, orig) {
18 t.Fatal("NTT should change coefficients")
19 }
20
21 INTT(a)
22 if !Equal(a, orig) {
23 t.Fatal("INTT(NTT(a)) != a")
24 for i := range a.Coeffs {
25 if a.Coeffs[i] != orig.Coeffs[i] {
26 t.Logf(" diff at %d: got %d, want %d", i, a.Coeffs[i], orig.Coeffs[i])
27 if i > 10 {
28 break
29 }
30 }
31 }
32 }
33 })
34 }
35 }
36
37 // TestMulSchoolbook verifies NTT-based multiplication matches schoolbook
38 // multiplication in Z_q[x]/(x^n + 1) for small polynomials.
39 func TestMulSchoolbook(t *testing.T) {
40 for _, p := range []Params{Falcon512(), NewHope256()} {
41 t.Run(paramName(p), func(t *testing.T) {
42 a := randomSmallPoly(p, 10)
43 b := randomSmallPoly(p, 10)
44
45 // NTT multiplication.
46 cNTT := Mul(a, b)
47
48 // Schoolbook multiplication mod (x^n + 1).
49 cSchool := schoolbookMul(a, b)
50
51 if !Equal(cNTT, cSchool) {
52 diffs := 0
53 for i := range cNTT.Coeffs {
54 if cNTT.Coeffs[i] != cSchool.Coeffs[i] {
55 if diffs < 10 {
56 t.Logf(" diff at %d: NTT=%d school=%d", i, cNTT.Coeffs[i], cSchool.Coeffs[i])
57 }
58 diffs++
59 }
60 }
61 t.Fatalf("NTT mul != schoolbook mul (%d diffs)", diffs)
62 }
63 })
64 }
65 }
66
67 // TestAddSub verifies Add and Sub are inverses.
68 func TestAddSub(t *testing.T) {
69 p := Falcon512()
70 a := randomPoly(p)
71 b := randomPoly(p)
72
73 c := Add(a, b)
74 d := Sub(c, b)
75 if !Equal(d, a) {
76 t.Fatal("(a + b) - b != a")
77 }
78 }
79
80 // TestNeg verifies a + (-a) == 0.
81 func TestNeg(t *testing.T) {
82 p := Falcon512()
83 a := randomPoly(p)
84 b := Neg(a)
85 c := Add(a, b)
86
87 for i := range c.Coeffs {
88 if c.Coeffs[i] != 0 {
89 t.Fatalf("a + (-a) != 0 at index %d: got %d", i, c.Coeffs[i])
90 }
91 }
92 }
93
94 // TestScalarMul verifies scalar multiplication.
95 func TestScalarMul(t *testing.T) {
96 p := Falcon512()
97 a := randomPoly(p)
98
99 // a * 0 == 0
100 z := ScalarMul(a, 0)
101 for i := range z.Coeffs {
102 if z.Coeffs[i] != 0 {
103 t.Fatal("a * 0 != 0")
104 }
105 }
106
107 // a * 1 == a
108 one := ScalarMul(a, 1)
109 if !Equal(one, a) {
110 t.Fatal("a * 1 != a")
111 }
112
113 // a * 2 == a + a
114 two := ScalarMul(a, 2)
115 sum := Add(a, a)
116 if !Equal(two, sum) {
117 t.Fatal("a * 2 != a + a")
118 }
119 }
120
121 // TestMulCommutativity verifies a * b == b * a.
122 func TestMulCommutativity(t *testing.T) {
123 p := Falcon512()
124 a := randomSmallPoly(p, 100)
125 b := randomSmallPoly(p, 100)
126
127 ab := Mul(a, b)
128 ba := Mul(b, a)
129
130 if !Equal(ab, ba) {
131 t.Fatal("a * b != b * a")
132 }
133 }
134
135 // TestMulDistributivity verifies a*(b+c) == a*b + a*c.
136 func TestMulDistributivity(t *testing.T) {
137 p := NewHope256()
138 a := randomSmallPoly(p, 50)
139 b := randomSmallPoly(p, 50)
140 c := randomSmallPoly(p, 50)
141
142 bc := Add(b, c)
143 lhs := Mul(a, bc)
144
145 ab := Mul(a, b)
146 ac := Mul(a, c)
147 rhs := Add(ab, ac)
148
149 if !Equal(lhs, rhs) {
150 t.Fatal("a*(b+c) != a*b + a*c")
151 }
152 }
153
154 // TestQInv verifies the Montgomery constant.
155 func TestQInv(t *testing.T) {
156 for _, p := range []Params{Falcon512(), NewHope256()} {
157 t.Run(paramName(p), func(t *testing.T) {
158 // qinv * q ≡ -1 (mod 2^16)
159 // Equivalently: (qinv * q + 1) mod 2^16 == 0
160 product := p.QInv * uint64(p.Q)
161 mask := (uint64(1) << 16) - 1
162 if (product+1)&mask != 0 {
163 t.Fatalf("QInv check failed: QInv=%d, Q=%d, (QInv*Q+1) mod 2^16 = %d",
164 p.QInv, p.Q, (product+1)&mask)
165 }
166 })
167 }
168 }
169
170 func BenchmarkNTT(b *testing.B) {
171 for _, p := range []Params{Falcon512(), NewHope256()} {
172 b.Run(paramName(p), func(b *testing.B) {
173 poly := randomPoly(p)
174 b.ResetTimer()
175 for range b.N {
176 a := poly.Clone()
177 NTT(a)
178 }
179 })
180 }
181 }
182
183 func BenchmarkINTT(b *testing.B) {
184 for _, p := range []Params{Falcon512(), NewHope256()} {
185 b.Run(paramName(p), func(b *testing.B) {
186 poly := randomPoly(p)
187 NTT(poly)
188 b.ResetTimer()
189 for range b.N {
190 a := poly.Clone()
191 INTT(a)
192 }
193 })
194 }
195 }
196
197 func BenchmarkMul(b *testing.B) {
198 for _, p := range []Params{Falcon512(), NewHope256()} {
199 b.Run(paramName(p), func(b *testing.B) {
200 a := randomPoly(p)
201 bp2 := randomPoly(p)
202 b.ResetTimer()
203 for range b.N {
204 Mul(a, bp2)
205 }
206 })
207 }
208 }
209
210 // --- helpers ---
211
212 func paramName(p Params) string {
213 switch {
214 case p.N == 512 && p.Q == 12289:
215 return "Falcon512"
216 case p.N == 256 && p.Q == 7681:
217 return "NewHope256"
218 default:
219 return "Custom"
220 }
221 }
222
223 func randomPoly(p Params) *Poly {
224 a := New(p)
225 buf := make([]byte, 4)
226 for i := range a.Coeffs {
227 rand.Read(buf)
228 v := uint32(buf[0]) | uint32(buf[1])<<8
229 a.Coeffs[i] = v % p.Q
230 }
231 return a
232 }
233
234 func randomSmallPoly(p Params, bound uint32) *Poly {
235 a := New(p)
236 buf := make([]byte, 2)
237 for i := range a.Coeffs {
238 rand.Read(buf)
239 v := uint32(buf[0]) | uint32(buf[1])<<8
240 a.Coeffs[i] = v % bound
241 }
242 return a
243 }
244
245 // schoolbookMul computes a * b mod (x^n + 1) using naive O(n²) multiplication.
246 func schoolbookMul(a, b *Poly) *Poly {
247 n := a.params.N
248 q := int64(a.params.Q)
249
250 // Use big.Int for exact intermediate values.
251 result := make([]int64, n)
252 for i := 0; i < n; i++ {
253 for j := 0; j < n; j++ {
254 prod := int64(a.Coeffs[i]) * int64(b.Coeffs[j])
255 k := i + j
256 if k < n {
257 result[k] = (result[k] + prod) % q
258 } else {
259 // x^n ≡ -1, so x^{n+r} ≡ -x^r
260 result[k-n] = (result[k-n] - prod) % q
261 }
262 }
263 }
264
265 c := New(a.params)
266 for i := range result {
267 v := result[i] % q
268 if v < 0 {
269 v += q
270 }
271 c.Coeffs[i] = uint32(v)
272 }
273 return c
274 }
275
276 // Verify that our parameters are correct by checking root of unity properties.
277 func TestRootOfUnity(t *testing.T) {
278 for _, p := range []Params{Falcon512(), NewHope256()} {
279 t.Run(paramName(p), func(t *testing.T) {
280 psi := big.NewInt(int64(p.RootOfUnity))
281 q := big.NewInt(int64(p.Q))
282 twoN := big.NewInt(int64(2 * p.N))
283
284 // psi^{2n} ≡ 1 (mod q)
285 order := new(big.Int).Exp(psi, twoN, q)
286 if order.Int64() != 1 {
287 t.Fatalf("psi^{2n} = %d, want 1 (mod %d)", order, p.Q)
288 }
289
290 // psi^n ≡ -1 (mod q) [negacyclic property]
291 nBig := big.NewInt(int64(p.N))
292 half := new(big.Int).Exp(psi, nBig, q)
293 want := new(big.Int).Sub(q, big.NewInt(1)) // q-1 ≡ -1
294 if half.Cmp(want) != 0 {
295 t.Fatalf("psi^n = %d, want %d (-1 mod %d)", half, want, p.Q)
296 }
297 })
298 }
299 }
300