wnaf_test.go raw

   1  package wnaf
   2  
   3  import (
   4  	"crypto/rand"
   5  	"math/big"
   6  	"testing"
   7  )
   8  
   9  // smallScalar constructs a [4]uint64 from a small uint64 value.
  10  func smallScalar(v uint64) [4]uint64 {
  11  	return [4]uint64{v, 0, 0, 0}
  12  }
  13  
  14  // limbs4ToBigInt converts [4]uint64 little-endian limbs to *big.Int (for test readability).
  15  func limbs4ToBigInt(s [4]uint64) *big.Int {
  16  	b := new(big.Int)
  17  	for i := 3; i >= 0; i-- {
  18  		b.Lsh(b, 64)
  19  		b.Or(b, new(big.Int).SetUint64(s[i]))
  20  	}
  21  	return b
  22  }
  23  
  24  // bigIntToLimbs4 converts *big.Int to [4]uint64 little-endian limbs.
  25  func bigIntToLimbs4(b *big.Int) [4]uint64 {
  26  	var s [4]uint64
  27  	mask := new(big.Int).SetUint64(^uint64(0))
  28  	tmp := new(big.Int).Set(b)
  29  	for i := range 4 {
  30  		s[i] = new(big.Int).And(tmp, mask).Uint64()
  31  		tmp.Rsh(tmp, 64)
  32  	}
  33  	return s
  34  }
  35  
  36  // randomScalar generates a random 256-bit scalar for testing.
  37  func randomScalar() [4]uint64 {
  38  	var buf [32]byte
  39  	if _, err := rand.Read(buf[:]); err != nil {
  40  		panic(err)
  41  	}
  42  	return FromBytes(buf)
  43  }
  44  
  45  // TestEncodeSmallNAF tests wNAF with w=2 (standard NAF) against hand-computed values.
  46  func TestEncodeSmallNAF(t *testing.T) {
  47  	tests := []struct {
  48  		name     string
  49  		scalar   uint64
  50  		expected map[int]int8 // position -> digit
  51  	}{
  52  		{"zero", 0, nil},
  53  		{"one", 1, map[int]int8{0: 1}},
  54  		{"two", 2, map[int]int8{1: 1}},
  55  		{"three", 3, map[int]int8{0: -1, 2: 1}},
  56  		{"seven", 7, map[int]int8{0: -1, 3: 1}},
  57  		{"eleven", 11, map[int]int8{0: -1, 2: -1, 4: 1}},
  58  		{"four", 4, map[int]int8{2: 1}},
  59  		{"eight", 8, map[int]int8{3: 1}},
  60  		{"sixteen", 16, map[int]int8{4: 1}},
  61  		{"five", 5, map[int]int8{0: 1, 2: 1}},
  62  		{"six", 6, map[int]int8{1: -1, 3: 1}},
  63  	}
  64  
  65  	for _, tc := range tests {
  66  		t.Run(tc.name, func(t *testing.T) {
  67  			d := Encode(smallScalar(tc.scalar), 2)
  68  
  69  			if err := d.Valid(2); err != nil {
  70  				t.Fatalf("invalid wNAF: %v", err)
  71  			}
  72  
  73  			// Check expected non-zero digits
  74  			for pos, digit := range tc.expected {
  75  				if d.D[pos] != digit {
  76  					t.Errorf("position %d: expected %d, got %d", pos, digit, d.D[pos])
  77  				}
  78  			}
  79  
  80  			// Check no unexpected non-zero digits
  81  			for i := range 257 {
  82  				if d.D[i] != 0 {
  83  					if _, ok := tc.expected[i]; !ok {
  84  						t.Errorf("unexpected non-zero digit at position %d: %d", i, d.D[i])
  85  					}
  86  				}
  87  			}
  88  
  89  			// Round-trip
  90  			got := d.Reconstruct()
  91  			want := smallScalar(tc.scalar)
  92  			if got != want {
  93  				t.Errorf("reconstruct mismatch: got %v, want %v", got, want)
  94  			}
  95  		})
  96  	}
  97  }
  98  
  99  // TestEncodeW5 tests wNAF with w=5 (libsecp256k1 default) against hand-computed values.
 100  func TestEncodeW5(t *testing.T) {
 101  	tests := []struct {
 102  		name     string
 103  		scalar   uint64
 104  		expected map[int]int8
 105  	}{
 106  		{"zero", 0, nil},
 107  		{"one", 1, map[int]int8{0: 1}},
 108  		{"fifteen", 15, map[int]int8{0: 15}},
 109  		{"seventeen", 17, map[int]int8{0: -15, 5: 1}},
 110  		{"thirty_one", 31, map[int]int8{0: -1, 5: 1}},
 111  		{"thirty_two", 32, map[int]int8{5: 1}},
 112  		{"thirty_three", 33, map[int]int8{0: 1, 5: 1}},
 113  	}
 114  
 115  	for _, tc := range tests {
 116  		t.Run(tc.name, func(t *testing.T) {
 117  			d := Encode(smallScalar(tc.scalar), 5)
 118  
 119  			if err := d.Valid(5); err != nil {
 120  				t.Fatalf("invalid wNAF: %v", err)
 121  			}
 122  
 123  			for pos, digit := range tc.expected {
 124  				if d.D[pos] != digit {
 125  					t.Errorf("position %d: expected %d, got %d", pos, digit, d.D[pos])
 126  				}
 127  			}
 128  
 129  			for i := range 257 {
 130  				if d.D[i] != 0 {
 131  					if _, ok := tc.expected[i]; !ok {
 132  						t.Errorf("unexpected non-zero digit at position %d: %d", i, d.D[i])
 133  					}
 134  				}
 135  			}
 136  
 137  			got := d.Reconstruct()
 138  			want := smallScalar(tc.scalar)
 139  			if got != want {
 140  				t.Errorf("reconstruct mismatch: got %v, want %v", got, want)
 141  			}
 142  		})
 143  	}
 144  }
 145  
 146  // TestEncodeEdgeCases tests boundary conditions.
 147  func TestEncodeEdgeCases(t *testing.T) {
 148  	for w := 2; w <= 8; w++ {
 149  		t.Run("zero/w="+itoa(w), func(t *testing.T) {
 150  			d := Encode([4]uint64{}, w)
 151  			if err := d.Valid(w); err != nil {
 152  				t.Fatalf("invalid: %v", err)
 153  			}
 154  			if d.Reconstruct() != [4]uint64{} {
 155  				t.Error("zero scalar should reconstruct to zero")
 156  			}
 157  		})
 158  
 159  		t.Run("one/w="+itoa(w), func(t *testing.T) {
 160  			d := Encode(smallScalar(1), w)
 161  			if err := d.Valid(w); err != nil {
 162  				t.Fatalf("invalid: %v", err)
 163  			}
 164  			if d.Reconstruct() != smallScalar(1) {
 165  				t.Error("scalar 1 should reconstruct to 1")
 166  			}
 167  		})
 168  
 169  		t.Run("powers_of_2/w="+itoa(w), func(t *testing.T) {
 170  			for bit := 0; bit < 256; bit++ {
 171  				var s [4]uint64
 172  				s[bit/64] = 1 << (bit % 64)
 173  				d := Encode(s, w)
 174  				if err := d.Valid(w); err != nil {
 175  					t.Fatalf("bit %d: invalid: %v", bit, err)
 176  				}
 177  				if d.Reconstruct() != s {
 178  					t.Errorf("bit %d: reconstruct mismatch", bit)
 179  				}
 180  			}
 181  		})
 182  
 183  		t.Run("all_ones_128/w="+itoa(w), func(t *testing.T) {
 184  			s := [4]uint64{^uint64(0), ^uint64(0), 0, 0}
 185  			d := Encode(s, w)
 186  			if err := d.Valid(w); err != nil {
 187  				t.Fatalf("invalid: %v", err)
 188  			}
 189  			if d.Reconstruct() != s {
 190  				t.Errorf("reconstruct mismatch: got %v, want %v", d.Reconstruct(), s)
 191  			}
 192  		})
 193  
 194  		t.Run("all_ones_256/w="+itoa(w), func(t *testing.T) {
 195  			s := [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}
 196  			d := Encode(s, w)
 197  			if err := d.Valid(w); err != nil {
 198  				t.Fatalf("invalid: %v", err)
 199  			}
 200  			if d.Reconstruct() != s {
 201  				t.Errorf("reconstruct mismatch: got %v, want %v", d.Reconstruct(), s)
 202  			}
 203  		})
 204  	}
 205  }
 206  
 207  // TestEncodePropertyRoundTrip verifies that Encode followed by Reconstruct
 208  // is the identity for random 256-bit scalars across all supported window widths.
 209  func TestEncodePropertyRoundTrip(t *testing.T) {
 210  	const iterations = 500
 211  
 212  	for w := 2; w <= 8; w++ {
 213  		t.Run("w="+itoa(w), func(t *testing.T) {
 214  			for i := range iterations {
 215  				s := randomScalar()
 216  				d := Encode(s, w)
 217  
 218  				if err := d.Valid(w); err != nil {
 219  					t.Fatalf("iteration %d: invalid wNAF: %v", i, err)
 220  				}
 221  
 222  				got := d.Reconstruct()
 223  				if got != s {
 224  					t.Fatalf("iteration %d: round-trip failed\n  input:  %016x %016x %016x %016x\n  output: %016x %016x %016x %016x",
 225  						i, s[3], s[2], s[1], s[0], got[3], got[2], got[1], got[0])
 226  				}
 227  			}
 228  		})
 229  	}
 230  }
 231  
 232  // TestEncodePropertyValidity checks that every encoded wNAF satisfies
 233  // all structural invariants.
 234  func TestEncodePropertyValidity(t *testing.T) {
 235  	const iterations = 500
 236  
 237  	for w := 2; w <= 8; w++ {
 238  		t.Run("w="+itoa(w), func(t *testing.T) {
 239  			for i := range iterations {
 240  				s := randomScalar()
 241  				d := Encode(s, w)
 242  				if err := d.Valid(w); err != nil {
 243  					t.Fatalf("iteration %d: %v\n  scalar: %016x %016x %016x %016x",
 244  						i, err, s[3], s[2], s[1], s[0])
 245  				}
 246  			}
 247  		})
 248  	}
 249  }
 250  
 251  // TestReconstructWithBigInt cross-checks Reconstruct against big.Int arithmetic.
 252  func TestReconstructWithBigInt(t *testing.T) {
 253  	const iterations = 200
 254  
 255  	for w := 2; w <= 8; w++ {
 256  		t.Run("w="+itoa(w), func(t *testing.T) {
 257  			for iter := range iterations {
 258  				s := randomScalar()
 259  				d := Encode(s, w)
 260  
 261  				// Compute expected value using big.Int
 262  				expected := new(big.Int)
 263  				for i := range 257 {
 264  					if d.D[i] == 0 {
 265  						continue
 266  					}
 267  					shifted := new(big.Int).Lsh(big.NewInt(int64(d.D[i])), uint(i))
 268  					expected.Add(expected, shifted)
 269  				}
 270  
 271  				got := limbs4ToBigInt(d.Reconstruct())
 272  				if got.Cmp(expected) != 0 {
 273  					t.Fatalf("iteration %d: big.Int mismatch\n  Reconstruct: %s\n  big.Int:     %s",
 274  						iter, got.Text(16), expected.Text(16))
 275  				}
 276  			}
 277  		})
 278  	}
 279  }
 280  
 281  // TestFromBytes verifies the byte-to-limb conversion.
 282  func TestFromBytes(t *testing.T) {
 283  	// All zeros
 284  	var zero [32]byte
 285  	if FromBytes(zero) != [4]uint64{} {
 286  		t.Error("zero bytes should produce zero limbs")
 287  	}
 288  
 289  	// Single byte at the end (LSB)
 290  	var one [32]byte
 291  	one[31] = 1
 292  	if FromBytes(one) != smallScalar(1) {
 293  		t.Errorf("got %v, want %v", FromBytes(one), smallScalar(1))
 294  	}
 295  
 296  	// 0x0102030405060708 in the highest limb
 297  	var high [32]byte
 298  	high[0] = 0x01
 299  	high[1] = 0x02
 300  	high[2] = 0x03
 301  	high[3] = 0x04
 302  	high[4] = 0x05
 303  	high[5] = 0x06
 304  	high[6] = 0x07
 305  	high[7] = 0x08
 306  	s := FromBytes(high)
 307  	if s[3] != 0x0102030405060708 {
 308  		t.Errorf("high limb: got %016x, want 0102030405060708", s[3])
 309  	}
 310  	if s[0] != 0 || s[1] != 0 || s[2] != 0 {
 311  		t.Error("lower limbs should be zero")
 312  	}
 313  }
 314  
 315  // TestGetBits verifies bit extraction across limb boundaries.
 316  func TestGetBits(t *testing.T) {
 317  	s := [4]uint64{0xDEADBEEFCAFEBABE, 0x1234567890ABCDEF, 0, 0}
 318  
 319  	// Single bit
 320  	if getBits(s, 0, 1) != 0 {
 321  		t.Error("bit 0 should be 0")
 322  	}
 323  	if getBits(s, 1, 1) != 1 {
 324  		t.Error("bit 1 should be 1")
 325  	}
 326  
 327  	// Byte at offset 0
 328  	if getBits(s, 0, 8) != 0xBE {
 329  		t.Errorf("bits [0:8] = %x, want 0xBE", getBits(s, 0, 8))
 330  	}
 331  
 332  	// Cross limb boundary (bits 60-67)
 333  	val := getBits(s, 60, 8)
 334  	// Low 4 bits from limb 0 (bits 60-63): 0xD (top nibble of 0xDEADBEEFCAFEBABE)
 335  	// High 4 bits from limb 1 (bits 0-3): 0xF (low nibble of 0x1234567890ABCDEF)
 336  	if val != 0xFD {
 337  		t.Errorf("cross-limb bits [60:68] = %x, want 0xFD", val)
 338  	}
 339  }
 340  
 341  // TestEncodeSignedNegation verifies that EncodeSigned handles high-bit scalars.
 342  func TestEncodeSignedNegation(t *testing.T) {
 343  	// Scalar with bit 255 set
 344  	s := [4]uint64{1, 0, 0, 1 << 63}
 345  	d, negated := EncodeSigned(s, 5)
 346  	if !negated {
 347  		t.Error("should have negated")
 348  	}
 349  	if err := d.Valid(5); err != nil {
 350  		t.Fatalf("invalid: %v", err)
 351  	}
 352  
 353  	// Scalar without bit 255
 354  	s2 := [4]uint64{0x12345, 0, 0, 0}
 355  	d2, negated2 := EncodeSigned(s2, 5)
 356  	if negated2 {
 357  		t.Error("should not have negated")
 358  	}
 359  	if d2.Reconstruct() != s2 {
 360  		t.Error("non-negated round-trip failed")
 361  	}
 362  }
 363  
 364  // TestValidRejectsInvalid checks that Valid catches structural violations.
 365  func TestValidRejectsInvalid(t *testing.T) {
 366  	// Even non-zero digit
 367  	d := Digits{}
 368  	d.D[0] = 2
 369  	if d.Valid(5) == nil {
 370  		t.Error("should reject even digit")
 371  	}
 372  
 373  	// Out of range digit for w=3 (max is 3)
 374  	d = Digits{}
 375  	d.D[0] = 5
 376  	if d.Valid(3) == nil {
 377  		t.Error("should reject digit 5 for w=3")
 378  	}
 379  
 380  	// Adjacent non-zero digits (spacing violation for w=5)
 381  	d = Digits{}
 382  	d.D[0] = 1
 383  	d.D[3] = 1 // only 3 apart, need >= 5
 384  	if d.Valid(5) == nil {
 385  		t.Error("should reject spacing violation")
 386  	}
 387  
 388  	// Valid: non-zero digits exactly w apart
 389  	d = Digits{}
 390  	d.D[0] = 1
 391  	d.D[5] = 1
 392  	if err := d.Valid(5); err != nil {
 393  		t.Errorf("should accept spacing of exactly w: %v", err)
 394  	}
 395  }
 396  
 397  // TestNonZeroDigitCount verifies that wNAF produces a sparse representation.
 398  func TestNonZeroDigitCount(t *testing.T) {
 399  	const iterations = 200
 400  
 401  	for w := 2; w <= 8; w++ {
 402  		t.Run("w="+itoa(w), func(t *testing.T) {
 403  			maxExpected := 256/w + 2 // floor(256/w) + 1 body digits + 1 possible carry
 404  			for range iterations {
 405  				s := randomScalar()
 406  				d := Encode(s, w)
 407  
 408  				count := 0
 409  				for i := range 257 {
 410  					if d.D[i] != 0 {
 411  						count++
 412  					}
 413  				}
 414  
 415  				if count > maxExpected {
 416  					t.Errorf("too many non-zero digits: %d (expected <= %d) for w=%d",
 417  						count, maxExpected, w)
 418  				}
 419  			}
 420  		})
 421  	}
 422  }
 423  
 424  // itoa is a simple int-to-string for test names without importing strconv.
 425  func itoa(n int) string {
 426  	if n == 0 {
 427  		return "0"
 428  	}
 429  	var buf [20]byte
 430  	i := len(buf)
 431  	for n > 0 {
 432  		i--
 433  		buf[i] = byte('0' + n%10)
 434  		n /= 10
 435  	}
 436  	return string(buf[i:])
 437  }
 438