he_test.go raw

   1  package ring
   2  
   3  import "testing"
   4  
   5  // TestHEEncryptDecrypt verifies single-bit encrypt/decrypt.
   6  func TestHEEncryptDecrypt(t *testing.T) {
   7  	kp := DefaultHEParams()
   8  	pk, sk, _ := HEKeyGen(kp)
   9  
  10  	for _, bit := range []int{0, 1} {
  11  		ct := HEEncrypt(pk, bit)
  12  		got := HEDecrypt(sk, ct)
  13  		if got != bit {
  14  			t.Fatalf("bit=%d: decrypted to %d", bit, got)
  15  		}
  16  	}
  17  }
  18  
  19  // TestHEAdd verifies homomorphic addition (XOR).
  20  func TestHEAdd(t *testing.T) {
  21  	kp := DefaultHEParams()
  22  	pk, sk, _ := HEKeyGen(kp)
  23  
  24  	// Test all XOR combinations.
  25  	for _, a := range []int{0, 1} {
  26  		for _, b := range []int{0, 1} {
  27  			ctA := HEEncrypt(pk, a)
  28  			ctB := HEEncrypt(pk, b)
  29  			ctSum := HEAdd(ctA, ctB)
  30  			got := HEDecrypt(sk, ctSum)
  31  			want := a ^ b
  32  			if got != want {
  33  				t.Fatalf("%d XOR %d: got %d, want %d", a, b, got, want)
  34  			}
  35  		}
  36  	}
  37  }
  38  
  39  // TestHENot verifies homomorphic NOT.
  40  func TestHENot(t *testing.T) {
  41  	kp := DefaultHEParams()
  42  	pk, sk, _ := HEKeyGen(kp)
  43  
  44  	for _, bit := range []int{0, 1} {
  45  		ct := HEEncrypt(pk, bit)
  46  		ctNot := HENot(ct)
  47  		got := HEDecrypt(sk, ctNot)
  48  		want := 1 - bit
  49  		if got != want {
  50  			t.Fatalf("NOT(%d): got %d, want %d", bit, got, want)
  51  		}
  52  	}
  53  }
  54  
  55  // TestHEMul verifies homomorphic multiplication (AND).
  56  func TestHEMul(t *testing.T) {
  57  	kp := DefaultHEParams()
  58  	pk, sk, rlk := HEKeyGen(kp)
  59  
  60  	for _, a := range []int{0, 1} {
  61  		for _, b := range []int{0, 1} {
  62  			ctA := HEEncrypt(pk, a)
  63  			ctB := HEEncrypt(pk, b)
  64  			ctProd := HEMul(ctA, ctB, rlk)
  65  			got := HEDecrypt(sk, ctProd)
  66  			want := a & b
  67  			if got != want {
  68  				t.Fatalf("%d AND %d: got %d, want %d", a, b, got, want)
  69  			}
  70  		}
  71  	}
  72  }
  73  
  74  // TestHERerandomize verifies that rerandomization preserves the plaintext.
  75  func TestHERerandomize(t *testing.T) {
  76  	kp := DefaultHEParams()
  77  	pk, sk, _ := HEKeyGen(kp)
  78  
  79  	for _, bit := range []int{0, 1} {
  80  		ct := HEEncrypt(pk, bit)
  81  		ctRerand := Rerandomize(pk, ct)
  82  
  83  		got := HEDecrypt(sk, ctRerand)
  84  		if got != bit {
  85  			t.Fatalf("rerandomize(%d): decrypted to %d", bit, got)
  86  		}
  87  
  88  		// Verify ciphertexts are actually different.
  89  		if Equal(ct.U, ctRerand.U) && Equal(ct.V, ctRerand.V) {
  90  			t.Fatal("rerandomized ciphertext is identical — randomness failure")
  91  		}
  92  	}
  93  }
  94  
  95  // TestHEChainedAdd verifies multiple additions stay within noise budget.
  96  func TestHEChainedAdd(t *testing.T) {
  97  	kp := DefaultHEParams()
  98  	pk, sk, _ := HEKeyGen(kp)
  99  
 100  	// XOR 16 bits together: result should be 0 (even count of 1s) or 1 (odd).
 101  	bits := []int{1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0}
 102  	expected := 0
 103  	for _, b := range bits {
 104  		expected ^= b
 105  	}
 106  
 107  	acc := HEEncrypt(pk, bits[0])
 108  	for _, b := range bits[1:] {
 109  		ct := HEEncrypt(pk, b)
 110  		acc = HEAdd(acc, ct)
 111  	}
 112  
 113  	got := HEDecrypt(sk, acc)
 114  	if got != expected {
 115  		t.Fatalf("chained XOR: got %d, want %d (noise estimate: %.1f)", got, expected, acc.NoiseEstimate)
 116  	}
 117  }
 118  
 119  // TestHEBGVNoiseParity verifies that BGV encryption produces even noise + message.
 120  func TestHEBGVNoiseParity(t *testing.T) {
 121  	kp := DefaultHEParams()
 122  	pk, sk, _ := HEKeyGen(kp)
 123  	p := kp.Ring
 124  
 125  	for _, bit := range []int{0, 1} {
 126  		ct := HEEncrypt(pk, bit)
 127  
 128  		// Compute v - s·u
 129  		uNTT := ct.U.Clone()
 130  		NTT(uNTT)
 131  		su := MulPointwise(sk.S, uNTT)
 132  		INTT(su)
 133  		noisy := Sub(ct.V, su)
 134  
 135  		c := noisy.Coeffs[0]
 136  		q := p.Q
 137  		if c > q/2 {
 138  			c = q - c
 139  		}
 140  		parity := c % 2
 141  
 142  		t.Logf("bit=%d: coeff[0]=%d, centered=%d, parity=%d (want parity=%d)",
 143  			bit, noisy.Coeffs[0], c, parity, bit)
 144  
 145  		if int(parity) != bit {
 146  			t.Errorf("bit=%d: wrong parity %d", bit, parity)
 147  		}
 148  	}
 149  }
 150  
 151  // TestHEMulDeg2 tests the tensor product without relinearization.
 152  // This isolates whether the issue is in the tensor product or relinearization.
 153  func TestHEMulDeg2(t *testing.T) {
 154  	kp := DefaultHEParams()
 155  	pk, sk, _ := HEKeyGen(kp)
 156  	p := kp.Ring
 157  
 158  	for _, a := range []int{0, 1} {
 159  		for _, b := range []int{0, 1} {
 160  			ctA := HEEncrypt(pk, a)
 161  			ctB := HEEncrypt(pk, b)
 162  
 163  			// Tensor product: c0 = v_a * v_b, c1 = v_a*u_b + u_a*v_b, c2 = u_a * u_b
 164  			c0 := Mul(ctA.V, ctB.V)
 165  			c1a := Mul(ctA.V, ctB.U)
 166  			c1b := Mul(ctA.U, ctB.V)
 167  			c1 := Add(c1a, c1b)
 168  			c2 := Mul(ctA.U, ctB.U)
 169  
 170  			// Degree-2 decryption: c0 - s*c1 + s²*c2
 171  			// s is in NTT form
 172  			sCoeff := sk.S.Clone()
 173  
 174  			c1NTT := c1.Clone()
 175  			NTT(c1NTT)
 176  			sc1 := MulPointwise(sCoeff, c1NTT)
 177  			INTT(sc1)
 178  
 179  			// s²
 180  			s2 := MulPointwise(sCoeff, sCoeff)
 181  			c2NTT := c2.Clone()
 182  			NTT(c2NTT)
 183  			s2c2 := MulPointwise(s2, c2NTT)
 184  			INTT(s2c2)
 185  
 186  			result := Sub(c0, sc1)
 187  			result = Add(result, s2c2)
 188  
 189  			// Decode: result[0] mod 2
 190  			q := p.Q
 191  			c := result.Coeffs[0]
 192  			if c > q/2 {
 193  				c = q - c
 194  			}
 195  			got := int(c % 2)
 196  			want := a & b
 197  
 198  			t.Logf("%d AND %d: coeff[0]=%d, centered=%d, mod2=%d (want %d)",
 199  				a, b, result.Coeffs[0], c, got, want)
 200  
 201  			if got != want {
 202  				t.Errorf("%d AND %d: got %d, want %d", a, b, got, want)
 203  			}
 204  		}
 205  	}
 206  }
 207  
 208  func BenchmarkHEEncrypt(b *testing.B) {
 209  	kp := DefaultHEParams()
 210  	pk, _, _ := HEKeyGen(kp)
 211  	b.ResetTimer()
 212  	for range b.N {
 213  		HEEncrypt(pk, 1)
 214  	}
 215  }
 216  
 217  func BenchmarkHEAdd(b *testing.B) {
 218  	kp := DefaultHEParams()
 219  	pk, _, _ := HEKeyGen(kp)
 220  	ct0 := HEEncrypt(pk, 0)
 221  	ct1 := HEEncrypt(pk, 1)
 222  	b.ResetTimer()
 223  	for range b.N {
 224  		HEAdd(ct0, ct1)
 225  	}
 226  }
 227  
 228  func BenchmarkHEMul(b *testing.B) {
 229  	kp := DefaultHEParams()
 230  	pk, _, rlk := HEKeyGen(kp)
 231  	ct0 := HEEncrypt(pk, 1)
 232  	ct1 := HEEncrypt(pk, 1)
 233  	b.ResetTimer()
 234  	for range b.N {
 235  		HEMul(ct0, ct1, rlk)
 236  	}
 237  }
 238