ntt_test.go raw
1 package crypto
2
3 import "testing"
4
5 func TestBitRev6(t *testing.T) {
6 // 6-bit reversal: 0b000001 -> 0b100000 = 32
7 if got := bitRev6(1); got != 32 {
8 t.Errorf("bitRev6(1) = %d, want 32", got)
9 }
10 // Identity: 0 -> 0, 63 -> 63
11 if got := bitRev6(0); got != 0 {
12 t.Errorf("bitRev6(0) = %d, want 0", got)
13 }
14 if got := bitRev6(63); got != 63 {
15 t.Errorf("bitRev6(63) = %d, want 63", got)
16 }
17 // 0b010101 = 21 -> 0b101010 = 42
18 if got := bitRev6(21); got != 42 {
19 t.Errorf("bitRev6(21) = %d, want 42", got)
20 }
21 }
22
23 func TestNTTRoundTrip(t *testing.T) {
24 // Create a polynomial with known coefficients.
25 var a [HamN]uint16
26 for i := range a {
27 a[i] = uint16(i % HamP)
28 }
29 original := a
30
31 // Forward NTT then inverse should recover original.
32 ntt64(&a)
33 intt64(&a)
34
35 for i := range a {
36 if a[i] != original[i] {
37 t.Errorf("round-trip failed at [%d]: got %d, want %d", i, a[i], original[i])
38 }
39 }
40 }
41
42 func TestNTTZero(t *testing.T) {
43 var a [HamN]uint16
44 ntt64(&a)
45 for i, v := range a {
46 if v != 0 {
47 t.Errorf("NTT of zero poly: a[%d] = %d, want 0", i, v)
48 }
49 }
50 }
51
52 func TestNTTConvolution(t *testing.T) {
53 // Verify that pointwise multiply in NTT domain corresponds to
54 // polynomial multiplication mod (x^64 + 1) mod 257.
55
56 // f(x) = 1 + x
57 var f [HamN]uint16
58 f[0] = 1
59 f[1] = 1
60
61 // g(x) = 1 + x
62 var g [HamN]uint16
63 g[0] = 1
64 g[1] = 1
65
66 // f*g mod (x^64+1) = 1 + 2x + x^2 (since degree < 64, no wraparound).
67 ntt64(&f)
68 ntt64(&g)
69
70 var h [HamN]uint16
71 for i := range h {
72 h[i] = mod257(int(f[i]) * int(g[i]))
73 }
74
75 intt64(&h)
76
77 // Expected: h[0]=1, h[1]=2, h[2]=1, rest 0.
78 if h[0] != 1 {
79 t.Errorf("h[0] = %d, want 1", h[0])
80 }
81 if h[1] != 2 {
82 t.Errorf("h[1] = %d, want 2", h[1])
83 }
84 if h[2] != 1 {
85 t.Errorf("h[2] = %d, want 1", h[2])
86 }
87 for i := 3; i < HamN; i++ {
88 if h[i] != 0 {
89 t.Errorf("h[%d] = %d, want 0", i, h[i])
90 }
91 }
92 }
93
94 func TestNTTWraparound(t *testing.T) {
95 // Test the negacyclic property: x^64 ≡ -1 mod (x^64 + 1).
96 // f(x) = x^63, g(x) = x → f*g = x^64 ≡ -1 ≡ 256 mod 257.
97
98 var f [HamN]uint16
99 f[63] = 1
100
101 var g [HamN]uint16
102 g[1] = 1
103
104 ntt64(&f)
105 ntt64(&g)
106
107 var h [HamN]uint16
108 for i := range h {
109 h[i] = mod257(int(f[i]) * int(g[i]))
110 }
111
112 intt64(&h)
113
114 // Expected: h[0] = 256 (= -1 mod 257), rest 0.
115 if h[0] != 256 {
116 t.Errorf("h[0] = %d, want 256 (-1 mod 257)", h[0])
117 }
118 for i := 1; i < HamN; i++ {
119 if h[i] != 0 {
120 t.Errorf("h[%d] = %d, want 0", i, h[i])
121 }
122 }
123 }
124
125 func TestPowMod(t *testing.T) {
126 // 3 is a generator of Z_257*. 3^256 ≡ 1 mod 257 (Fermat).
127 if got := powMod(3, 256); got != 1 {
128 t.Errorf("3^256 mod 257 = %d, want 1", got)
129 }
130 // 3^128 ≡ 256 ≡ -1 mod 257 (since order is 256, half-order gives -1).
131 if got := powMod(3, 128); got != 256 {
132 t.Errorf("3^128 mod 257 = %d, want 256", got)
133 }
134 // 9^64 = (3^2)^64 = 3^128 ≡ -1 mod 257.
135 if got := powMod(9, 64); got != 256 {
136 t.Errorf("9^64 mod 257 = %d, want 256", got)
137 }
138 }
139