ntt27.go raw
1 package crypto
2
3 // Radix-3 Number-Theoretic Transform (NTT) over Z_271 for trinary Hamadryad.
4 //
5 // The trinary analogue of the radix-2 NTT in ntt.go. Transforms a
6 // length-27 polynomial from coefficient form to evaluation form over Z_271,
7 // enabling O(n log n) polynomial multiplication via pointwise multiply.
8 //
9 // For trinary Hamadryad: n=27 = 3^3, p=271.
10 //
11 // 271 is prime and 270 = 271-1 = 2 × 5 × 3^3, so Z_271* has order 270.
12 // g = 6 is a primitive root (generator) of Z_271*.
13 // psi = 6^(270/54) = 6^5 = 188 is a primitive 54th root of unity.
14 //
15 // For polynomials mod (x^27 + 1), we use the negacyclic NTT:
16 // pre-multiply coefficients by powers of psi (a 54th root),
17 // then apply a standard 27-point NTT with omega = psi^2 = 114
18 // (a primitive 27th root of unity).
19 //
20 // The radix-3 butterfly uses the primitive cube root of unity
21 // w3 = omega^9 = 242, which satisfies 1 + w3 + w3^2 = 0 (mod 271).
22
23 // GnarlP is the ring modulus for the trinary Hamadryad.
24 const GnarlP = 271
25
26 // GnarlN is the ring dimension: degree of x^27 + 1.
27 const GnarlN = 27
28
29 // mod271 reduces x into [0, 270]. x must be in [0, 73441).
30 // Uses Barrett reduction: q = (x * 483) >> 17, r = x - q*271.
31 // Magic constant: 483 = floor(2^17 / 271). All arithmetic fits in uint32.
32 // Max intermediate: 73170 * 483 = 35,341,110 < 2^32.
33 //
34 //go:nosplit
35 func mod271(x uint32) uint16 {
36 q := (x * 483) >> 17
37 r := x - q*271
38 if r >= 271 {
39 r -= 271
40 }
41 return uint16(r)
42 }
43
44 // mod271s reduces a signed int into [0, 270]. Used only in test/rare paths.
45 func mod271s(x int) uint16 {
46 x %= GnarlP
47 if x < 0 {
48 x += GnarlP
49 }
50 return uint16(x)
51 }
52
53 // powMod271 computes base^exp mod 271 via binary exponentiation.
54 func powMod271(base, exp int) uint16 {
55 result := 1
56 b := base % GnarlP
57 if b < 0 {
58 b += GnarlP
59 }
60 for e := exp; e > 0; e >>= 1 {
61 if e&1 == 1 {
62 result = (result * b) % GnarlP
63 }
64 b = (b * b) % GnarlP
65 }
66 return uint16(result)
67 }
68
69 // invMod271 computes the modular inverse of a mod 271.
70 // Uses Fermat's little theorem: a^(-1) = a^(p-2) mod p.
71 func invMod271(a uint16) uint16 {
72 return powMod271(int(a), GnarlP-2)
73 }
74
75 // digitRev3 reverses the base-3 digits of x with the given number of digits.
76 // For n=27 = 3^3, we use 3 digits.
77 func digitRev3(x, digits int) int {
78 r := 0
79 for range digits {
80 r = r*3 + x%3
81 x /= 3
82 }
83 return r
84 }
85
86 // Pre-computed tables for the radix-3 NTT, filled by initNTT27Tables.
87 var (
88 // psiPows27[i] = psi^i mod 271, for i=0..53.
89 // psi = 188 (primitive 54th root of unity).
90 psiPows27 [54]uint16
91
92 // psiInvPows27[i] = psi^(-i) mod 271.
93 psiInvPows27 [54]uint16
94
95 // invN27 = 27^(-1) mod 271 = 261.
96 invN27 uint16
97
98 ntt27TablesReady bool
99 )
100
101 // initNTT27Tables populates the radix-3 NTT lookup tables.
102 func initNTT27Tables() {
103 if ntt27TablesReady {
104 return
105 }
106 const psi = 188 // primitive 54th root of unity in Z_271
107
108 psiPows27[0] = 1
109 for i := 1; i < 54; i++ {
110 psiPows27[i] = mod271(uint32(psiPows27[i-1]) * psi)
111 }
112
113 psiInv := invMod271(psi)
114 psiInvPows27[0] = 1
115 for i := 1; i < 54; i++ {
116 psiInvPows27[i] = mod271(uint32(psiInvPows27[i-1]) * uint32(psiInv))
117 }
118
119 invN27 = invMod271(uint16(GnarlN))
120 ntt27TablesReady = true
121 }
122
123 func init() {
124 initNTT27Tables()
125 }
126
127 // ntt27 computes the forward negacyclic NTT of a length-27 polynomial over Z_271.
128 //
129 // The negacyclic NTT evaluates the polynomial at the roots of (x^27 + 1),
130 // which are psi^(2k+1) for k=0..26. Implementation:
131 // 1. Pre-multiply: a[i] *= psi^i
132 // 2. Radix-3 DIT NTT with omega = psi^2 (primitive 27th root of unity)
133 //
134 // The radix-3 butterfly for indices [j, j+m, j+2m] with twiddle w:
135 //
136 // a0' = a0 + w*a1 + w^2*a2
137 // a1' = a0 + w*a1*w3 + w^2*a2*w3^2
138 // a2' = a0 + w*a1*w3^2 + w^2*a2*w3
139 //
140 // where w3 = omega^(27/3) = omega^9 = 242 is the primitive cube root of unity.
141 func ntt27(a *[GnarlN]uint16) {
142 // Fully-unrolled negacyclic NTT over Z_271, n=27 = 3^3.
143 // Three stages of radix-3 butterflies with precomputed twiddles.
144
145 // Step 1: pre-multiply by psi^i for negacyclic.
146 // psi = 188, psiPows27 = [1,188,114,23,259,183,258,266,144,243,156,60,
147 // 169,65,25,93,140,33,242,239,217,146,77,113,106,145,160]
148 // a[0] *= 1 (no-op)
149 a[1] = mod271(uint32(a[1]) * 188)
150 a[2] = mod271(uint32(a[2]) * 114)
151 a[3] = mod271(uint32(a[3]) * 23)
152 a[4] = mod271(uint32(a[4]) * 259)
153 a[5] = mod271(uint32(a[5]) * 183)
154 a[6] = mod271(uint32(a[6]) * 258)
155 a[7] = mod271(uint32(a[7]) * 266)
156 a[8] = mod271(uint32(a[8]) * 144)
157 a[9] = mod271(uint32(a[9]) * 243)
158 a[10] = mod271(uint32(a[10]) * 156)
159 a[11] = mod271(uint32(a[11]) * 60)
160 a[12] = mod271(uint32(a[12]) * 169)
161 a[13] = mod271(uint32(a[13]) * 65)
162 a[14] = mod271(uint32(a[14]) * 25)
163 a[15] = mod271(uint32(a[15]) * 93)
164 a[16] = mod271(uint32(a[16]) * 140)
165 a[17] = mod271(uint32(a[17]) * 33)
166 a[18] = mod271(uint32(a[18]) * 242)
167 a[19] = mod271(uint32(a[19]) * 239)
168 a[20] = mod271(uint32(a[20]) * 217)
169 a[21] = mod271(uint32(a[21]) * 146)
170 a[22] = mod271(uint32(a[22]) * 77)
171 a[23] = mod271(uint32(a[23]) * 113)
172 a[24] = mod271(uint32(a[24]) * 106)
173 a[25] = mod271(uint32(a[25]) * 145)
174 a[26] = mod271(uint32(a[26]) * 160)
175
176 // Step 2: digit-reversal permutation (9 swaps).
177 a[1], a[9] = a[9], a[1]
178 a[2], a[18] = a[18], a[2]
179 a[4], a[12] = a[12], a[4]
180 a[5], a[21] = a[21], a[5]
181 a[7], a[15] = a[15], a[7]
182 a[8], a[24] = a[24], a[8]
183 a[11], a[19] = a[19], a[11]
184 a[14], a[22] = a[22], a[14]
185 a[17], a[25] = a[25], a[17]
186
187 // Step 3: radix-3 butterfly stages (fully unrolled).
188 // w3 = 242, w3sq = 28. All twiddles precomputed.
189
190 // --- Stage 0: stride=1, groupSize=3, all tw1=tw2=1 ---
191 // 9 butterflies, each: b0=a0+a1+a2, b1=a0+242*a1+28*a2, b2=a0+28*a1+242*a2
192 nttBfly1(a, 0, 1, 2)
193 nttBfly1(a, 3, 4, 5)
194 nttBfly1(a, 6, 7, 8)
195 nttBfly1(a, 9, 10, 11)
196 nttBfly1(a, 12, 13, 14)
197 nttBfly1(a, 15, 16, 17)
198 nttBfly1(a, 18, 19, 20)
199 nttBfly1(a, 21, 22, 23)
200 nttBfly1(a, 24, 25, 26)
201
202 // --- Stage 1: stride=3, groupSize=9, twiddleStep=3 ---
203 // j=0: tw1=1, tw2=1 (same as stage 0 kernel)
204 nttBfly1(a, 0, 3, 6)
205 nttBfly1(a, 9, 12, 15)
206 nttBfly1(a, 18, 21, 24)
207 // j=1: tw1=258, tw2=169
208 nttBfly(a, 1, 4, 7, 258, 169)
209 nttBfly(a, 10, 13, 16, 258, 169)
210 nttBfly(a, 19, 22, 25, 258, 169)
211 // j=2: tw1=169, tw2=106
212 nttBfly(a, 2, 5, 8, 169, 106)
213 nttBfly(a, 11, 14, 17, 169, 106)
214 nttBfly(a, 20, 23, 26, 169, 106)
215
216 // --- Stage 2: stride=9, groupSize=27, twiddleStep=1 ---
217 // j=0: tw1=1, tw2=1
218 nttBfly1(a, 0, 9, 18)
219 // j=1: tw1=114, tw2=259
220 nttBfly(a, 1, 10, 19, 114, 259)
221 // j=2: tw1=259, tw2=144
222 nttBfly(a, 2, 11, 20, 259, 144)
223 // j=3: tw1=258, tw2=169
224 nttBfly(a, 3, 12, 21, 258, 169)
225 // j=4: tw1=144, tw2=140
226 nttBfly(a, 4, 13, 22, 144, 140)
227 // j=5: tw1=156, tw2=217
228 nttBfly(a, 5, 14, 23, 156, 217)
229 // j=6: tw1=169, tw2=106
230 nttBfly(a, 6, 15, 24, 169, 106)
231 // j=7: tw1=25, tw2=83
232 nttBfly(a, 7, 16, 25, 25, 83)
233 // j=8: tw1=140, tw2=88
234 nttBfly(a, 8, 17, 26, 140, 88)
235 }
236
237 // nttBfly1 performs a radix-3 butterfly with trivial twiddles (tw1=tw2=1).
238 // b0 = a0 + a1 + a2, b1 = a0 + 242*a1 + 28*a2, b2 = a0 + 28*a1 + 242*a2
239 //
240 //go:nosplit
241 func nttBfly1(a *[27]uint16, i0, i1, i2 int) {
242 v0 := uint32(a[i0])
243 v1 := uint32(a[i1])
244 v2 := uint32(a[i2])
245 a[i0] = mod271(v0 + v1 + v2)
246 a[i1] = mod271(v0 + uint32(mod271(v1*242)) + uint32(mod271(v2*28)))
247 a[i2] = mod271(v0 + uint32(mod271(v1*28)) + uint32(mod271(v2*242)))
248 }
249
250 // nttBfly performs a radix-3 butterfly with given twiddle factors tw1, tw2.
251 //
252 //go:nosplit
253 func nttBfly(a *[27]uint16, i0, i1, i2 int, tw1, tw2 uint32) {
254 v0 := uint32(a[i0])
255 a1tw := uint32(mod271(uint32(a[i1]) * tw1))
256 a2tw := uint32(mod271(uint32(a[i2]) * tw2))
257 a[i0] = mod271(v0 + a1tw + a2tw)
258 a[i1] = mod271(v0 + uint32(mod271(a1tw*242)) + uint32(mod271(a2tw*28)))
259 a[i2] = mod271(v0 + uint32(mod271(a1tw*28)) + uint32(mod271(a2tw*242)))
260 }
261
262 // intt27 computes the inverse negacyclic NTT, recovering coefficients.
263 //
264 // Uses the Decimation-In-Frequency (DIF) structure: butterflies top-down
265 // (largest group first), followed by digit-reversal permutation.
266 // This is the natural inverse of the forward DIT NTT.
267 func intt27(a *[GnarlN]uint16) {
268 // Fully-unrolled inverse negacyclic NTT (DIF, top-down).
269
270 // --- Stage 2: groupSize=27, stride=9, twiddleStep=1 ---
271 inttBfly1(a, 0, 9, 18)
272 inttBfly(a, 1, 10, 19, 126, 158)
273 inttBfly(a, 2, 11, 20, 158, 32)
274 inttBfly(a, 3, 12, 21, 125, 178)
275 inttBfly(a, 4, 13, 22, 32, 211)
276 inttBfly(a, 5, 14, 23, 238, 5)
277 inttBfly(a, 6, 15, 24, 178, 248)
278 inttBfly(a, 7, 16, 25, 206, 160)
279 inttBfly(a, 8, 17, 26, 211, 77)
280
281 // --- Stage 1: groupSize=9, stride=3, twiddleStep=3 ---
282 inttBfly1(a, 0, 3, 6)
283 inttBfly(a, 1, 4, 7, 125, 178)
284 inttBfly(a, 2, 5, 8, 178, 248)
285 inttBfly1(a, 9, 12, 15)
286 inttBfly(a, 10, 13, 16, 125, 178)
287 inttBfly(a, 11, 14, 17, 178, 248)
288 inttBfly1(a, 18, 21, 24)
289 inttBfly(a, 19, 22, 25, 125, 178)
290 inttBfly(a, 20, 23, 26, 178, 248)
291
292 // --- Stage 0: groupSize=3, stride=1, twiddleStep=9 (all trivial) ---
293 inttBfly1(a, 0, 1, 2)
294 inttBfly1(a, 3, 4, 5)
295 inttBfly1(a, 6, 7, 8)
296 inttBfly1(a, 9, 10, 11)
297 inttBfly1(a, 12, 13, 14)
298 inttBfly1(a, 15, 16, 17)
299 inttBfly1(a, 18, 19, 20)
300 inttBfly1(a, 21, 22, 23)
301 inttBfly1(a, 24, 25, 26)
302
303 // Digit-reversal permutation (9 swaps).
304 a[1], a[9] = a[9], a[1]
305 a[2], a[18] = a[18], a[2]
306 a[4], a[12] = a[12], a[4]
307 a[5], a[21] = a[21], a[5]
308 a[7], a[15] = a[15], a[7]
309 a[8], a[24] = a[24], a[8]
310 a[11], a[19] = a[19], a[11]
311 a[14], a[22] = a[22], a[14]
312 a[17], a[25] = a[25], a[17]
313
314 // Post-multiply: fused invN27 * psiInvPows27[i] for each coefficient.
315 a[0] = mod271(uint32(a[0]) * 261)
316 a[1] = mod271(uint32(a[1]) * 245)
317 a[2] = mod271(uint32(a[2]) * 95)
318 a[3] = mod271(uint32(a[3]) * 247)
319 a[4] = mod271(uint32(a[4]) * 46)
320 a[5] = mod271(uint32(a[5]) * 228)
321 a[6] = mod271(uint32(a[6]) * 105)
322 a[7] = mod271(uint32(a[7]) * 2)
323 a[8] = mod271(uint32(a[8]) * 222)
324 a[9] = mod271(uint32(a[9]) * 252)
325 a[10] = mod271(uint32(a[10]) * 59)
326 a[11] = mod271(uint32(a[11]) * 45)
327 a[12] = mod271(uint32(a[12]) * 117)
328 a[13] = mod271(uint32(a[13]) * 250)
329 a[14] = mod271(uint32(a[14]) * 108)
330 a[15] = mod271(uint32(a[15]) * 64)
331 a[16] = mod271(uint32(a[16]) * 58)
332 a[17] = mod271(uint32(a[17]) * 205)
333 a[18] = mod271(uint32(a[18]) * 262)
334 a[19] = mod271(uint32(a[19]) * 85)
335 a[20] = mod271(uint32(a[20]) * 221)
336 a[21] = mod271(uint32(a[21]) * 141)
337 a[22] = mod271(uint32(a[22]) * 204)
338 a[23] = mod271(uint32(a[23]) * 151)
339 a[24] = mod271(uint32(a[24]) * 230)
340 a[25] = mod271(uint32(a[25]) * 56)
341 a[26] = mod271(uint32(a[26]) * 254)
342 }
343
344 // inttBfly1 performs an inverse radix-3 DIF butterfly with trivial twiddles (tw1inv=tw2inv=1).
345 // b0 = a0 + a1 + a2, b1 = a0 + 28*a1 + 242*a2, b2 = a0 + 242*a1 + 28*a2
346 //
347 //go:nosplit
348 func inttBfly1(a *[27]uint16, i0, i1, i2 int) {
349 v0 := uint32(a[i0])
350 v1 := uint32(a[i1])
351 v2 := uint32(a[i2])
352 a[i0] = mod271(v0 + v1 + v2)
353 a[i1] = mod271(v0 + uint32(mod271(v1*28)) + uint32(mod271(v2*242)))
354 a[i2] = mod271(v0 + uint32(mod271(v1*242)) + uint32(mod271(v2*28)))
355 }
356
357 // inttBfly performs an inverse radix-3 DIF butterfly with given twiddle factors.
358 // After the butterfly, b1 *= tw1inv, b2 *= tw2inv.
359 //
360 //go:nosplit
361 func inttBfly(a *[27]uint16, i0, i1, i2 int, tw1inv, tw2inv uint32) {
362 v0 := uint32(a[i0])
363 v1 := uint32(a[i1])
364 v2 := uint32(a[i2])
365 a[i0] = mod271(v0 + v1 + v2)
366 b1 := mod271(v0 + uint32(mod271(v1*28)) + uint32(mod271(v2*242)))
367 b2 := mod271(v0 + uint32(mod271(v1*242)) + uint32(mod271(v2*28)))
368 a[i1] = mod271(uint32(b1) * tw1inv)
369 a[i2] = mod271(uint32(b2) * tw2inv)
370 }
371