kem_test.go raw

   1  package ring
   2  
   3  import (
   4  	"bytes"
   5  	"testing"
   6  )
   7  
   8  // TestKEMRoundTrip verifies encapsulate/decapsulate produce the same shared key.
   9  func TestKEMRoundTrip(t *testing.T) {
  10  	kp := DefaultKEMParams()
  11  	pk, sk := KEMKeyGen(kp)
  12  
  13  	sharedKeyEnc, ct, err := Encapsulate(pk)
  14  	if err != nil {
  15  		t.Fatalf("Encapsulate: %v", err)
  16  	}
  17  
  18  	sharedKeyDec, err := Decapsulate(sk, ct)
  19  	if err != nil {
  20  		t.Fatalf("Decapsulate: %v", err)
  21  	}
  22  
  23  	if !bytes.Equal(sharedKeyEnc, sharedKeyDec) {
  24  		t.Fatalf("shared keys don't match:\n  enc: %x\n  dec: %x", sharedKeyEnc, sharedKeyDec)
  25  	}
  26  
  27  	if len(sharedKeyEnc) != kp.SharedKeyLen {
  28  		t.Fatalf("shared key length: got %d, want %d", len(sharedKeyEnc), kp.SharedKeyLen)
  29  	}
  30  }
  31  
  32  // TestKEMMultipleRoundTrips verifies consistency across multiple key pairs.
  33  func TestKEMMultipleRoundTrips(t *testing.T) {
  34  	kp := DefaultKEMParams()
  35  	for i := range 10 {
  36  		pk, sk := KEMKeyGen(kp)
  37  		sharedKeyEnc, ct, err := Encapsulate(pk)
  38  		if err != nil {
  39  			t.Fatalf("round %d: Encapsulate: %v", i, err)
  40  		}
  41  		sharedKeyDec, err := Decapsulate(sk, ct)
  42  		if err != nil {
  43  			t.Fatalf("round %d: Decapsulate: %v", i, err)
  44  		}
  45  		if !bytes.Equal(sharedKeyEnc, sharedKeyDec) {
  46  			t.Fatalf("round %d: shared keys don't match", i)
  47  		}
  48  	}
  49  }
  50  
  51  // TestKEMTamperedCiphertext verifies that a modified ciphertext produces
  52  // a different shared key (implicit rejection, not an error).
  53  func TestKEMTamperedCiphertext(t *testing.T) {
  54  	kp := DefaultKEMParams()
  55  	pk, sk := KEMKeyGen(kp)
  56  
  57  	sharedKeyEnc, ct, err := Encapsulate(pk)
  58  	if err != nil {
  59  		t.Fatal(err)
  60  	}
  61  
  62  	// Tamper: flip a coefficient in U.
  63  	ct.U.Coeffs[0] = (ct.U.Coeffs[0] + 1) % kp.Ring.Q
  64  
  65  	sharedKeyDec, err := Decapsulate(sk, ct)
  66  	if err != nil {
  67  		t.Fatal(err)
  68  	}
  69  
  70  	// The decapsulated key should be different (implicit rejection).
  71  	if bytes.Equal(sharedKeyEnc, sharedKeyDec) {
  72  		t.Fatal("tampered ciphertext produced same shared key — FO transform not working")
  73  	}
  74  }
  75  
  76  // TestKEMDifferentKeys verifies two key pairs produce different shared keys.
  77  func TestKEMDifferentKeys(t *testing.T) {
  78  	kp := DefaultKEMParams()
  79  	pk1, _ := KEMKeyGen(kp)
  80  	_, sk2 := KEMKeyGen(kp)
  81  
  82  	_, ct, err := Encapsulate(pk1)
  83  	if err != nil {
  84  		t.Fatal(err)
  85  	}
  86  
  87  	// Decapsulate with wrong secret key — should produce rejection key.
  88  	sharedKeyWrong, err := Decapsulate(sk2, ct)
  89  	if err != nil {
  90  		t.Fatal(err)
  91  	}
  92  
  93  	// The key should be deterministically derived but wrong.
  94  	if len(sharedKeyWrong) != kp.SharedKeyLen {
  95  		t.Fatalf("wrong key length: got %d, want %d", len(sharedKeyWrong), kp.SharedKeyLen)
  96  	}
  97  }
  98  
  99  // TestCPAEncryptDecrypt verifies the inner CPA scheme directly.
 100  func TestCPAEncryptDecrypt(t *testing.T) {
 101  	kp := DefaultKEMParams()
 102  	pk, sk := KEMKeyGen(kp)
 103  
 104  	msg := make([]byte, 32)
 105  	for i := range msg {
 106  		msg[i] = byte(i)
 107  	}
 108  
 109  	coins := make([]byte, 32)
 110  	for i := range coins {
 111  		coins[i] = byte(i + 100)
 112  	}
 113  
 114  	ct := cpaPKEEncrypt(pk, msg, coins)
 115  	recovered := cpaPKEDecrypt(sk, ct)
 116  
 117  	if !bytes.Equal(msg, recovered) {
 118  		t.Fatalf("CPA decrypt failed:\n  sent: %x\n  recv: %x", msg, recovered)
 119  	}
 120  }
 121  
 122  // TestEncodeDecodeMessage verifies the message encoding round-trip.
 123  func TestEncodeDecodeMessage(t *testing.T) {
 124  	p := Falcon512()
 125  	msg := make([]byte, 32)
 126  	for i := range msg {
 127  		msg[i] = byte(0xAA) // alternating bits
 128  	}
 129  
 130  	encoded := encodeMessage(p, msg)
 131  	decoded := decodeMessage(encoded)
 132  
 133  	if !bytes.Equal(msg, decoded[:32]) {
 134  		t.Fatalf("encode/decode failed:\n  sent: %x\n  recv: %x", msg, decoded[:32])
 135  	}
 136  }
 137  
 138  func BenchmarkKEMKeyGen(b *testing.B) {
 139  	kp := DefaultKEMParams()
 140  	for range b.N {
 141  		KEMKeyGen(kp)
 142  	}
 143  }
 144  
 145  func BenchmarkKEMEncapsulate(b *testing.B) {
 146  	kp := DefaultKEMParams()
 147  	pk, _ := KEMKeyGen(kp)
 148  	b.ResetTimer()
 149  	for range b.N {
 150  		Encapsulate(pk)
 151  	}
 152  }
 153  
 154  func BenchmarkKEMDecapsulate(b *testing.B) {
 155  	kp := DefaultKEMParams()
 156  	pk, sk := KEMKeyGen(kp)
 157  	_, ct, _ := Encapsulate(pk)
 158  	b.ResetTimer()
 159  	for range b.N {
 160  		Decapsulate(sk, ct)
 161  	}
 162  }
 163