keyagg_test.go raw

   1  package ring
   2  
   3  import "testing"
   4  
   5  func TestAggregateHEKeys2Party(t *testing.T) {
   6  	kp := DefaultHEParams()
   7  	seed := []byte("test-common-reference-string")
   8  
   9  	// Both parties generate keys with the same shared A.
  10  	a := GenerateSharedA(kp, seed)
  11  	pk1, sk1 := HEKeyGenWithA(kp, a)
  12  	pk2, sk2 := HEKeyGenWithA(kp, a)
  13  
  14  	// Aggregate public keys.
  15  	aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
  16  	if err != nil {
  17  		t.Fatalf("AggregateHEKeys: %v", err)
  18  	}
  19  
  20  	// Encrypt under the aggregate key.
  21  	for _, bit := range []int{0, 1} {
  22  		ct := HEEncrypt(aggPK, bit)
  23  
  24  		// Each party computes a partial decryption.
  25  		d1 := PartialDecrypt(sk1, ct)
  26  		d2 := PartialDecrypt(sk2, ct)
  27  
  28  		// Combine and recover.
  29  		got := CombinePartialDecryptions(ct, []*Poly{d1, d2})
  30  		if got != bit {
  31  			t.Errorf("2-party decrypt: got %d, want %d", got, bit)
  32  		}
  33  	}
  34  }
  35  
  36  func TestAggregateHEKeys3Party(t *testing.T) {
  37  	kp := DefaultHEParams()
  38  	seed := []byte("three-party-crs")
  39  
  40  	a := GenerateSharedA(kp, seed)
  41  	pk1, sk1 := HEKeyGenWithA(kp, a)
  42  	pk2, sk2 := HEKeyGenWithA(kp, a)
  43  	pk3, sk3 := HEKeyGenWithA(kp, a)
  44  
  45  	aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2, pk3})
  46  	if err != nil {
  47  		t.Fatalf("AggregateHEKeys: %v", err)
  48  	}
  49  
  50  	for _, bit := range []int{0, 1} {
  51  		ct := HEEncrypt(aggPK, bit)
  52  		d1 := PartialDecrypt(sk1, ct)
  53  		d2 := PartialDecrypt(sk2, ct)
  54  		d3 := PartialDecrypt(sk3, ct)
  55  		got := CombinePartialDecryptions(ct, []*Poly{d1, d2, d3})
  56  		if got != bit {
  57  			t.Errorf("3-party decrypt: got %d, want %d", got, bit)
  58  		}
  59  	}
  60  }
  61  
  62  func TestAggregateHEKeysEmpty(t *testing.T) {
  63  	_, err := AggregateHEKeys(nil)
  64  	if err == nil {
  65  		t.Error("expected error for nil input")
  66  	}
  67  	_, err = AggregateHEKeys([]*KEMPublicKey{})
  68  	if err == nil {
  69  		t.Error("expected error for empty input")
  70  	}
  71  }
  72  
  73  func TestAggregateHEKeysSingle(t *testing.T) {
  74  	kp := DefaultHEParams()
  75  	a := GenerateSharedA(kp, []byte("single"))
  76  	pk, sk := HEKeyGenWithA(kp, a)
  77  
  78  	aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk})
  79  	if err != nil {
  80  		t.Fatalf("AggregateHEKeys single: %v", err)
  81  	}
  82  
  83  	// Single-party aggregation should still decrypt correctly
  84  	// (aggregate key == the original key).
  85  	ct := HEEncrypt(aggPK, 1)
  86  	got := HEDecrypt(sk, ct)
  87  	if got != 1 {
  88  		t.Errorf("single-party decrypt: got %d, want 1", got)
  89  	}
  90  }
  91  
  92  func TestAggregateHEKeysMismatchedA(t *testing.T) {
  93  	kp := DefaultHEParams()
  94  	pk1, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-1")))
  95  	pk2, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-2")))
  96  
  97  	_, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
  98  	if err == nil {
  99  		t.Error("expected error for mismatched A elements")
 100  	}
 101  }
 102  
 103  func TestGenerateSharedADeterministic(t *testing.T) {
 104  	kp := DefaultHEParams()
 105  	seed := []byte("determinism-test")
 106  
 107  	a1 := GenerateSharedA(kp, seed)
 108  	a2 := GenerateSharedA(kp, seed)
 109  
 110  	if !Equal(a1, a2) {
 111  		t.Error("GenerateSharedA should be deterministic")
 112  	}
 113  }
 114  
 115  func TestGenerateSharedADifferentSeeds(t *testing.T) {
 116  	kp := DefaultHEParams()
 117  	a1 := GenerateSharedA(kp, []byte("seed-a"))
 118  	a2 := GenerateSharedA(kp, []byte("seed-b"))
 119  
 120  	if Equal(a1, a2) {
 121  		t.Error("different seeds should produce different A elements")
 122  	}
 123  }
 124  
 125  func TestDistributedHEXOR(t *testing.T) {
 126  	kp := DefaultHEParams()
 127  	a := GenerateSharedA(kp, []byte("xor-test"))
 128  	pk1, sk1 := HEKeyGenWithA(kp, a)
 129  	pk2, sk2 := HEKeyGenWithA(kp, a)
 130  
 131  	aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
 132  	if err != nil {
 133  		t.Fatalf("AggregateHEKeys: %v", err)
 134  	}
 135  
 136  	// Encrypt two bits under the aggregate key, XOR them,
 137  	// then distributed-decrypt the result.
 138  	for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} {
 139  		ct0 := HEEncrypt(aggPK, pair[0])
 140  		ct1 := HEEncrypt(aggPK, pair[1])
 141  		xored := HEXOR(ct0, ct1)
 142  
 143  		d1 := PartialDecrypt(sk1, xored)
 144  		d2 := PartialDecrypt(sk2, xored)
 145  		got := CombinePartialDecryptions(xored, []*Poly{d1, d2})
 146  		want := pair[0] ^ pair[1]
 147  		if got != want {
 148  			t.Errorf("XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want)
 149  		}
 150  	}
 151  }
 152