package ring import "testing" func TestAggregateHEKeys2Party(t *testing.T) { kp := DefaultHEParams() seed := []byte("test-common-reference-string") // Both parties generate keys with the same shared A. a := GenerateSharedA(kp, seed) pk1, sk1 := HEKeyGenWithA(kp, a) pk2, sk2 := HEKeyGenWithA(kp, a) // Aggregate public keys. aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2}) if err != nil { t.Fatalf("AggregateHEKeys: %v", err) } // Encrypt under the aggregate key. for _, bit := range []int{0, 1} { ct := HEEncrypt(aggPK, bit) // Each party computes a partial decryption. d1 := PartialDecrypt(sk1, ct) d2 := PartialDecrypt(sk2, ct) // Combine and recover. got := CombinePartialDecryptions(ct, []*Poly{d1, d2}) if got != bit { t.Errorf("2-party decrypt: got %d, want %d", got, bit) } } } func TestAggregateHEKeys3Party(t *testing.T) { kp := DefaultHEParams() seed := []byte("three-party-crs") a := GenerateSharedA(kp, seed) pk1, sk1 := HEKeyGenWithA(kp, a) pk2, sk2 := HEKeyGenWithA(kp, a) pk3, sk3 := HEKeyGenWithA(kp, a) aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2, pk3}) if err != nil { t.Fatalf("AggregateHEKeys: %v", err) } for _, bit := range []int{0, 1} { ct := HEEncrypt(aggPK, bit) d1 := PartialDecrypt(sk1, ct) d2 := PartialDecrypt(sk2, ct) d3 := PartialDecrypt(sk3, ct) got := CombinePartialDecryptions(ct, []*Poly{d1, d2, d3}) if got != bit { t.Errorf("3-party decrypt: got %d, want %d", got, bit) } } } func TestAggregateHEKeysEmpty(t *testing.T) { _, err := AggregateHEKeys(nil) if err == nil { t.Error("expected error for nil input") } _, err = AggregateHEKeys([]*KEMPublicKey{}) if err == nil { t.Error("expected error for empty input") } } func TestAggregateHEKeysSingle(t *testing.T) { kp := DefaultHEParams() a := GenerateSharedA(kp, []byte("single")) pk, sk := HEKeyGenWithA(kp, a) aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk}) if err != nil { t.Fatalf("AggregateHEKeys single: %v", err) } // Single-party aggregation should still decrypt correctly // (aggregate key == the original key). ct := HEEncrypt(aggPK, 1) got := HEDecrypt(sk, ct) if got != 1 { t.Errorf("single-party decrypt: got %d, want 1", got) } } func TestAggregateHEKeysMismatchedA(t *testing.T) { kp := DefaultHEParams() pk1, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-1"))) pk2, _ := HEKeyGenWithA(kp, GenerateSharedA(kp, []byte("seed-2"))) _, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2}) if err == nil { t.Error("expected error for mismatched A elements") } } func TestGenerateSharedADeterministic(t *testing.T) { kp := DefaultHEParams() seed := []byte("determinism-test") a1 := GenerateSharedA(kp, seed) a2 := GenerateSharedA(kp, seed) if !Equal(a1, a2) { t.Error("GenerateSharedA should be deterministic") } } func TestGenerateSharedADifferentSeeds(t *testing.T) { kp := DefaultHEParams() a1 := GenerateSharedA(kp, []byte("seed-a")) a2 := GenerateSharedA(kp, []byte("seed-b")) if Equal(a1, a2) { t.Error("different seeds should produce different A elements") } } func TestDistributedHEXOR(t *testing.T) { kp := DefaultHEParams() a := GenerateSharedA(kp, []byte("xor-test")) pk1, sk1 := HEKeyGenWithA(kp, a) pk2, sk2 := HEKeyGenWithA(kp, a) aggPK, err := AggregateHEKeys([]*KEMPublicKey{pk1, pk2}) if err != nil { t.Fatalf("AggregateHEKeys: %v", err) } // Encrypt two bits under the aggregate key, XOR them, // then distributed-decrypt the result. for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} { ct0 := HEEncrypt(aggPK, pair[0]) ct1 := HEEncrypt(aggPK, pair[1]) xored := HEXOR(ct0, ct1) d1 := PartialDecrypt(sk1, xored) d2 := PartialDecrypt(sk2, xored) got := CombinePartialDecryptions(xored, []*Poly{d1, d2}) want := pair[0] ^ pair[1] if got != want { t.Errorf("XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want) } } }