mpc_test.go raw
1 package ring
2
3 import "testing"
4
5 func setupMPC2Party(t *testing.T) (*MPCSession, *KEMSecretKey, *KEMSecretKey) {
6 t.Helper()
7 kp := DefaultHEParams()
8 a := GenerateSharedA(kp, []byte("mpc-test-crs"))
9 pk1, sk1 := HEKeyGenWithA(kp, a)
10 pk2, sk2 := HEKeyGenWithA(kp, a)
11
12 sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, nil, nil)
13 if err != nil {
14 t.Fatalf("NewMPCSession: %v", err)
15 }
16 return sess, sk1, sk2
17 }
18
19 func mpcDecrypt2(t *testing.T, sk1, sk2 *KEMSecretKey, ct *HECiphertext) int {
20 t.Helper()
21 d1 := PartialDecrypt(sk1, ct)
22 d2 := PartialDecrypt(sk2, ct)
23 return DecryptDistributed(ct, []*Poly{d1, d2})
24 }
25
26 func TestMPCEncryptDecrypt(t *testing.T) {
27 sess, sk1, sk2 := setupMPC2Party(t)
28
29 for _, bit := range []int{0, 1} {
30 ct := sess.Encrypt(bit)
31 got := mpcDecrypt2(t, sk1, sk2, ct)
32 if got != bit {
33 t.Errorf("MPC encrypt/decrypt: got %d, want %d", got, bit)
34 }
35 }
36 }
37
38 func TestMPCXOR(t *testing.T) {
39 sess, sk1, sk2 := setupMPC2Party(t)
40
41 for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} {
42 ct0 := sess.Encrypt(pair[0])
43 ct1 := sess.Encrypt(pair[1])
44
45 result := sess.XOR(ct0, ct1)
46 if !sess.Verify(result) {
47 t.Fatal("XOR result failed verification")
48 }
49
50 got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
51 want := pair[0] ^ pair[1]
52 if got != want {
53 t.Errorf("MPC XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want)
54 }
55 }
56 }
57
58 func TestMPCNOT(t *testing.T) {
59 sess, sk1, sk2 := setupMPC2Party(t)
60
61 for _, bit := range []int{0, 1} {
62 ct := sess.Encrypt(bit)
63 result := sess.NOT(ct)
64 if !sess.Verify(result) {
65 t.Fatal("NOT result failed verification")
66 }
67
68 got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
69 want := 1 - bit
70 if got != want {
71 t.Errorf("MPC NOT(%d): got %d, want %d", bit, got, want)
72 }
73 }
74 }
75
76 func TestMPCAdd(t *testing.T) {
77 sess, sk1, sk2 := setupMPC2Party(t)
78
79 // HEAdd is XOR for binary plaintext (mod 2 addition).
80 ct0 := sess.Encrypt(1)
81 ct1 := sess.Encrypt(1)
82 result := sess.Add(ct0, ct1)
83 if !sess.Verify(result) {
84 t.Fatal("Add result failed verification")
85 }
86
87 got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
88 if got != 0 {
89 t.Errorf("MPC Add(1,1): got %d, want 0 (1+1 mod 2)", got)
90 }
91 }
92
93 func TestMPCANDRequiresRLK(t *testing.T) {
94 sess, _, _ := setupMPC2Party(t)
95 ct0 := sess.Encrypt(1)
96 ct1 := sess.Encrypt(1)
97
98 _, err := sess.AND(ct0, ct1)
99 if err == nil {
100 t.Error("AND without RLK should return error")
101 }
102 }
103
104 func TestMPCVerifyRejectsTamperedTag(t *testing.T) {
105 sess, _, _ := setupMPC2Party(t)
106 ct := sess.Encrypt(1)
107 result := sess.XOR(ct, sess.Encrypt(0))
108
109 // Tamper with the tag.
110 result.Tag[0] ^= 0xFF
111 if sess.Verify(result) {
112 t.Error("tampered tag should fail verification")
113 }
114 }
115
116 func TestMPCUnwrap(t *testing.T) {
117 sess, _, _ := setupMPC2Party(t)
118 ct := sess.Encrypt(1)
119 result := sess.XOR(ct, sess.Encrypt(0))
120
121 // Valid unwrap.
122 unwrapped := sess.Unwrap(result)
123 if unwrapped == nil {
124 t.Fatal("Unwrap returned nil for valid result")
125 }
126
127 // Tamper and unwrap should return nil.
128 result.Tag[0] ^= 0xFF
129 if sess.Unwrap(result) != nil {
130 t.Error("Unwrap should return nil for tampered result")
131 }
132 }
133
134 func TestMPCSessionID(t *testing.T) {
135 sess, _, _ := setupMPC2Party(t)
136 sid := sess.SessionID()
137 if len(sid) != 32 {
138 t.Errorf("SessionID length = %d, want 32", len(sid))
139 }
140
141 // Non-zero.
142 allZero := true
143 for _, b := range sid {
144 if b != 0 {
145 allZero = false
146 break
147 }
148 }
149 if allZero {
150 t.Error("SessionID should not be all zeros")
151 }
152 }
153
154 func TestMPCCrossSessionReject(t *testing.T) {
155 kp := DefaultHEParams()
156 a := GenerateSharedA(kp, []byte("cross-session-crs"))
157 pk1, _ := HEKeyGenWithA(kp, a)
158 pk2, _ := HEKeyGenWithA(kp, a)
159 pks := []*KEMPublicKey{pk1, pk2}
160
161 sess1, err := NewMPCSession(pks, nil, nil, nil)
162 if err != nil {
163 t.Fatalf("NewMPCSession 1: %v", err)
164 }
165 sess2, err := NewMPCSession(pks, nil, nil, nil)
166 if err != nil {
167 t.Fatalf("NewMPCSession 2: %v", err)
168 }
169
170 // Result from session 1 should not verify under session 2.
171 ct := sess1.Encrypt(1)
172 result := sess1.XOR(ct, sess1.Encrypt(0))
173
174 if sess2.Verify(result) {
175 t.Error("result from session 1 should not verify under session 2")
176 }
177 }
178
179 func TestMPCWithGPVSignature(t *testing.T) {
180 kp := DefaultHEParams()
181 a := GenerateSharedA(kp, []byte("gpv-mpc-crs"))
182 pk1, sk1 := HEKeyGenWithA(kp, a)
183 pk2, sk2 := HEKeyGenWithA(kp, a)
184
185 gp := SmallGPVParams()
186 gpvPK, gpvSK := GPVKeyGen(gp)
187
188 sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, gpvPK, gpvSK)
189 if err != nil {
190 t.Fatalf("NewMPCSession: %v", err)
191 }
192
193 ct := sess.Encrypt(1)
194 result := sess.XOR(ct, sess.Encrypt(0))
195
196 // Verify includes GPV signature check.
197 if !sess.Verify(result) {
198 t.Fatal("signed result failed verification")
199 }
200 if result.Signature == nil {
201 t.Fatal("expected GPV signature on result")
202 }
203
204 // Decrypt to confirm the computation is correct.
205 got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
206 if got != 1 {
207 t.Errorf("signed MPC XOR(1,0): got %d, want 1", got)
208 }
209 }
210