wnaf_test.go raw
1 package wnaf
2
3 import (
4 "crypto/rand"
5 "math/big"
6 "testing"
7 )
8
9 // smallScalar constructs a [4]uint64 from a small uint64 value.
10 func smallScalar(v uint64) [4]uint64 {
11 return [4]uint64{v, 0, 0, 0}
12 }
13
14 // limbs4ToBigInt converts [4]uint64 little-endian limbs to *big.Int (for test readability).
15 func limbs4ToBigInt(s [4]uint64) *big.Int {
16 b := new(big.Int)
17 for i := 3; i >= 0; i-- {
18 b.Lsh(b, 64)
19 b.Or(b, new(big.Int).SetUint64(s[i]))
20 }
21 return b
22 }
23
24 // bigIntToLimbs4 converts *big.Int to [4]uint64 little-endian limbs.
25 func bigIntToLimbs4(b *big.Int) [4]uint64 {
26 var s [4]uint64
27 mask := new(big.Int).SetUint64(^uint64(0))
28 tmp := new(big.Int).Set(b)
29 for i := range 4 {
30 s[i] = new(big.Int).And(tmp, mask).Uint64()
31 tmp.Rsh(tmp, 64)
32 }
33 return s
34 }
35
36 // randomScalar generates a random 256-bit scalar for testing.
37 func randomScalar() [4]uint64 {
38 var buf [32]byte
39 if _, err := rand.Read(buf[:]); err != nil {
40 panic(err)
41 }
42 return FromBytes(buf)
43 }
44
45 // TestEncodeSmallNAF tests wNAF with w=2 (standard NAF) against hand-computed values.
46 func TestEncodeSmallNAF(t *testing.T) {
47 tests := []struct {
48 name string
49 scalar uint64
50 expected map[int]int8 // position -> digit
51 }{
52 {"zero", 0, nil},
53 {"one", 1, map[int]int8{0: 1}},
54 {"two", 2, map[int]int8{1: 1}},
55 {"three", 3, map[int]int8{0: -1, 2: 1}},
56 {"seven", 7, map[int]int8{0: -1, 3: 1}},
57 {"eleven", 11, map[int]int8{0: -1, 2: -1, 4: 1}},
58 {"four", 4, map[int]int8{2: 1}},
59 {"eight", 8, map[int]int8{3: 1}},
60 {"sixteen", 16, map[int]int8{4: 1}},
61 {"five", 5, map[int]int8{0: 1, 2: 1}},
62 {"six", 6, map[int]int8{1: -1, 3: 1}},
63 }
64
65 for _, tc := range tests {
66 t.Run(tc.name, func(t *testing.T) {
67 d := Encode(smallScalar(tc.scalar), 2)
68
69 if err := d.Valid(2); err != nil {
70 t.Fatalf("invalid wNAF: %v", err)
71 }
72
73 // Check expected non-zero digits
74 for pos, digit := range tc.expected {
75 if d.D[pos] != digit {
76 t.Errorf("position %d: expected %d, got %d", pos, digit, d.D[pos])
77 }
78 }
79
80 // Check no unexpected non-zero digits
81 for i := range 257 {
82 if d.D[i] != 0 {
83 if _, ok := tc.expected[i]; !ok {
84 t.Errorf("unexpected non-zero digit at position %d: %d", i, d.D[i])
85 }
86 }
87 }
88
89 // Round-trip
90 got := d.Reconstruct()
91 want := smallScalar(tc.scalar)
92 if got != want {
93 t.Errorf("reconstruct mismatch: got %v, want %v", got, want)
94 }
95 })
96 }
97 }
98
99 // TestEncodeW5 tests wNAF with w=5 (libsecp256k1 default) against hand-computed values.
100 func TestEncodeW5(t *testing.T) {
101 tests := []struct {
102 name string
103 scalar uint64
104 expected map[int]int8
105 }{
106 {"zero", 0, nil},
107 {"one", 1, map[int]int8{0: 1}},
108 {"fifteen", 15, map[int]int8{0: 15}},
109 {"seventeen", 17, map[int]int8{0: -15, 5: 1}},
110 {"thirty_one", 31, map[int]int8{0: -1, 5: 1}},
111 {"thirty_two", 32, map[int]int8{5: 1}},
112 {"thirty_three", 33, map[int]int8{0: 1, 5: 1}},
113 }
114
115 for _, tc := range tests {
116 t.Run(tc.name, func(t *testing.T) {
117 d := Encode(smallScalar(tc.scalar), 5)
118
119 if err := d.Valid(5); err != nil {
120 t.Fatalf("invalid wNAF: %v", err)
121 }
122
123 for pos, digit := range tc.expected {
124 if d.D[pos] != digit {
125 t.Errorf("position %d: expected %d, got %d", pos, digit, d.D[pos])
126 }
127 }
128
129 for i := range 257 {
130 if d.D[i] != 0 {
131 if _, ok := tc.expected[i]; !ok {
132 t.Errorf("unexpected non-zero digit at position %d: %d", i, d.D[i])
133 }
134 }
135 }
136
137 got := d.Reconstruct()
138 want := smallScalar(tc.scalar)
139 if got != want {
140 t.Errorf("reconstruct mismatch: got %v, want %v", got, want)
141 }
142 })
143 }
144 }
145
146 // TestEncodeEdgeCases tests boundary conditions.
147 func TestEncodeEdgeCases(t *testing.T) {
148 for w := 2; w <= 8; w++ {
149 t.Run("zero/w="+itoa(w), func(t *testing.T) {
150 d := Encode([4]uint64{}, w)
151 if err := d.Valid(w); err != nil {
152 t.Fatalf("invalid: %v", err)
153 }
154 if d.Reconstruct() != [4]uint64{} {
155 t.Error("zero scalar should reconstruct to zero")
156 }
157 })
158
159 t.Run("one/w="+itoa(w), func(t *testing.T) {
160 d := Encode(smallScalar(1), w)
161 if err := d.Valid(w); err != nil {
162 t.Fatalf("invalid: %v", err)
163 }
164 if d.Reconstruct() != smallScalar(1) {
165 t.Error("scalar 1 should reconstruct to 1")
166 }
167 })
168
169 t.Run("powers_of_2/w="+itoa(w), func(t *testing.T) {
170 for bit := 0; bit < 256; bit++ {
171 var s [4]uint64
172 s[bit/64] = 1 << (bit % 64)
173 d := Encode(s, w)
174 if err := d.Valid(w); err != nil {
175 t.Fatalf("bit %d: invalid: %v", bit, err)
176 }
177 if d.Reconstruct() != s {
178 t.Errorf("bit %d: reconstruct mismatch", bit)
179 }
180 }
181 })
182
183 t.Run("all_ones_128/w="+itoa(w), func(t *testing.T) {
184 s := [4]uint64{^uint64(0), ^uint64(0), 0, 0}
185 d := Encode(s, w)
186 if err := d.Valid(w); err != nil {
187 t.Fatalf("invalid: %v", err)
188 }
189 if d.Reconstruct() != s {
190 t.Errorf("reconstruct mismatch: got %v, want %v", d.Reconstruct(), s)
191 }
192 })
193
194 t.Run("all_ones_256/w="+itoa(w), func(t *testing.T) {
195 s := [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}
196 d := Encode(s, w)
197 if err := d.Valid(w); err != nil {
198 t.Fatalf("invalid: %v", err)
199 }
200 if d.Reconstruct() != s {
201 t.Errorf("reconstruct mismatch: got %v, want %v", d.Reconstruct(), s)
202 }
203 })
204 }
205 }
206
207 // TestEncodePropertyRoundTrip verifies that Encode followed by Reconstruct
208 // is the identity for random 256-bit scalars across all supported window widths.
209 func TestEncodePropertyRoundTrip(t *testing.T) {
210 const iterations = 500
211
212 for w := 2; w <= 8; w++ {
213 t.Run("w="+itoa(w), func(t *testing.T) {
214 for i := range iterations {
215 s := randomScalar()
216 d := Encode(s, w)
217
218 if err := d.Valid(w); err != nil {
219 t.Fatalf("iteration %d: invalid wNAF: %v", i, err)
220 }
221
222 got := d.Reconstruct()
223 if got != s {
224 t.Fatalf("iteration %d: round-trip failed\n input: %016x %016x %016x %016x\n output: %016x %016x %016x %016x",
225 i, s[3], s[2], s[1], s[0], got[3], got[2], got[1], got[0])
226 }
227 }
228 })
229 }
230 }
231
232 // TestEncodePropertyValidity checks that every encoded wNAF satisfies
233 // all structural invariants.
234 func TestEncodePropertyValidity(t *testing.T) {
235 const iterations = 500
236
237 for w := 2; w <= 8; w++ {
238 t.Run("w="+itoa(w), func(t *testing.T) {
239 for i := range iterations {
240 s := randomScalar()
241 d := Encode(s, w)
242 if err := d.Valid(w); err != nil {
243 t.Fatalf("iteration %d: %v\n scalar: %016x %016x %016x %016x",
244 i, err, s[3], s[2], s[1], s[0])
245 }
246 }
247 })
248 }
249 }
250
251 // TestReconstructWithBigInt cross-checks Reconstruct against big.Int arithmetic.
252 func TestReconstructWithBigInt(t *testing.T) {
253 const iterations = 200
254
255 for w := 2; w <= 8; w++ {
256 t.Run("w="+itoa(w), func(t *testing.T) {
257 for iter := range iterations {
258 s := randomScalar()
259 d := Encode(s, w)
260
261 // Compute expected value using big.Int
262 expected := new(big.Int)
263 for i := range 257 {
264 if d.D[i] == 0 {
265 continue
266 }
267 shifted := new(big.Int).Lsh(big.NewInt(int64(d.D[i])), uint(i))
268 expected.Add(expected, shifted)
269 }
270
271 got := limbs4ToBigInt(d.Reconstruct())
272 if got.Cmp(expected) != 0 {
273 t.Fatalf("iteration %d: big.Int mismatch\n Reconstruct: %s\n big.Int: %s",
274 iter, got.Text(16), expected.Text(16))
275 }
276 }
277 })
278 }
279 }
280
281 // TestFromBytes verifies the byte-to-limb conversion.
282 func TestFromBytes(t *testing.T) {
283 // All zeros
284 var zero [32]byte
285 if FromBytes(zero) != [4]uint64{} {
286 t.Error("zero bytes should produce zero limbs")
287 }
288
289 // Single byte at the end (LSB)
290 var one [32]byte
291 one[31] = 1
292 if FromBytes(one) != smallScalar(1) {
293 t.Errorf("got %v, want %v", FromBytes(one), smallScalar(1))
294 }
295
296 // 0x0102030405060708 in the highest limb
297 var high [32]byte
298 high[0] = 0x01
299 high[1] = 0x02
300 high[2] = 0x03
301 high[3] = 0x04
302 high[4] = 0x05
303 high[5] = 0x06
304 high[6] = 0x07
305 high[7] = 0x08
306 s := FromBytes(high)
307 if s[3] != 0x0102030405060708 {
308 t.Errorf("high limb: got %016x, want 0102030405060708", s[3])
309 }
310 if s[0] != 0 || s[1] != 0 || s[2] != 0 {
311 t.Error("lower limbs should be zero")
312 }
313 }
314
315 // TestGetBits verifies bit extraction across limb boundaries.
316 func TestGetBits(t *testing.T) {
317 s := [4]uint64{0xDEADBEEFCAFEBABE, 0x1234567890ABCDEF, 0, 0}
318
319 // Single bit
320 if getBits(s, 0, 1) != 0 {
321 t.Error("bit 0 should be 0")
322 }
323 if getBits(s, 1, 1) != 1 {
324 t.Error("bit 1 should be 1")
325 }
326
327 // Byte at offset 0
328 if getBits(s, 0, 8) != 0xBE {
329 t.Errorf("bits [0:8] = %x, want 0xBE", getBits(s, 0, 8))
330 }
331
332 // Cross limb boundary (bits 60-67)
333 val := getBits(s, 60, 8)
334 // Low 4 bits from limb 0 (bits 60-63): 0xD (top nibble of 0xDEADBEEFCAFEBABE)
335 // High 4 bits from limb 1 (bits 0-3): 0xF (low nibble of 0x1234567890ABCDEF)
336 if val != 0xFD {
337 t.Errorf("cross-limb bits [60:68] = %x, want 0xFD", val)
338 }
339 }
340
341 // TestEncodeSignedNegation verifies that EncodeSigned handles high-bit scalars.
342 func TestEncodeSignedNegation(t *testing.T) {
343 // Scalar with bit 255 set
344 s := [4]uint64{1, 0, 0, 1 << 63}
345 d, negated := EncodeSigned(s, 5)
346 if !negated {
347 t.Error("should have negated")
348 }
349 if err := d.Valid(5); err != nil {
350 t.Fatalf("invalid: %v", err)
351 }
352
353 // Scalar without bit 255
354 s2 := [4]uint64{0x12345, 0, 0, 0}
355 d2, negated2 := EncodeSigned(s2, 5)
356 if negated2 {
357 t.Error("should not have negated")
358 }
359 if d2.Reconstruct() != s2 {
360 t.Error("non-negated round-trip failed")
361 }
362 }
363
364 // TestValidRejectsInvalid checks that Valid catches structural violations.
365 func TestValidRejectsInvalid(t *testing.T) {
366 // Even non-zero digit
367 d := Digits{}
368 d.D[0] = 2
369 if d.Valid(5) == nil {
370 t.Error("should reject even digit")
371 }
372
373 // Out of range digit for w=3 (max is 3)
374 d = Digits{}
375 d.D[0] = 5
376 if d.Valid(3) == nil {
377 t.Error("should reject digit 5 for w=3")
378 }
379
380 // Adjacent non-zero digits (spacing violation for w=5)
381 d = Digits{}
382 d.D[0] = 1
383 d.D[3] = 1 // only 3 apart, need >= 5
384 if d.Valid(5) == nil {
385 t.Error("should reject spacing violation")
386 }
387
388 // Valid: non-zero digits exactly w apart
389 d = Digits{}
390 d.D[0] = 1
391 d.D[5] = 1
392 if err := d.Valid(5); err != nil {
393 t.Errorf("should accept spacing of exactly w: %v", err)
394 }
395 }
396
397 // TestNonZeroDigitCount verifies that wNAF produces a sparse representation.
398 func TestNonZeroDigitCount(t *testing.T) {
399 const iterations = 200
400
401 for w := 2; w <= 8; w++ {
402 t.Run("w="+itoa(w), func(t *testing.T) {
403 maxExpected := 256/w + 2 // floor(256/w) + 1 body digits + 1 possible carry
404 for range iterations {
405 s := randomScalar()
406 d := Encode(s, w)
407
408 count := 0
409 for i := range 257 {
410 if d.D[i] != 0 {
411 count++
412 }
413 }
414
415 if count > maxExpected {
416 t.Errorf("too many non-zero digits: %d (expected <= %d) for w=%d",
417 count, maxExpected, w)
418 }
419 }
420 })
421 }
422 }
423
424 // itoa is a simple int-to-string for test names without importing strconv.
425 func itoa(n int) string {
426 if n == 0 {
427 return "0"
428 }
429 var buf [20]byte
430 i := len(buf)
431 for n > 0 {
432 i--
433 buf[i] = byte('0' + n%10)
434 n /= 10
435 }
436 return string(buf[i:])
437 }
438