package ring import ( "crypto/rand" "math/big" "testing" ) // TestNTTRoundTrip verifies NTT(INTT(a)) == a and INTT(NTT(a)) == a. func TestNTTRoundTrip(t *testing.T) { for _, p := range []Params{Falcon512(), NewHope256()} { t.Run(paramName(p), func(t *testing.T) { a := randomPoly(p) orig := a.Clone() NTT(a) if Equal(a, orig) { t.Fatal("NTT should change coefficients") } INTT(a) if !Equal(a, orig) { t.Fatal("INTT(NTT(a)) != a") for i := range a.Coeffs { if a.Coeffs[i] != orig.Coeffs[i] { t.Logf(" diff at %d: got %d, want %d", i, a.Coeffs[i], orig.Coeffs[i]) if i > 10 { break } } } } }) } } // TestMulSchoolbook verifies NTT-based multiplication matches schoolbook // multiplication in Z_q[x]/(x^n + 1) for small polynomials. func TestMulSchoolbook(t *testing.T) { for _, p := range []Params{Falcon512(), NewHope256()} { t.Run(paramName(p), func(t *testing.T) { a := randomSmallPoly(p, 10) b := randomSmallPoly(p, 10) // NTT multiplication. cNTT := Mul(a, b) // Schoolbook multiplication mod (x^n + 1). cSchool := schoolbookMul(a, b) if !Equal(cNTT, cSchool) { diffs := 0 for i := range cNTT.Coeffs { if cNTT.Coeffs[i] != cSchool.Coeffs[i] { if diffs < 10 { t.Logf(" diff at %d: NTT=%d school=%d", i, cNTT.Coeffs[i], cSchool.Coeffs[i]) } diffs++ } } t.Fatalf("NTT mul != schoolbook mul (%d diffs)", diffs) } }) } } // TestAddSub verifies Add and Sub are inverses. func TestAddSub(t *testing.T) { p := Falcon512() a := randomPoly(p) b := randomPoly(p) c := Add(a, b) d := Sub(c, b) if !Equal(d, a) { t.Fatal("(a + b) - b != a") } } // TestNeg verifies a + (-a) == 0. func TestNeg(t *testing.T) { p := Falcon512() a := randomPoly(p) b := Neg(a) c := Add(a, b) for i := range c.Coeffs { if c.Coeffs[i] != 0 { t.Fatalf("a + (-a) != 0 at index %d: got %d", i, c.Coeffs[i]) } } } // TestScalarMul verifies scalar multiplication. func TestScalarMul(t *testing.T) { p := Falcon512() a := randomPoly(p) // a * 0 == 0 z := ScalarMul(a, 0) for i := range z.Coeffs { if z.Coeffs[i] != 0 { t.Fatal("a * 0 != 0") } } // a * 1 == a one := ScalarMul(a, 1) if !Equal(one, a) { t.Fatal("a * 1 != a") } // a * 2 == a + a two := ScalarMul(a, 2) sum := Add(a, a) if !Equal(two, sum) { t.Fatal("a * 2 != a + a") } } // TestMulCommutativity verifies a * b == b * a. func TestMulCommutativity(t *testing.T) { p := Falcon512() a := randomSmallPoly(p, 100) b := randomSmallPoly(p, 100) ab := Mul(a, b) ba := Mul(b, a) if !Equal(ab, ba) { t.Fatal("a * b != b * a") } } // TestMulDistributivity verifies a*(b+c) == a*b + a*c. func TestMulDistributivity(t *testing.T) { p := NewHope256() a := randomSmallPoly(p, 50) b := randomSmallPoly(p, 50) c := randomSmallPoly(p, 50) bc := Add(b, c) lhs := Mul(a, bc) ab := Mul(a, b) ac := Mul(a, c) rhs := Add(ab, ac) if !Equal(lhs, rhs) { t.Fatal("a*(b+c) != a*b + a*c") } } // TestQInv verifies the Montgomery constant. func TestQInv(t *testing.T) { for _, p := range []Params{Falcon512(), NewHope256()} { t.Run(paramName(p), func(t *testing.T) { // qinv * q ≡ -1 (mod 2^16) // Equivalently: (qinv * q + 1) mod 2^16 == 0 product := p.QInv * uint64(p.Q) mask := (uint64(1) << 16) - 1 if (product+1)&mask != 0 { t.Fatalf("QInv check failed: QInv=%d, Q=%d, (QInv*Q+1) mod 2^16 = %d", p.QInv, p.Q, (product+1)&mask) } }) } } func BenchmarkNTT(b *testing.B) { for _, p := range []Params{Falcon512(), NewHope256()} { b.Run(paramName(p), func(b *testing.B) { poly := randomPoly(p) b.ResetTimer() for range b.N { a := poly.Clone() NTT(a) } }) } } func BenchmarkINTT(b *testing.B) { for _, p := range []Params{Falcon512(), NewHope256()} { b.Run(paramName(p), func(b *testing.B) { poly := randomPoly(p) NTT(poly) b.ResetTimer() for range b.N { a := poly.Clone() INTT(a) } }) } } func BenchmarkMul(b *testing.B) { for _, p := range []Params{Falcon512(), NewHope256()} { b.Run(paramName(p), func(b *testing.B) { a := randomPoly(p) bp2 := randomPoly(p) b.ResetTimer() for range b.N { Mul(a, bp2) } }) } } // --- helpers --- func paramName(p Params) string { switch { case p.N == 512 && p.Q == 12289: return "Falcon512" case p.N == 256 && p.Q == 7681: return "NewHope256" default: return "Custom" } } func randomPoly(p Params) *Poly { a := New(p) buf := make([]byte, 4) for i := range a.Coeffs { rand.Read(buf) v := uint32(buf[0]) | uint32(buf[1])<<8 a.Coeffs[i] = v % p.Q } return a } func randomSmallPoly(p Params, bound uint32) *Poly { a := New(p) buf := make([]byte, 2) for i := range a.Coeffs { rand.Read(buf) v := uint32(buf[0]) | uint32(buf[1])<<8 a.Coeffs[i] = v % bound } return a } // schoolbookMul computes a * b mod (x^n + 1) using naive O(n²) multiplication. func schoolbookMul(a, b *Poly) *Poly { n := a.params.N q := int64(a.params.Q) // Use big.Int for exact intermediate values. result := make([]int64, n) for i := 0; i < n; i++ { for j := 0; j < n; j++ { prod := int64(a.Coeffs[i]) * int64(b.Coeffs[j]) k := i + j if k < n { result[k] = (result[k] + prod) % q } else { // x^n ≡ -1, so x^{n+r} ≡ -x^r result[k-n] = (result[k-n] - prod) % q } } } c := New(a.params) for i := range result { v := result[i] % q if v < 0 { v += q } c.Coeffs[i] = uint32(v) } return c } // Verify that our parameters are correct by checking root of unity properties. func TestRootOfUnity(t *testing.T) { for _, p := range []Params{Falcon512(), NewHope256()} { t.Run(paramName(p), func(t *testing.T) { psi := big.NewInt(int64(p.RootOfUnity)) q := big.NewInt(int64(p.Q)) twoN := big.NewInt(int64(2 * p.N)) // psi^{2n} ≡ 1 (mod q) order := new(big.Int).Exp(psi, twoN, q) if order.Int64() != 1 { t.Fatalf("psi^{2n} = %d, want 1 (mod %d)", order, p.Q) } // psi^n ≡ -1 (mod q) [negacyclic property] nBig := big.NewInt(int64(p.N)) half := new(big.Int).Exp(psi, nBig, q) want := new(big.Int).Sub(q, big.NewInt(1)) // q-1 ≡ -1 if half.Cmp(want) != 0 { t.Fatalf("psi^n = %d, want %d (-1 mod %d)", half, want, p.Q) } }) } }