ntt_test.go raw

   1  package crypto
   2  
   3  import "testing"
   4  
   5  func TestBitRev6(t *testing.T) {
   6  	// 6-bit reversal: 0b000001 -> 0b100000 = 32
   7  	if got := bitRev6(1); got != 32 {
   8  		t.Errorf("bitRev6(1) = %d, want 32", got)
   9  	}
  10  	// Identity: 0 -> 0, 63 -> 63
  11  	if got := bitRev6(0); got != 0 {
  12  		t.Errorf("bitRev6(0) = %d, want 0", got)
  13  	}
  14  	if got := bitRev6(63); got != 63 {
  15  		t.Errorf("bitRev6(63) = %d, want 63", got)
  16  	}
  17  	// 0b010101 = 21 -> 0b101010 = 42
  18  	if got := bitRev6(21); got != 42 {
  19  		t.Errorf("bitRev6(21) = %d, want 42", got)
  20  	}
  21  }
  22  
  23  func TestNTTRoundTrip(t *testing.T) {
  24  	// Create a polynomial with known coefficients.
  25  	var a [HamN]uint16
  26  	for i := range a {
  27  		a[i] = uint16(i % HamP)
  28  	}
  29  	original := a
  30  
  31  	// Forward NTT then inverse should recover original.
  32  	ntt64(&a)
  33  	intt64(&a)
  34  
  35  	for i := range a {
  36  		if a[i] != original[i] {
  37  			t.Errorf("round-trip failed at [%d]: got %d, want %d", i, a[i], original[i])
  38  		}
  39  	}
  40  }
  41  
  42  func TestNTTZero(t *testing.T) {
  43  	var a [HamN]uint16
  44  	ntt64(&a)
  45  	for i, v := range a {
  46  		if v != 0 {
  47  			t.Errorf("NTT of zero poly: a[%d] = %d, want 0", i, v)
  48  		}
  49  	}
  50  }
  51  
  52  func TestNTTConvolution(t *testing.T) {
  53  	// Verify that pointwise multiply in NTT domain corresponds to
  54  	// polynomial multiplication mod (x^64 + 1) mod 257.
  55  
  56  	// f(x) = 1 + x
  57  	var f [HamN]uint16
  58  	f[0] = 1
  59  	f[1] = 1
  60  
  61  	// g(x) = 1 + x
  62  	var g [HamN]uint16
  63  	g[0] = 1
  64  	g[1] = 1
  65  
  66  	// f*g mod (x^64+1) = 1 + 2x + x^2 (since degree < 64, no wraparound).
  67  	ntt64(&f)
  68  	ntt64(&g)
  69  
  70  	var h [HamN]uint16
  71  	for i := range h {
  72  		h[i] = mod257(int(f[i]) * int(g[i]))
  73  	}
  74  
  75  	intt64(&h)
  76  
  77  	// Expected: h[0]=1, h[1]=2, h[2]=1, rest 0.
  78  	if h[0] != 1 {
  79  		t.Errorf("h[0] = %d, want 1", h[0])
  80  	}
  81  	if h[1] != 2 {
  82  		t.Errorf("h[1] = %d, want 2", h[1])
  83  	}
  84  	if h[2] != 1 {
  85  		t.Errorf("h[2] = %d, want 1", h[2])
  86  	}
  87  	for i := 3; i < HamN; i++ {
  88  		if h[i] != 0 {
  89  			t.Errorf("h[%d] = %d, want 0", i, h[i])
  90  		}
  91  	}
  92  }
  93  
  94  func TestNTTWraparound(t *testing.T) {
  95  	// Test the negacyclic property: x^64 ≡ -1 mod (x^64 + 1).
  96  	// f(x) = x^63, g(x) = x → f*g = x^64 ≡ -1 ≡ 256 mod 257.
  97  
  98  	var f [HamN]uint16
  99  	f[63] = 1
 100  
 101  	var g [HamN]uint16
 102  	g[1] = 1
 103  
 104  	ntt64(&f)
 105  	ntt64(&g)
 106  
 107  	var h [HamN]uint16
 108  	for i := range h {
 109  		h[i] = mod257(int(f[i]) * int(g[i]))
 110  	}
 111  
 112  	intt64(&h)
 113  
 114  	// Expected: h[0] = 256 (= -1 mod 257), rest 0.
 115  	if h[0] != 256 {
 116  		t.Errorf("h[0] = %d, want 256 (-1 mod 257)", h[0])
 117  	}
 118  	for i := 1; i < HamN; i++ {
 119  		if h[i] != 0 {
 120  			t.Errorf("h[%d] = %d, want 0", i, h[i])
 121  		}
 122  	}
 123  }
 124  
 125  func TestPowMod(t *testing.T) {
 126  	// 3 is a generator of Z_257*. 3^256 ≡ 1 mod 257 (Fermat).
 127  	if got := powMod(3, 256); got != 1 {
 128  		t.Errorf("3^256 mod 257 = %d, want 1", got)
 129  	}
 130  	// 3^128 ≡ 256 ≡ -1 mod 257 (since order is 256, half-order gives -1).
 131  	if got := powMod(3, 128); got != 256 {
 132  		t.Errorf("3^128 mod 257 = %d, want 256", got)
 133  	}
 134  	// 9^64 = (3^2)^64 = 3^128 ≡ -1 mod 257.
 135  	if got := powMod(9, 64); got != 256 {
 136  		t.Errorf("9^64 mod 257 = %d, want 256", got)
 137  	}
 138  }
 139