ring_test.go raw

   1  package ring
   2  
   3  import (
   4  	"crypto/rand"
   5  	"math/big"
   6  	"testing"
   7  )
   8  
   9  // TestNTTRoundTrip verifies NTT(INTT(a)) == a and INTT(NTT(a)) == a.
  10  func TestNTTRoundTrip(t *testing.T) {
  11  	for _, p := range []Params{Falcon512(), NewHope256()} {
  12  		t.Run(paramName(p), func(t *testing.T) {
  13  			a := randomPoly(p)
  14  			orig := a.Clone()
  15  
  16  			NTT(a)
  17  			if Equal(a, orig) {
  18  				t.Fatal("NTT should change coefficients")
  19  			}
  20  
  21  			INTT(a)
  22  			if !Equal(a, orig) {
  23  				t.Fatal("INTT(NTT(a)) != a")
  24  				for i := range a.Coeffs {
  25  					if a.Coeffs[i] != orig.Coeffs[i] {
  26  						t.Logf("  diff at %d: got %d, want %d", i, a.Coeffs[i], orig.Coeffs[i])
  27  						if i > 10 {
  28  							break
  29  						}
  30  					}
  31  				}
  32  			}
  33  		})
  34  	}
  35  }
  36  
  37  // TestMulSchoolbook verifies NTT-based multiplication matches schoolbook
  38  // multiplication in Z_q[x]/(x^n + 1) for small polynomials.
  39  func TestMulSchoolbook(t *testing.T) {
  40  	for _, p := range []Params{Falcon512(), NewHope256()} {
  41  		t.Run(paramName(p), func(t *testing.T) {
  42  			a := randomSmallPoly(p, 10)
  43  			b := randomSmallPoly(p, 10)
  44  
  45  			// NTT multiplication.
  46  			cNTT := Mul(a, b)
  47  
  48  			// Schoolbook multiplication mod (x^n + 1).
  49  			cSchool := schoolbookMul(a, b)
  50  
  51  			if !Equal(cNTT, cSchool) {
  52  				diffs := 0
  53  				for i := range cNTT.Coeffs {
  54  					if cNTT.Coeffs[i] != cSchool.Coeffs[i] {
  55  						if diffs < 10 {
  56  							t.Logf("  diff at %d: NTT=%d school=%d", i, cNTT.Coeffs[i], cSchool.Coeffs[i])
  57  						}
  58  						diffs++
  59  					}
  60  				}
  61  				t.Fatalf("NTT mul != schoolbook mul (%d diffs)", diffs)
  62  			}
  63  		})
  64  	}
  65  }
  66  
  67  // TestAddSub verifies Add and Sub are inverses.
  68  func TestAddSub(t *testing.T) {
  69  	p := Falcon512()
  70  	a := randomPoly(p)
  71  	b := randomPoly(p)
  72  
  73  	c := Add(a, b)
  74  	d := Sub(c, b)
  75  	if !Equal(d, a) {
  76  		t.Fatal("(a + b) - b != a")
  77  	}
  78  }
  79  
  80  // TestNeg verifies a + (-a) == 0.
  81  func TestNeg(t *testing.T) {
  82  	p := Falcon512()
  83  	a := randomPoly(p)
  84  	b := Neg(a)
  85  	c := Add(a, b)
  86  
  87  	for i := range c.Coeffs {
  88  		if c.Coeffs[i] != 0 {
  89  			t.Fatalf("a + (-a) != 0 at index %d: got %d", i, c.Coeffs[i])
  90  		}
  91  	}
  92  }
  93  
  94  // TestScalarMul verifies scalar multiplication.
  95  func TestScalarMul(t *testing.T) {
  96  	p := Falcon512()
  97  	a := randomPoly(p)
  98  
  99  	// a * 0 == 0
 100  	z := ScalarMul(a, 0)
 101  	for i := range z.Coeffs {
 102  		if z.Coeffs[i] != 0 {
 103  			t.Fatal("a * 0 != 0")
 104  		}
 105  	}
 106  
 107  	// a * 1 == a
 108  	one := ScalarMul(a, 1)
 109  	if !Equal(one, a) {
 110  		t.Fatal("a * 1 != a")
 111  	}
 112  
 113  	// a * 2 == a + a
 114  	two := ScalarMul(a, 2)
 115  	sum := Add(a, a)
 116  	if !Equal(two, sum) {
 117  		t.Fatal("a * 2 != a + a")
 118  	}
 119  }
 120  
 121  // TestMulCommutativity verifies a * b == b * a.
 122  func TestMulCommutativity(t *testing.T) {
 123  	p := Falcon512()
 124  	a := randomSmallPoly(p, 100)
 125  	b := randomSmallPoly(p, 100)
 126  
 127  	ab := Mul(a, b)
 128  	ba := Mul(b, a)
 129  
 130  	if !Equal(ab, ba) {
 131  		t.Fatal("a * b != b * a")
 132  	}
 133  }
 134  
 135  // TestMulDistributivity verifies a*(b+c) == a*b + a*c.
 136  func TestMulDistributivity(t *testing.T) {
 137  	p := NewHope256()
 138  	a := randomSmallPoly(p, 50)
 139  	b := randomSmallPoly(p, 50)
 140  	c := randomSmallPoly(p, 50)
 141  
 142  	bc := Add(b, c)
 143  	lhs := Mul(a, bc)
 144  
 145  	ab := Mul(a, b)
 146  	ac := Mul(a, c)
 147  	rhs := Add(ab, ac)
 148  
 149  	if !Equal(lhs, rhs) {
 150  		t.Fatal("a*(b+c) != a*b + a*c")
 151  	}
 152  }
 153  
 154  // TestQInv verifies the Montgomery constant.
 155  func TestQInv(t *testing.T) {
 156  	for _, p := range []Params{Falcon512(), NewHope256()} {
 157  		t.Run(paramName(p), func(t *testing.T) {
 158  			// qinv * q ≡ -1 (mod 2^16)
 159  			// Equivalently: (qinv * q + 1) mod 2^16 == 0
 160  			product := p.QInv * uint64(p.Q)
 161  			mask := (uint64(1) << 16) - 1
 162  			if (product+1)&mask != 0 {
 163  				t.Fatalf("QInv check failed: QInv=%d, Q=%d, (QInv*Q+1) mod 2^16 = %d",
 164  					p.QInv, p.Q, (product+1)&mask)
 165  			}
 166  		})
 167  	}
 168  }
 169  
 170  func BenchmarkNTT(b *testing.B) {
 171  	for _, p := range []Params{Falcon512(), NewHope256()} {
 172  		b.Run(paramName(p), func(b *testing.B) {
 173  			poly := randomPoly(p)
 174  			b.ResetTimer()
 175  			for range b.N {
 176  				a := poly.Clone()
 177  				NTT(a)
 178  			}
 179  		})
 180  	}
 181  }
 182  
 183  func BenchmarkINTT(b *testing.B) {
 184  	for _, p := range []Params{Falcon512(), NewHope256()} {
 185  		b.Run(paramName(p), func(b *testing.B) {
 186  			poly := randomPoly(p)
 187  			NTT(poly)
 188  			b.ResetTimer()
 189  			for range b.N {
 190  				a := poly.Clone()
 191  				INTT(a)
 192  			}
 193  		})
 194  	}
 195  }
 196  
 197  func BenchmarkMul(b *testing.B) {
 198  	for _, p := range []Params{Falcon512(), NewHope256()} {
 199  		b.Run(paramName(p), func(b *testing.B) {
 200  			a := randomPoly(p)
 201  			bp2 := randomPoly(p)
 202  			b.ResetTimer()
 203  			for range b.N {
 204  				Mul(a, bp2)
 205  			}
 206  		})
 207  	}
 208  }
 209  
 210  // --- helpers ---
 211  
 212  func paramName(p Params) string {
 213  	switch {
 214  	case p.N == 512 && p.Q == 12289:
 215  		return "Falcon512"
 216  	case p.N == 256 && p.Q == 7681:
 217  		return "NewHope256"
 218  	default:
 219  		return "Custom"
 220  	}
 221  }
 222  
 223  func randomPoly(p Params) *Poly {
 224  	a := New(p)
 225  	buf := make([]byte, 4)
 226  	for i := range a.Coeffs {
 227  		rand.Read(buf)
 228  		v := uint32(buf[0]) | uint32(buf[1])<<8
 229  		a.Coeffs[i] = v % p.Q
 230  	}
 231  	return a
 232  }
 233  
 234  func randomSmallPoly(p Params, bound uint32) *Poly {
 235  	a := New(p)
 236  	buf := make([]byte, 2)
 237  	for i := range a.Coeffs {
 238  		rand.Read(buf)
 239  		v := uint32(buf[0]) | uint32(buf[1])<<8
 240  		a.Coeffs[i] = v % bound
 241  	}
 242  	return a
 243  }
 244  
 245  // schoolbookMul computes a * b mod (x^n + 1) using naive O(n²) multiplication.
 246  func schoolbookMul(a, b *Poly) *Poly {
 247  	n := a.params.N
 248  	q := int64(a.params.Q)
 249  
 250  	// Use big.Int for exact intermediate values.
 251  	result := make([]int64, n)
 252  	for i := 0; i < n; i++ {
 253  		for j := 0; j < n; j++ {
 254  			prod := int64(a.Coeffs[i]) * int64(b.Coeffs[j])
 255  			k := i + j
 256  			if k < n {
 257  				result[k] = (result[k] + prod) % q
 258  			} else {
 259  				// x^n ≡ -1, so x^{n+r} ≡ -x^r
 260  				result[k-n] = (result[k-n] - prod) % q
 261  			}
 262  		}
 263  	}
 264  
 265  	c := New(a.params)
 266  	for i := range result {
 267  		v := result[i] % q
 268  		if v < 0 {
 269  			v += q
 270  		}
 271  		c.Coeffs[i] = uint32(v)
 272  	}
 273  	return c
 274  }
 275  
 276  // Verify that our parameters are correct by checking root of unity properties.
 277  func TestRootOfUnity(t *testing.T) {
 278  	for _, p := range []Params{Falcon512(), NewHope256()} {
 279  		t.Run(paramName(p), func(t *testing.T) {
 280  			psi := big.NewInt(int64(p.RootOfUnity))
 281  			q := big.NewInt(int64(p.Q))
 282  			twoN := big.NewInt(int64(2 * p.N))
 283  
 284  			// psi^{2n} ≡ 1 (mod q)
 285  			order := new(big.Int).Exp(psi, twoN, q)
 286  			if order.Int64() != 1 {
 287  				t.Fatalf("psi^{2n} = %d, want 1 (mod %d)", order, p.Q)
 288  			}
 289  
 290  			// psi^n ≡ -1 (mod q) [negacyclic property]
 291  			nBig := big.NewInt(int64(p.N))
 292  			half := new(big.Int).Exp(psi, nBig, q)
 293  			want := new(big.Int).Sub(q, big.NewInt(1)) // q-1 ≡ -1
 294  			if half.Cmp(want) != 0 {
 295  				t.Fatalf("psi^n = %d, want %d (-1 mod %d)", half, want, p.Q)
 296  			}
 297  		})
 298  	}
 299  }
 300