ecdh_test.go raw

   1  package p256k1
   2  
   3  import (
   4  	"testing"
   5  )
   6  
   7  func TestEcmultConst(t *testing.T) {
   8  	// Test with generator point
   9  	var scalar Scalar
  10  	scalar.setInt(5)
  11  
  12  	var gJac GroupElementJacobian
  13  	gJac.setGE(&Generator)
  14  
  15  	var result GroupElementJacobian
  16  	EcmultConst(&result, &Generator, &scalar)
  17  
  18  	if result.isInfinity() {
  19  		t.Error("5*G should not be infinity")
  20  	}
  21  
  22  	// Verify it matches EcmultGen for generator
  23  	var expected GroupElementJacobian
  24  	EcmultGen(&expected, &scalar)
  25  
  26  	var resultAff, expectedAff GroupElementAffine
  27  	resultAff.setGEJ(&result)
  28  	expectedAff.setGEJ(&expected)
  29  
  30  	resultAff.x.normalize()
  31  	resultAff.y.normalize()
  32  	expectedAff.x.normalize()
  33  	expectedAff.y.normalize()
  34  
  35  	if !resultAff.x.equal(&expectedAff.x) || !resultAff.y.equal(&expectedAff.y) {
  36  		t.Error("EcmultConst result does not match EcmultGen for generator")
  37  	}
  38  }
  39  
  40  func TestEcmult(t *testing.T) {
  41  	// Test with arbitrary point
  42  	var scalar Scalar
  43  	scalar.setInt(3)
  44  
  45  	var point GroupElementAffine
  46  	point.setXY(&Generator.x, &Generator.y)
  47  
  48  	var pointJac GroupElementJacobian
  49  	pointJac.setGE(&point)
  50  
  51  	var result GroupElementJacobian
  52  	Ecmult(&result, &pointJac, &scalar)
  53  
  54  	if result.isInfinity() {
  55  		t.Error("3*P should not be infinity")
  56  	}
  57  
  58  	// Verify it matches EcmultConst
  59  	var expected GroupElementJacobian
  60  	EcmultConst(&expected, &point, &scalar)
  61  
  62  	var resultAff, expectedAff GroupElementAffine
  63  	resultAff.setGEJ(&result)
  64  	expectedAff.setGEJ(&expected)
  65  
  66  	resultAff.x.normalize()
  67  	resultAff.y.normalize()
  68  	expectedAff.x.normalize()
  69  	expectedAff.y.normalize()
  70  
  71  	if !resultAff.x.equal(&expectedAff.x) || !resultAff.y.equal(&expectedAff.y) {
  72  		t.Error("Ecmult result does not match EcmultConst")
  73  	}
  74  }
  75  
  76  func TestECDH(t *testing.T) {
  77  	// Generate two key pairs
  78  	seckey1, pubkey1, err := ECKeyPairGenerate()
  79  	if err != nil {
  80  		t.Fatalf("failed to generate key pair 1: %v", err)
  81  	}
  82  
  83  	seckey2, pubkey2, err := ECKeyPairGenerate()
  84  	if err != nil {
  85  		t.Fatalf("failed to generate key pair 2: %v", err)
  86  	}
  87  
  88  	// Compute shared secrets
  89  	var shared1, shared2 [32]byte
  90  
  91  	// Alice computes shared secret with Bob's public key
  92  	if err := ECDH(shared1[:], pubkey2, seckey1, nil); err != nil {
  93  		t.Fatalf("ECDH failed for Alice: %v", err)
  94  	}
  95  
  96  	// Bob computes shared secret with Alice's public key
  97  	if err := ECDH(shared2[:], pubkey1, seckey2, nil); err != nil {
  98  		t.Fatalf("ECDH failed for Bob: %v", err)
  99  	}
 100  
 101  	// Both should have the same shared secret
 102  	for i := 0; i < 32; i++ {
 103  		if shared1[i] != shared2[i] {
 104  			t.Errorf("shared secrets differ at byte %d: 0x%02x != 0x%02x", i, shared1[i], shared2[i])
 105  		}
 106  	}
 107  }
 108  
 109  func TestECDHZeroKey(t *testing.T) {
 110  	// Test that zero key is rejected
 111  	_, pubkey, err := ECKeyPairGenerate()
 112  	if err != nil {
 113  		t.Fatalf("failed to generate key pair: %v", err)
 114  	}
 115  
 116  	zeroKey := make([]byte, 32)
 117  	var output [32]byte
 118  
 119  	err = ECDH(output[:], pubkey, zeroKey, nil)
 120  	if err == nil {
 121  		t.Error("ECDH should fail with zero key")
 122  	}
 123  }
 124  
 125  func TestECDHInvalidKey(t *testing.T) {
 126  	_, pubkey, err := ECKeyPairGenerate()
 127  	if err != nil {
 128  		t.Fatalf("failed to generate key pair: %v", err)
 129  	}
 130  
 131  	// Test with invalid key (all 0xFF - likely invalid)
 132  	invalidKey := make([]byte, 32)
 133  	for i := range invalidKey {
 134  		invalidKey[i] = 0xFF
 135  	}
 136  
 137  	var output [32]byte
 138  	err = ECDH(output[:], pubkey, invalidKey, nil)
 139  	if err == nil {
 140  		// If it doesn't fail, verify the key is actually valid
 141  		if !ECSeckeyVerify(invalidKey) {
 142  			t.Error("ECDH should fail with invalid key")
 143  		}
 144  	}
 145  }
 146  
 147  func TestECDHCustomHash(t *testing.T) {
 148  	// Test with custom hash function
 149  	seckey1, pubkey1, err := ECKeyPairGenerate()
 150  	if err != nil {
 151  		t.Fatalf("failed to generate key pair 1: %v", err)
 152  	}
 153  
 154  	seckey2, pubkey2, err := ECKeyPairGenerate()
 155  	if err != nil {
 156  		t.Fatalf("failed to generate key pair 2: %v", err)
 157  	}
 158  
 159  	// Custom hash: just XOR x and y
 160  	customHash := func(output []byte, x32 []byte, y32 []byte) bool {
 161  		if len(output) != 32 {
 162  			return false
 163  		}
 164  		for i := 0; i < 32; i++ {
 165  			output[i] = x32[i] ^ y32[i]
 166  		}
 167  		return true
 168  	}
 169  
 170  	var shared1, shared2 [32]byte
 171  
 172  	if err := ECDH(shared1[:], pubkey2, seckey1, customHash); err != nil {
 173  		t.Fatalf("ECDH failed: %v", err)
 174  	}
 175  
 176  	if err := ECDH(shared2[:], pubkey1, seckey2, customHash); err != nil {
 177  		t.Fatalf("ECDH failed: %v", err)
 178  	}
 179  
 180  	for i := 0; i < 32; i++ {
 181  		if shared1[i] != shared2[i] {
 182  			t.Errorf("shared secrets differ at byte %d", i)
 183  		}
 184  	}
 185  }
 186  
 187  func TestHKDF(t *testing.T) {
 188  	// Test HKDF with known inputs
 189  	ikm := []byte("test input key material")
 190  	salt := []byte("test salt")
 191  	info := []byte("test info")
 192  
 193  	output := make([]byte, 64)
 194  	if err := HKDF(output, ikm, salt, info); err != nil {
 195  		t.Fatalf("HKDF failed: %v", err)
 196  	}
 197  
 198  	// Verify output is not all zeros
 199  	allZero := true
 200  	for i := 0; i < len(output); i++ {
 201  		if output[i] != 0 {
 202  			allZero = false
 203  			break
 204  		}
 205  	}
 206  	if allZero {
 207  		t.Error("HKDF output is all zeros")
 208  	}
 209  
 210  	// Test with empty salt
 211  	output2 := make([]byte, 32)
 212  	if err := HKDF(output2, ikm, nil, info); err != nil {
 213  		t.Fatalf("HKDF failed with empty salt: %v", err)
 214  	}
 215  
 216  	// Test with empty info
 217  	output3 := make([]byte, 32)
 218  	if err := HKDF(output3, ikm, salt, nil); err != nil {
 219  		t.Fatalf("HKDF failed with empty info: %v", err)
 220  	}
 221  }
 222  
 223  func TestECDHWithHKDF(t *testing.T) {
 224  	seckey1, pubkey1, err := ECKeyPairGenerate()
 225  	if err != nil {
 226  		t.Fatalf("failed to generate key pair 1: %v", err)
 227  	}
 228  
 229  	seckey2, pubkey2, err := ECKeyPairGenerate()
 230  	if err != nil {
 231  		t.Fatalf("failed to generate key pair 2: %v", err)
 232  	}
 233  
 234  	salt := []byte("test salt")
 235  	info := []byte("test info")
 236  
 237  	// Derive keys
 238  	var key1, key2 [64]byte
 239  	if err := ECDHWithHKDF(key1[:], pubkey2, seckey1, salt, info); err != nil {
 240  		t.Fatalf("ECDHWithHKDF failed: %v", err)
 241  	}
 242  
 243  	if err := ECDHWithHKDF(key2[:], pubkey1, seckey2, salt, info); err != nil {
 244  		t.Fatalf("ECDHWithHKDF failed: %v", err)
 245  	}
 246  
 247  	// Keys should match
 248  	for i := 0; i < 64; i++ {
 249  		if key1[i] != key2[i] {
 250  			t.Errorf("derived keys differ at byte %d", i)
 251  		}
 252  	}
 253  }
 254  
 255  func TestECDHXOnly(t *testing.T) {
 256  	seckey1, pubkey1, err := ECKeyPairGenerate()
 257  	if err != nil {
 258  		t.Fatalf("failed to generate key pair 1: %v", err)
 259  	}
 260  
 261  	seckey2, pubkey2, err := ECKeyPairGenerate()
 262  	if err != nil {
 263  		t.Fatalf("failed to generate key pair 2: %v", err)
 264  	}
 265  
 266  	// Compute X-only shared secrets
 267  	var x1, x2 [32]byte
 268  
 269  	if err := ECDHXOnly(x1[:], pubkey2, seckey1); err != nil {
 270  		t.Fatalf("ECDHXOnly failed: %v", err)
 271  	}
 272  
 273  	if err := ECDHXOnly(x2[:], pubkey1, seckey2); err != nil {
 274  		t.Fatalf("ECDHXOnly failed: %v", err)
 275  	}
 276  
 277  	// X coordinates should match
 278  	for i := 0; i < 32; i++ {
 279  		if x1[i] != x2[i] {
 280  			t.Errorf("X-only shared secrets differ at byte %d", i)
 281  		}
 282  	}
 283  }
 284