mpc_test.go raw

   1  package ring
   2  
   3  import "testing"
   4  
   5  func setupMPC2Party(t *testing.T) (*MPCSession, *KEMSecretKey, *KEMSecretKey) {
   6  	t.Helper()
   7  	kp := DefaultHEParams()
   8  	a := GenerateSharedA(kp, []byte("mpc-test-crs"))
   9  	pk1, sk1 := HEKeyGenWithA(kp, a)
  10  	pk2, sk2 := HEKeyGenWithA(kp, a)
  11  
  12  	sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, nil, nil)
  13  	if err != nil {
  14  		t.Fatalf("NewMPCSession: %v", err)
  15  	}
  16  	return sess, sk1, sk2
  17  }
  18  
  19  func mpcDecrypt2(t *testing.T, sk1, sk2 *KEMSecretKey, ct *HECiphertext) int {
  20  	t.Helper()
  21  	d1 := PartialDecrypt(sk1, ct)
  22  	d2 := PartialDecrypt(sk2, ct)
  23  	return DecryptDistributed(ct, []*Poly{d1, d2})
  24  }
  25  
  26  func TestMPCEncryptDecrypt(t *testing.T) {
  27  	sess, sk1, sk2 := setupMPC2Party(t)
  28  
  29  	for _, bit := range []int{0, 1} {
  30  		ct := sess.Encrypt(bit)
  31  		got := mpcDecrypt2(t, sk1, sk2, ct)
  32  		if got != bit {
  33  			t.Errorf("MPC encrypt/decrypt: got %d, want %d", got, bit)
  34  		}
  35  	}
  36  }
  37  
  38  func TestMPCXOR(t *testing.T) {
  39  	sess, sk1, sk2 := setupMPC2Party(t)
  40  
  41  	for _, pair := range [][2]int{{0, 0}, {0, 1}, {1, 0}, {1, 1}} {
  42  		ct0 := sess.Encrypt(pair[0])
  43  		ct1 := sess.Encrypt(pair[1])
  44  
  45  		result := sess.XOR(ct0, ct1)
  46  		if !sess.Verify(result) {
  47  			t.Fatal("XOR result failed verification")
  48  		}
  49  
  50  		got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
  51  		want := pair[0] ^ pair[1]
  52  		if got != want {
  53  			t.Errorf("MPC XOR(%d,%d): got %d, want %d", pair[0], pair[1], got, want)
  54  		}
  55  	}
  56  }
  57  
  58  func TestMPCNOT(t *testing.T) {
  59  	sess, sk1, sk2 := setupMPC2Party(t)
  60  
  61  	for _, bit := range []int{0, 1} {
  62  		ct := sess.Encrypt(bit)
  63  		result := sess.NOT(ct)
  64  		if !sess.Verify(result) {
  65  			t.Fatal("NOT result failed verification")
  66  		}
  67  
  68  		got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
  69  		want := 1 - bit
  70  		if got != want {
  71  			t.Errorf("MPC NOT(%d): got %d, want %d", bit, got, want)
  72  		}
  73  	}
  74  }
  75  
  76  func TestMPCAdd(t *testing.T) {
  77  	sess, sk1, sk2 := setupMPC2Party(t)
  78  
  79  	// HEAdd is XOR for binary plaintext (mod 2 addition).
  80  	ct0 := sess.Encrypt(1)
  81  	ct1 := sess.Encrypt(1)
  82  	result := sess.Add(ct0, ct1)
  83  	if !sess.Verify(result) {
  84  		t.Fatal("Add result failed verification")
  85  	}
  86  
  87  	got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
  88  	if got != 0 {
  89  		t.Errorf("MPC Add(1,1): got %d, want 0 (1+1 mod 2)", got)
  90  	}
  91  }
  92  
  93  func TestMPCANDRequiresRLK(t *testing.T) {
  94  	sess, _, _ := setupMPC2Party(t)
  95  	ct0 := sess.Encrypt(1)
  96  	ct1 := sess.Encrypt(1)
  97  
  98  	_, err := sess.AND(ct0, ct1)
  99  	if err == nil {
 100  		t.Error("AND without RLK should return error")
 101  	}
 102  }
 103  
 104  func TestMPCVerifyRejectsTamperedTag(t *testing.T) {
 105  	sess, _, _ := setupMPC2Party(t)
 106  	ct := sess.Encrypt(1)
 107  	result := sess.XOR(ct, sess.Encrypt(0))
 108  
 109  	// Tamper with the tag.
 110  	result.Tag[0] ^= 0xFF
 111  	if sess.Verify(result) {
 112  		t.Error("tampered tag should fail verification")
 113  	}
 114  }
 115  
 116  func TestMPCUnwrap(t *testing.T) {
 117  	sess, _, _ := setupMPC2Party(t)
 118  	ct := sess.Encrypt(1)
 119  	result := sess.XOR(ct, sess.Encrypt(0))
 120  
 121  	// Valid unwrap.
 122  	unwrapped := sess.Unwrap(result)
 123  	if unwrapped == nil {
 124  		t.Fatal("Unwrap returned nil for valid result")
 125  	}
 126  
 127  	// Tamper and unwrap should return nil.
 128  	result.Tag[0] ^= 0xFF
 129  	if sess.Unwrap(result) != nil {
 130  		t.Error("Unwrap should return nil for tampered result")
 131  	}
 132  }
 133  
 134  func TestMPCSessionID(t *testing.T) {
 135  	sess, _, _ := setupMPC2Party(t)
 136  	sid := sess.SessionID()
 137  	if len(sid) != 32 {
 138  		t.Errorf("SessionID length = %d, want 32", len(sid))
 139  	}
 140  
 141  	// Non-zero.
 142  	allZero := true
 143  	for _, b := range sid {
 144  		if b != 0 {
 145  			allZero = false
 146  			break
 147  		}
 148  	}
 149  	if allZero {
 150  		t.Error("SessionID should not be all zeros")
 151  	}
 152  }
 153  
 154  func TestMPCCrossSessionReject(t *testing.T) {
 155  	kp := DefaultHEParams()
 156  	a := GenerateSharedA(kp, []byte("cross-session-crs"))
 157  	pk1, _ := HEKeyGenWithA(kp, a)
 158  	pk2, _ := HEKeyGenWithA(kp, a)
 159  	pks := []*KEMPublicKey{pk1, pk2}
 160  
 161  	sess1, err := NewMPCSession(pks, nil, nil, nil)
 162  	if err != nil {
 163  		t.Fatalf("NewMPCSession 1: %v", err)
 164  	}
 165  	sess2, err := NewMPCSession(pks, nil, nil, nil)
 166  	if err != nil {
 167  		t.Fatalf("NewMPCSession 2: %v", err)
 168  	}
 169  
 170  	// Result from session 1 should not verify under session 2.
 171  	ct := sess1.Encrypt(1)
 172  	result := sess1.XOR(ct, sess1.Encrypt(0))
 173  
 174  	if sess2.Verify(result) {
 175  		t.Error("result from session 1 should not verify under session 2")
 176  	}
 177  }
 178  
 179  func TestMPCWithGPVSignature(t *testing.T) {
 180  	kp := DefaultHEParams()
 181  	a := GenerateSharedA(kp, []byte("gpv-mpc-crs"))
 182  	pk1, sk1 := HEKeyGenWithA(kp, a)
 183  	pk2, sk2 := HEKeyGenWithA(kp, a)
 184  
 185  	gp := SmallGPVParams()
 186  	gpvPK, gpvSK := GPVKeyGen(gp)
 187  
 188  	sess, err := NewMPCSession([]*KEMPublicKey{pk1, pk2}, nil, gpvPK, gpvSK)
 189  	if err != nil {
 190  		t.Fatalf("NewMPCSession: %v", err)
 191  	}
 192  
 193  	ct := sess.Encrypt(1)
 194  	result := sess.XOR(ct, sess.Encrypt(0))
 195  
 196  	// Verify includes GPV signature check.
 197  	if !sess.Verify(result) {
 198  		t.Fatal("signed result failed verification")
 199  	}
 200  	if result.Signature == nil {
 201  		t.Fatal("expected GPV signature on result")
 202  	}
 203  
 204  	// Decrypt to confirm the computation is correct.
 205  	got := mpcDecrypt2(t, sk1, sk2, result.Ciphertext)
 206  	if got != 1 {
 207  		t.Errorf("signed MPC XOR(1,0): got %d, want 1", got)
 208  	}
 209  }
 210