keyagg_test.go raw
1 package ring
2
3 import "testing"
4
5 func TestAggregateHEKeys2Party(t *testing.T) {
6 kp := DefaultHEParams()
7 seed := []byte("test-common-reference-string")
8
9 // Both parties generate keys with the same shared A.
10 a := GenerateSharedA(kp, seed)
11 pk1, sk1 := HEKeyGenWithA(kp, a)
12 pk2, sk2 := HEKeyGenWithA(kp, a)
13
14 // Aggregate public keys.
15 aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
16 if err != nil {
17 t.Fatalf("AggregateHEKeys: %v", err)
18 }
19
20 // Encrypt under the aggregate key.
21 for _, bit := range []int{0, 1} {
22 ct := HEEncrypt(aggPK, bit)
23
24 // Each party computes a partial decryption.
25 d1 := PartialDecrypt(sk1, ct)
26 d2 := PartialDecrypt(sk2, ct)
27
28 // Combine and recover.
29 got := CombinePartialDecryptions(ct, []*Poly{d1, d2})
30 if got != bit {
31 t.Errorf("2-party decrypt: got %d, want %d", got, bit)
32 }
33 }
34 }
35
36 func TestAggregateHEKeys3Party(t *testing.T) {
37 kp := DefaultHEParams()
38 seed := []byte("three-party-crs")
39
40 a := GenerateSharedA(kp, seed)
41 pk1, sk1 := HEKeyGenWithA(kp, a)
42 pk2, sk2 := HEKeyGenWithA(kp, a)
43 pk3, sk3 := HEKeyGenWithA(kp, a)
44
45 aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2, pk3})
46 if err != nil {
47 t.Fatalf("AggregateHEKeys: %v", err)
48 }
49
50 for _, bit := range []int{0, 1} {
51 ct := HEEncrypt(aggPK, bit)
52 d1 := PartialDecrypt(sk1, ct)
53 d2 := PartialDecrypt(sk2, ct)
54 d3 := PartialDecrypt(sk3, ct)
55 got := CombinePartialDecryptions(ct, []*Poly{d1, d2, d3})
56 if got != bit {
57 t.Errorf("3-party decrypt: got %d, want %d", got, bit)
58 }
59 }
60 }
61
62 func TestAggregateHEKeysEmpty(t *testing.T) {
63 _, err := AggregateHEKeys(nil)
64 if err == nil {
65 t.Error("expected error for nil input")
66 }
67 _, err = AggregateHEKeys([]*KEMPublicKey{})
68 if err == nil {
69 t.Error("expected error for empty input")
70 }
71 }
72
73 func TestAggregateHEKeysSingle(t *testing.T) {
74 kp := DefaultHEParams()
75 a := GenerateSharedA(kp, []byte("single"))
76 pk, sk := HEKeyGenWithA(kp, a)
77
78 aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk})
79 if err != nil {
80 t.Fatalf("AggregateHEKeys single: %v", err)
81 }
82
83 // Single-party aggregation should still decrypt correctly
84 // (aggregate key == the original key).
85 ct := HEEncrypt(aggPK, 1)
86 got := HEDecrypt(sk, ct)
87 if got != 1 {
88 t.Errorf("single-party decrypt: got %d, want 1", got)
89 }
90 }
91
92 func TestAggregateHEKeysMismatchedA(t *testing.T) {
93 kp := DefaultHEParams()
94 pk1, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-1")))
95 pk2, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-2")))
96
97 _, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
98 if err == nil {
99 t.Error("expected error for mismatched A elements")
100 }
101 }
102
103 func TestGenerateSharedADeterministic(t *testing.T) {
104 kp := DefaultHEParams()
105 seed := []byte("determinism-test")
106
107 a1 := GenerateSharedA(kp, seed)
108 a2 := GenerateSharedA(kp, seed)
109
110 if !Equal(a1, a2) {
111 t.Error("GenerateSharedA should be deterministic")
112 }
113 }
114
115 func TestGenerateSharedADifferentSeeds(t *testing.T) {
116 kp := DefaultHEParams()
117 a1 := GenerateSharedA(kp, []byte("seed-a"))
118 a2 := GenerateSharedA(kp, []byte("seed-b"))
119
120 if Equal(a1, a2) {
121 t.Error("different seeds should produce different A elements")
122 }
123 }
124
125 func TestDistributedHEXOR(t *testing.T) {
126 kp := DefaultHEParams()
127 a := GenerateSharedA(kp, []byte("xor-test"))
128 pk1, sk1 := HEKeyGenWithA(kp, a)
129 pk2, sk2 := HEKeyGenWithA(kp, a)
130
131 aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2})
132 if err != nil {
133 t.Fatalf("AggregateHEKeys: %v", err)
134 }
135
136 // Encrypt two bits under the aggregate key, XOR them,
137 // then distributed-decrypt the result.
138 for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} {
139 ct0 := HEEncrypt(aggPK, pair[0])
140 ct1 := HEEncrypt(aggPK, pair[1])
141 xored := HEXOR(ct0, ct1)
142
143 d1 := PartialDecrypt(sk1, xored)
144 d2 := PartialDecrypt(sk2, xored)
145 got := CombinePartialDecryptions(xored, []*Poly{d1, d2})
146 want := pair[0] ^ pair[1]
147 if got != want {
148 t.Errorf("XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want)
149 }
150 }
151 }
152