package ring import "testing" func setupMPC2Party(t *testing.T) (*MPCSession, *KEMSecretKey, *KEMSecretKey) { t.Helper() kp := DefaultHEParams() a := GenerateSharedA(kp, []byte("mpc-test-crs")) pk1, sk1 := HEKeyGenWithA(kp, a) pk2, sk2 := HEKeyGenWithA(kp, a) sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, nil, nil) if err != nil { t.Fatalf("NewMPCSession: %v", err) } return sess, sk1, sk2 } func mpcDecrypt2(t *testing.T, sk1, sk2 *KEMSecretKey, ct *HECiphertext) int { t.Helper() d1 := PartialDecrypt(sk1, ct) d2 := PartialDecrypt(sk2, ct) return DecryptDistributed(ct, []*Poly{d1, d2}) } func TestMPCEncryptDecrypt(t *testing.T) { sess, sk1, sk2 := setupMPC2Party(t) for _, bit := range []int{0, 1} { ct := sess.Encrypt(bit) got := mpcDecrypt2(t, sk1, sk2, ct) if got != bit { t.Errorf("MPC encrypt/decrypt: got %d, want %d", got, bit) } } } func TestMPCXOR(t *testing.T) { sess, sk1, sk2 := setupMPC2Party(t) for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} { ct0 := sess.Encrypt(pair[0]) ct1 := sess.Encrypt(pair[1]) result := sess.XOR(ct0, ct1) if !sess.Verify(result) { t.Fatal("XOR result failed verification") } got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext) want := pair[0] ^ pair[1] if got != want { t.Errorf("MPC XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want) } } } func TestMPCNOT(t *testing.T) { sess, sk1, sk2 := setupMPC2Party(t) for _, bit := range []int{0, 1} { ct := sess.Encrypt(bit) result := sess.NOT(ct) if !sess.Verify(result) { t.Fatal("NOT result failed verification") } got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext) want := 1 - bit if got != want { t.Errorf("MPC NOT(%d): got %d, want %d", bit, got, want) } } } func TestMPCAdd(t *testing.T) { sess, sk1, sk2 := setupMPC2Party(t) // HEAdd is XOR for binary plaintext (mod 2 addition). ct0 := sess.Encrypt(1) ct1 := sess.Encrypt(1) result := sess.Add(ct0, ct1) if !sess.Verify(result) { t.Fatal("Add result failed verification") } got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext) if got != 0 { t.Errorf("MPC Add(1,1): got %d, want 0 (1+1 mod 2)", got) } } func TestMPCANDRequiresRLK(t *testing.T) { sess, _, _ := setupMPC2Party(t) ct0 := sess.Encrypt(1) ct1 := sess.Encrypt(1) _, err := sess.AND(ct0, ct1) if err == nil { t.Error("AND without RLK should return error") } } func TestMPCVerifyRejectsTamperedTag(t *testing.T) { sess, _, _ := setupMPC2Party(t) ct := sess.Encrypt(1) result := sess.XOR(ct, sess.Encrypt(0)) // Tamper with the tag. result.Tag[0] ^= 0xFF if sess.Verify(result) { t.Error("tampered tag should fail verification") } } func TestMPCUnwrap(t *testing.T) { sess, _, _ := setupMPC2Party(t) ct := sess.Encrypt(1) result := sess.XOR(ct, sess.Encrypt(0)) // Valid unwrap. unwrapped := sess.Unwrap(result) if unwrapped == nil { t.Fatal("Unwrap returned nil for valid result") } // Tamper and unwrap should return nil. result.Tag[0] ^= 0xFF if sess.Unwrap(result) != nil { t.Error("Unwrap should return nil for tampered result") } } func TestMPCSessionID(t *testing.T) { sess, _, _ := setupMPC2Party(t) sid := sess.SessionID() if len(sid) != 32 { t.Errorf("SessionID length = %d, want 32", len(sid)) } // Non-zero. allZero := true for _, b := range sid { if b != 0 { allZero = false break } } if allZero { t.Error("SessionID should not be all zeros") } } func TestMPCCrossSessionReject(t *testing.T) { kp := DefaultHEParams() a := GenerateSharedA(kp, []byte("cross-session-crs")) pk1, _ := HEKeyGenWithA(kp, a) pk2, _ := HEKeyGenWithA(kp, a) pks := []*KEMPublicKey{pk1, pk2} sess1, err := NewMPCSession(pks, nil, nil, nil) if err != nil { t.Fatalf("NewMPCSession 1: %v", err) } sess2, err := NewMPCSession(pks, nil, nil, nil) if err != nil { t.Fatalf("NewMPCSession 2: %v", err) } // Result from session 1 should not verify under session 2. ct := sess1.Encrypt(1) result := sess1.XOR(ct, sess1.Encrypt(0)) if sess2.Verify(result) { t.Error("result from session 1 should not verify under session 2") } } func TestMPCWithGPVSignature(t *testing.T) { kp := DefaultHEParams() a := GenerateSharedA(kp, []byte("gpv-mpc-crs")) pk1, sk1 := HEKeyGenWithA(kp, a) pk2, sk2 := HEKeyGenWithA(kp, a) gp := SmallGPVParams() gpvPK, gpvSK := GPVKeyGen(gp) sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, gpvPK, gpvSK) if err != nil { t.Fatalf("NewMPCSession: %v", err) } ct := sess.Encrypt(1) result := sess.XOR(ct, sess.Encrypt(0)) // Verify includes GPV signature check. if !sess.Verify(result) { t.Fatal("signed result failed verification") } if result.Signature == nil { t.Fatal("expected GPV signature on result") } // Decrypt to confirm the computation is correct. got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext) if got != 1 { t.Errorf("signed MPC XOR(1,0): got %d, want 1", got) } }