ntt.go raw

   1  package ring
   2  
   3  // Number-Theoretic Transform for R_q = Z_q[x]/(x^n + 1).
   4  //
   5  // The negacyclic NTT evaluates a polynomial at the roots of x^n + 1,
   6  // which are psi^(2k+1) for k = 0..n-1, where psi is a primitive 2n-th
   7  // root of unity mod q.
   8  //
   9  // Implementation uses Cooley-Tukey (forward) and Gentleman-Sande (inverse)
  10  // butterfly structures with fully precomputed twiddle factors.
  11  // Zero allocations in the hot path.
  12  //
  13  // The NTT enables O(n log n) polynomial multiplication:
  14  //   Mul(a, b) = INTT(NTT(a) ⊙ NTT(b))
  15  // where ⊙ is pointwise multiplication.
  16  
  17  // nttTables holds precomputed twiddle factors for a specific parameter set.
  18  type nttTables struct {
  19  	n    int
  20  	q    uint32
  21  	logN int
  22  
  23  	// psiPow[i] = psi^i mod q for i in [0, 2n).
  24  	// Used for pre-multiplication in forward NTT.
  25  	psiPow []uint32
  26  
  27  	// psiInvPow[i] = psi^{-i} mod q for i in [0, 2n).
  28  	// Used for post-multiplication in inverse NTT.
  29  	psiInvPow []uint32
  30  
  31  	// omegaPow[i] = omega^i mod q where omega = psi^2.
  32  	// Twiddle factors for forward butterfly.
  33  	omegaPow []uint32
  34  
  35  	// omegaInvPow[i] = omega^{-i} mod q.
  36  	// Twiddle factors for inverse butterfly.
  37  	omegaInvPow []uint32
  38  
  39  	// bitrevPerm[i] = bit-reversal of i.
  40  	bitrevPerm []int
  41  
  42  	// invN = n^{-1} mod q, for the 1/n scaling in the inverse transform.
  43  	invN uint32
  44  }
  45  
  46  // tableCache maps (N, Q) → precomputed tables. Populated lazily.
  47  var tableCache = make(map[[2]uint32]*nttTables)
  48  
  49  // getTables returns (possibly cached) NTT tables for the given parameters.
  50  func getTables(p Params) *nttTables {
  51  	key := [2]uint32{uint32(p.N), p.Q}
  52  	if t, ok := tableCache[key]; ok {
  53  		return t
  54  	}
  55  	t := newNTTTables(p)
  56  	tableCache[key] = t
  57  	return t
  58  }
  59  
  60  // newNTTTables precomputes all twiddle factors for zero-alloc NTT.
  61  func newNTTTables(p Params) *nttTables {
  62  	n := p.N
  63  	q := p.Q
  64  	psi := p.RootOfUnity
  65  	logN := log2(n)
  66  
  67  	// Powers of psi: psi^0 .. psi^{2n-1}.
  68  	psiPow := make([]uint32, 2*n)
  69  	psiPow[0] = 1
  70  	for i := 1; i < 2*n; i++ {
  71  		psiPow[i] = mulMod(psiPow[i-1], psi, q)
  72  	}
  73  
  74  	// Powers of psi^{-1}.
  75  	psiInv := powMod(psi, q-2, q)
  76  	psiInvPow := make([]uint32, 2*n)
  77  	psiInvPow[0] = 1
  78  	for i := 1; i < 2*n; i++ {
  79  		psiInvPow[i] = mulMod(psiInvPow[i-1], psiInv, q)
  80  	}
  81  
  82  	// omega = psi^2 (primitive n-th root of unity).
  83  	omega := mulMod(psi, psi, q)
  84  	omegaPow := make([]uint32, n)
  85  	omegaPow[0] = 1
  86  	for i := 1; i < n; i++ {
  87  		omegaPow[i] = mulMod(omegaPow[i-1], omega, q)
  88  	}
  89  
  90  	omegaInv := powMod(omega, q-2, q)
  91  	omegaInvPow := make([]uint32, n)
  92  	omegaInvPow[0] = 1
  93  	for i := 1; i < n; i++ {
  94  		omegaInvPow[i] = mulMod(omegaInvPow[i-1], omegaInv, q)
  95  	}
  96  
  97  	// Bit-reversal permutation table.
  98  	bitrevPerm := make([]int, n)
  99  	for i := range n {
 100  		bitrevPerm[i] = bitrev(i, logN)
 101  	}
 102  
 103  	return &nttTables{
 104  		n:           n,
 105  		q:           q,
 106  		logN:        logN,
 107  		psiPow:      psiPow,
 108  		psiInvPow:   psiInvPow,
 109  		omegaPow:    omegaPow,
 110  		omegaInvPow: omegaInvPow,
 111  		bitrevPerm:  bitrevPerm,
 112  		invN:        powMod(uint32(n), q-2, q),
 113  	}
 114  }
 115  
 116  // NTT computes the forward negacyclic NTT in-place.
 117  // Input: polynomial in coefficient form. Output: NTT (evaluation) form.
 118  // Zero allocations.
 119  func NTT(a *Poly) {
 120  	if a.isNTT {
 121  		return
 122  	}
 123  	t := getTables(a.params)
 124  	n := t.n
 125  	q := t.q
 126  	c := a.Coeffs
 127  
 128  	// Pre-multiply by psi^i for negacyclic convolution.
 129  	for i := range n {
 130  		c[i] = mulMod(c[i], t.psiPow[i], q)
 131  	}
 132  
 133  	// Bit-reversal permutation.
 134  	for i := range n {
 135  		j := t.bitrevPerm[i]
 136  		if i < j {
 137  			c[i], c[j] = c[j], c[i]
 138  		}
 139  	}
 140  
 141  	// Cooley-Tukey butterfly stages.
 142  	for length := 1; length < n; length <<= 1 {
 143  		step := n / (2 * length)
 144  		for start := 0; start < n; start += 2 * length {
 145  			for j := 0; j < length; j++ {
 146  				tw := t.omegaPow[(j*step)%n]
 147  				idx0 := start + j
 148  				idx1 := idx0 + length
 149  				u := c[idx0]
 150  				v := mulMod(c[idx1], tw, q)
 151  				c[idx0] = addMod(u, v, q)
 152  				c[idx1] = subMod(u, v, q)
 153  			}
 154  		}
 155  	}
 156  
 157  	a.isNTT = true
 158  }
 159  
 160  // INTT computes the inverse negacyclic NTT in-place.
 161  // Input: NTT form. Output: coefficient form.
 162  // Zero allocations.
 163  func INTT(a *Poly) {
 164  	if !a.isNTT {
 165  		return
 166  	}
 167  	t := getTables(a.params)
 168  	n := t.n
 169  	q := t.q
 170  	c := a.Coeffs
 171  
 172  	// Bit-reversal permutation.
 173  	for i := range n {
 174  		j := t.bitrevPerm[i]
 175  		if i < j {
 176  			c[i], c[j] = c[j], c[i]
 177  		}
 178  	}
 179  
 180  	// Gentleman-Sande inverse butterfly.
 181  	for length := 1; length < n; length <<= 1 {
 182  		step := n / (2 * length)
 183  		for start := 0; start < n; start += 2 * length {
 184  			for j := 0; j < length; j++ {
 185  				tw := t.omegaInvPow[(j*step)%n]
 186  				idx0 := start + j
 187  				idx1 := idx0 + length
 188  				u := c[idx0]
 189  				v := mulMod(c[idx1], tw, q)
 190  				c[idx0] = addMod(u, v, q)
 191  				c[idx1] = subMod(u, v, q)
 192  			}
 193  		}
 194  	}
 195  
 196  	// Scale by n^{-1} and undo psi pre-multiplication.
 197  	for i := range n {
 198  		c[i] = mulMod(c[i], t.invN, q)
 199  		c[i] = mulMod(c[i], t.psiInvPow[i], q)
 200  	}
 201  
 202  	a.isNTT = false
 203  }
 204  
 205  // Mul computes c = a * b in R_q via NTT.
 206  // If inputs are already in NTT form, does pointwise multiply.
 207  // If in coefficient form, transforms, multiplies, and inverse transforms.
 208  func Mul(a, b *Poly) *Poly {
 209  	if a.isNTT && b.isNTT {
 210  		return MulPointwise(a, b)
 211  	}
 212  
 213  	aNTT := a.Clone()
 214  	bNTT := b.Clone()
 215  	NTT(aNTT)
 216  	NTT(bNTT)
 217  	c := MulPointwise(aNTT, bNTT)
 218  	INTT(c)
 219  	return c
 220  }
 221  
 222  // powMod computes base^exp mod q by binary exponentiation.
 223  func powMod(base, exp, q uint32) uint32 {
 224  	result := uint32(1)
 225  	b := base % q
 226  	for e := exp; e > 0; e >>= 1 {
 227  		if e&1 == 1 {
 228  			result = mulMod(result, b, q)
 229  		}
 230  		b = mulMod(b, b, q)
 231  	}
 232  	return result
 233  }
 234  
 235  // log2 returns the base-2 logarithm of n (must be a power of 2).
 236  func log2(n int) int {
 237  	r := 0
 238  	for n >>= 1; n > 0; n >>= 1 {
 239  		r++
 240  	}
 241  	return r
 242  }
 243  
 244  // bitrev reverses the low `bits` bits of x.
 245  func bitrev(x, bits int) int {
 246  	r := 0
 247  	for i := range bits {
 248  		_ = i
 249  		r = (r << 1) | (x & 1)
 250  		x >>= 1
 251  	}
 252  	return r
 253  }
 254