ntt.go raw
1 package ring
2
3 // Number-Theoretic Transform for R_q = Z_q[x]/(x^n + 1).
4 //
5 // The negacyclic NTT evaluates a polynomial at the roots of x^n + 1,
6 // which are psi^(2k+1) for k = 0..n-1, where psi is a primitive 2n-th
7 // root of unity mod q.
8 //
9 // Implementation uses Cooley-Tukey (forward) and Gentleman-Sande (inverse)
10 // butterfly structures with fully precomputed twiddle factors.
11 // Zero allocations in the hot path.
12 //
13 // The NTT enables O(n log n) polynomial multiplication:
14 // Mul(a, b) = INTT(NTT(a) ⊙ NTT(b))
15 // where ⊙ is pointwise multiplication.
16
17 // nttTables holds precomputed twiddle factors for a specific parameter set.
18 type nttTables struct {
19 n int
20 q uint32
21 logN int
22
23 // psiPow[i] = psi^i mod q for i in [0, 2n).
24 // Used for pre-multiplication in forward NTT.
25 psiPow []uint32
26
27 // psiInvPow[i] = psi^{-i} mod q for i in [0, 2n).
28 // Used for post-multiplication in inverse NTT.
29 psiInvPow []uint32
30
31 // omegaPow[i] = omega^i mod q where omega = psi^2.
32 // Twiddle factors for forward butterfly.
33 omegaPow []uint32
34
35 // omegaInvPow[i] = omega^{-i} mod q.
36 // Twiddle factors for inverse butterfly.
37 omegaInvPow []uint32
38
39 // bitrevPerm[i] = bit-reversal of i.
40 bitrevPerm []int
41
42 // invN = n^{-1} mod q, for the 1/n scaling in the inverse transform.
43 invN uint32
44 }
45
46 // tableCache maps (N, Q) → precomputed tables. Populated lazily.
47 var tableCache = make(map[[2]uint32]*nttTables)
48
49 // getTables returns (possibly cached) NTT tables for the given parameters.
50 func getTables(p Params) *nttTables {
51 key := [2]uint32{uint32(p.N), p.Q}
52 if t, ok := tableCache[key]; ok {
53 return t
54 }
55 t := newNTTTables(p)
56 tableCache[key] = t
57 return t
58 }
59
60 // newNTTTables precomputes all twiddle factors for zero-alloc NTT.
61 func newNTTTables(p Params) *nttTables {
62 n := p.N
63 q := p.Q
64 psi := p.RootOfUnity
65 logN := log2(n)
66
67 // Powers of psi: psi^0 .. psi^{2n-1}.
68 psiPow := make([]uint32, 2*n)
69 psiPow[0] = 1
70 for i := 1; i < 2*n; i++ {
71 psiPow[i] = mulMod(psiPow[i-1], psi, q)
72 }
73
74 // Powers of psi^{-1}.
75 psiInv := powMod(psi, q-2, q)
76 psiInvPow := make([]uint32, 2*n)
77 psiInvPow[0] = 1
78 for i := 1; i < 2*n; i++ {
79 psiInvPow[i] = mulMod(psiInvPow[i-1], psiInv, q)
80 }
81
82 // omega = psi^2 (primitive n-th root of unity).
83 omega := mulMod(psi, psi, q)
84 omegaPow := make([]uint32, n)
85 omegaPow[0] = 1
86 for i := 1; i < n; i++ {
87 omegaPow[i] = mulMod(omegaPow[i-1], omega, q)
88 }
89
90 omegaInv := powMod(omega, q-2, q)
91 omegaInvPow := make([]uint32, n)
92 omegaInvPow[0] = 1
93 for i := 1; i < n; i++ {
94 omegaInvPow[i] = mulMod(omegaInvPow[i-1], omegaInv, q)
95 }
96
97 // Bit-reversal permutation table.
98 bitrevPerm := make([]int, n)
99 for i := range n {
100 bitrevPerm[i] = bitrev(i, logN)
101 }
102
103 return &nttTables{
104 n: n,
105 q: q,
106 logN: logN,
107 psiPow: psiPow,
108 psiInvPow: psiInvPow,
109 omegaPow: omegaPow,
110 omegaInvPow: omegaInvPow,
111 bitrevPerm: bitrevPerm,
112 invN: powMod(uint32(n), q-2, q),
113 }
114 }
115
116 // NTT computes the forward negacyclic NTT in-place.
117 // Input: polynomial in coefficient form. Output: NTT (evaluation) form.
118 // Zero allocations.
119 func NTT(a *Poly) {
120 if a.isNTT {
121 return
122 }
123 t := getTables(a.params)
124 n := t.n
125 q := t.q
126 c := a.Coeffs
127
128 // Pre-multiply by psi^i for negacyclic convolution.
129 for i := range n {
130 c[i] = mulMod(c[i], t.psiPow[i], q)
131 }
132
133 // Bit-reversal permutation.
134 for i := range n {
135 j := t.bitrevPerm[i]
136 if i < j {
137 c[i], c[j] = c[j], c[i]
138 }
139 }
140
141 // Cooley-Tukey butterfly stages.
142 for length := 1; length < n; length <<= 1 {
143 step := n / (2 * length)
144 for start := 0; start < n; start += 2 * length {
145 for j := 0; j < length; j++ {
146 tw := t.omegaPow[(j*step)%n]
147 idx0 := start + j
148 idx1 := idx0 + length
149 u := c[idx0]
150 v := mulMod(c[idx1], tw, q)
151 c[idx0] = addMod(u, v, q)
152 c[idx1] = subMod(u, v, q)
153 }
154 }
155 }
156
157 a.isNTT = true
158 }
159
160 // INTT computes the inverse negacyclic NTT in-place.
161 // Input: NTT form. Output: coefficient form.
162 // Zero allocations.
163 func INTT(a *Poly) {
164 if !a.isNTT {
165 return
166 }
167 t := getTables(a.params)
168 n := t.n
169 q := t.q
170 c := a.Coeffs
171
172 // Bit-reversal permutation.
173 for i := range n {
174 j := t.bitrevPerm[i]
175 if i < j {
176 c[i], c[j] = c[j], c[i]
177 }
178 }
179
180 // Gentleman-Sande inverse butterfly.
181 for length := 1; length < n; length <<= 1 {
182 step := n / (2 * length)
183 for start := 0; start < n; start += 2 * length {
184 for j := 0; j < length; j++ {
185 tw := t.omegaInvPow[(j*step)%n]
186 idx0 := start + j
187 idx1 := idx0 + length
188 u := c[idx0]
189 v := mulMod(c[idx1], tw, q)
190 c[idx0] = addMod(u, v, q)
191 c[idx1] = subMod(u, v, q)
192 }
193 }
194 }
195
196 // Scale by n^{-1} and undo psi pre-multiplication.
197 for i := range n {
198 c[i] = mulMod(c[i], t.invN, q)
199 c[i] = mulMod(c[i], t.psiInvPow[i], q)
200 }
201
202 a.isNTT = false
203 }
204
205 // Mul computes c = a * b in R_q via NTT.
206 // If inputs are already in NTT form, does pointwise multiply.
207 // If in coefficient form, transforms, multiplies, and inverse transforms.
208 func Mul(a, b *Poly) *Poly {
209 if a.isNTT && b.isNTT {
210 return MulPointwise(a, b)
211 }
212
213 aNTT := a.Clone()
214 bNTT := b.Clone()
215 NTT(aNTT)
216 NTT(bNTT)
217 c := MulPointwise(aNTT, bNTT)
218 INTT(c)
219 return c
220 }
221
222 // powMod computes base^exp mod q by binary exponentiation.
223 func powMod(base, exp, q uint32) uint32 {
224 result := uint32(1)
225 b := base % q
226 for e := exp; e > 0; e >>= 1 {
227 if e&1 == 1 {
228 result = mulMod(result, b, q)
229 }
230 b = mulMod(b, b, q)
231 }
232 return result
233 }
234
235 // log2 returns the base-2 logarithm of n (must be a power of 2).
236 func log2(n int) int {
237 r := 0
238 for n >>= 1; n > 0; n >>= 1 {
239 r++
240 }
241 return r
242 }
243
244 // bitrev reverses the low `bits` bits of x.
245 func bitrev(x, bits int) int {
246 r := 0
247 for i := range bits {
248 _ = i
249 r = (r << 1) | (x & 1)
250 x >>= 1
251 }
252 return r
253 }
254