package ring // Number-Theoretic Transform for R_q = Z_q[x]/(x^n + 1). // // The negacyclic NTT evaluates a polynomial at the roots of x^n + 1, // which are psi^(2k+1) for k = 0..n-1, where psi is a primitive 2n-th // root of unity mod q. // // Implementation uses Cooley-Tukey (forward) and Gentleman-Sande (inverse) // butterfly structures with fully precomputed twiddle factors. // Zero allocations in the hot path. // // The NTT enables O(n log n) polynomial multiplication: // Mul(a, b) = INTT(NTT(a) ⊙ NTT(b)) // where ⊙ is pointwise multiplication. // nttTables holds precomputed twiddle factors for a specific parameter set. type nttTables struct { n int q uint32 logN int // psiPow[i] = psi^i mod q for i in [0, 2n). // Used for pre-multiplication in forward NTT. psiPow []uint32 // psiInvPow[i] = psi^{-i} mod q for i in [0, 2n). // Used for post-multiplication in inverse NTT. psiInvPow []uint32 // omegaPow[i] = omega^i mod q where omega = psi^2. // Twiddle factors for forward butterfly. omegaPow []uint32 // omegaInvPow[i] = omega^{-i} mod q. // Twiddle factors for inverse butterfly. omegaInvPow []uint32 // bitrevPerm[i] = bit-reversal of i. bitrevPerm []int // invN = n^{-1} mod q, for the 1/n scaling in the inverse transform. invN uint32 } // tableCache maps (N, Q) → precomputed tables. Populated lazily. var tableCache = make(map[[2]uint32]*nttTables) // getTables returns (possibly cached) NTT tables for the given parameters. func getTables(p Params) *nttTables { key := [2]uint32{uint32(p.N), p.Q} if t, ok := tableCache[key]; ok { return t } t := newNTTTables(p) tableCache[key] = t return t } // newNTTTables precomputes all twiddle factors for zero-alloc NTT. func newNTTTables(p Params) *nttTables { n := p.N q := p.Q psi := p.RootOfUnity logN := log2(n) // Powers of psi: psi^0 .. psi^{2n-1}. psiPow := make([]uint32, 2*n) psiPow[0] = 1 for i := 1; i < 2*n; i++ { psiPow[i] = mulMod(psiPow[i-1], psi, q) } // Powers of psi^{-1}. psiInv := powMod(psi, q-2, q) psiInvPow := make([]uint32, 2*n) psiInvPow[0] = 1 for i := 1; i < 2*n; i++ { psiInvPow[i] = mulMod(psiInvPow[i-1], psiInv, q) } // omega = psi^2 (primitive n-th root of unity). omega := mulMod(psi, psi, q) omegaPow := make([]uint32, n) omegaPow[0] = 1 for i := 1; i < n; i++ { omegaPow[i] = mulMod(omegaPow[i-1], omega, q) } omegaInv := powMod(omega, q-2, q) omegaInvPow := make([]uint32, n) omegaInvPow[0] = 1 for i := 1; i < n; i++ { omegaInvPow[i] = mulMod(omegaInvPow[i-1], omegaInv, q) } // Bit-reversal permutation table. bitrevPerm := make([]int, n) for i := range n { bitrevPerm[i] = bitrev(i, logN) } return &nttTables{ n: n, q: q, logN: logN, psiPow: psiPow, psiInvPow: psiInvPow, omegaPow: omegaPow, omegaInvPow: omegaInvPow, bitrevPerm: bitrevPerm, invN: powMod(uint32(n), q-2, q), } } // NTT computes the forward negacyclic NTT in-place. // Input: polynomial in coefficient form. Output: NTT (evaluation) form. // Zero allocations. func NTT(a *Poly) { if a.isNTT { return } t := getTables(a.params) n := t.n q := t.q c := a.Coeffs // Pre-multiply by psi^i for negacyclic convolution. for i := range n { c[i] = mulMod(c[i], t.psiPow[i], q) } // Bit-reversal permutation. for i := range n { j := t.bitrevPerm[i] if i < j { c[i], c[j] = c[j], c[i] } } // Cooley-Tukey butterfly stages. for length := 1; length < n; length <<= 1 { step := n / (2 * length) for start := 0; start < n; start += 2 * length { for j := 0; j < length; j++ { tw := t.omegaPow[(j*step)%n] idx0 := start + j idx1 := idx0 + length u := c[idx0] v := mulMod(c[idx1], tw, q) c[idx0] = addMod(u, v, q) c[idx1] = subMod(u, v, q) } } } a.isNTT = true } // INTT computes the inverse negacyclic NTT in-place. // Input: NTT form. Output: coefficient form. // Zero allocations. func INTT(a *Poly) { if !a.isNTT { return } t := getTables(a.params) n := t.n q := t.q c := a.Coeffs // Bit-reversal permutation. for i := range n { j := t.bitrevPerm[i] if i < j { c[i], c[j] = c[j], c[i] } } // Gentleman-Sande inverse butterfly. for length := 1; length < n; length <<= 1 { step := n / (2 * length) for start := 0; start < n; start += 2 * length { for j := 0; j < length; j++ { tw := t.omegaInvPow[(j*step)%n] idx0 := start + j idx1 := idx0 + length u := c[idx0] v := mulMod(c[idx1], tw, q) c[idx0] = addMod(u, v, q) c[idx1] = subMod(u, v, q) } } } // Scale by n^{-1} and undo psi pre-multiplication. for i := range n { c[i] = mulMod(c[i], t.invN, q) c[i] = mulMod(c[i], t.psiInvPow[i], q) } a.isNTT = false } // Mul computes c = a * b in R_q via NTT. // If inputs are already in NTT form, does pointwise multiply. // If in coefficient form, transforms, multiplies, and inverse transforms. func Mul(a, b *Poly) *Poly { if a.isNTT && b.isNTT { return MulPointwise(a, b) } aNTT := a.Clone() bNTT := b.Clone() NTT(aNTT) NTT(bNTT) c := MulPointwise(aNTT, bNTT) INTT(c) return c } // powMod computes base^exp mod q by binary exponentiation. func powMod(base, exp, q uint32) uint32 { result := uint32(1) b := base % q for e := exp; e > 0; e >>= 1 { if e&1 == 1 { result = mulMod(result, b, q) } b = mulMod(b, b, q) } return result } // log2 returns the base-2 logarithm of n (must be a power of 2). func log2(n int) int { r := 0 for n >>= 1; n > 0; n >>= 1 { r++ } return r } // bitrev reverses the low `bits` bits of x. func bitrev(x, bits int) int { r := 0 for i := range bits { _ = i r = (r << 1) | (x & 1) x >>= 1 } return r }