wnaf.go raw
1 // Package wnaf implements windowed Non-Adjacent Form (wNAF) encoding for
2 // 256-bit scalars. wNAF is a signed-digit representation that minimizes
3 // the number of non-zero digits, reducing point additions in elliptic
4 // curve scalar multiplication.
5 //
6 // All types are stack-allocated fixed-length arrays with no heap allocation.
7 package wnaf
8
9 import (
10 "fmt"
11 "math/bits"
12 )
13
14 // Digits holds a wNAF representation: up to 257 signed digits.
15 // Each non-zero digit is an odd integer in [-(2^(w-1)-1), 2^(w-1)-1],
16 // and non-zero digits are separated by at least w-1 zeros.
17 type Digits struct {
18 D [257]int8 // signed digits
19 Len int // number of significant positions (highest non-zero index + 1)
20 }
21
22 // Encode converts a 256-bit scalar (as [4]uint64, little-endian limbs) into
23 // wNAF representation with window width w. w must be in [2, 8].
24 //
25 // The algorithm processes bits from LSB to MSB, extracting w-bit windows
26 // at each non-zero position and using carry propagation to ensure the
27 // non-adjacency property.
28 func Encode(scalar [4]uint64, w int) Digits {
29 if w < 2 || w > 8 {
30 panic("wnaf: w must be in [2, 8]")
31 }
32
33 var d Digits
34 var carry uint32
35
36 bit := 0
37 for bit < 256 {
38 if getBits(scalar, uint(bit), 1) == carry {
39 bit++
40 continue
41 }
42
43 window := uint(w)
44 if bit+int(window) > 256 {
45 window = uint(256 - bit)
46 }
47
48 word := getBits(scalar, uint(bit), window) + carry
49
50 carry = (word >> (window - 1)) & 1
51 word -= carry << window
52
53 d.D[bit] = int8(int32(word))
54 d.Len = bit + int(window)
55
56 bit += int(window)
57 }
58
59 if carry != 0 {
60 d.D[256] = int8(carry)
61 d.Len = 257
62 }
63
64 return d
65 }
66
67 // EncodeSigned converts a scalar to wNAF, handling sign normalization.
68 // If the scalar has bit 255 set (is "negative" in modular sense), it is
69 // negated before encoding. Returns the digits and whether negation occurred.
70 // The caller must negate the final EC point result if negated is true.
71 func EncodeSigned(scalar [4]uint64, w int) (Digits, bool) {
72 if getBits(scalar, 255, 1) == 1 {
73 scalar = negate256(scalar)
74 d := Encode(scalar, w)
75 return d, true
76 }
77 return Encode(scalar, w), false
78 }
79
80 // Reconstruct recovers the original scalar value from a wNAF representation.
81 // Returns [4]uint64 in little-endian limb order.
82 //
83 // Splits digits into positive and negative contributions, accumulates each
84 // as an unsigned 320-bit value, then subtracts to get the result.
85 func (d *Digits) Reconstruct() [4]uint64 {
86 var pos, neg [5]uint64
87
88 for i := range 257 {
89 if d.D[i] == 0 {
90 continue
91 }
92
93 limb := uint(i) / 64
94 shift := uint(i) % 64
95
96 if d.D[i] > 0 {
97 addShifted(&pos, uint64(d.D[i]), limb, shift)
98 } else {
99 addShifted(&neg, uint64(-d.D[i]), limb, shift)
100 }
101 }
102
103 return sub320(&pos, &neg)
104 }
105
106 // addShifted adds val << (limb*64 + shift) to a 320-bit accumulator.
107 func addShifted(acc *[5]uint64, val uint64, limb, shift uint) {
108 lo := val << shift
109 old := acc[limb]
110 acc[limb] += lo
111 if acc[limb] < old {
112 for j := limb + 1; j < 5; j++ {
113 acc[j]++
114 if acc[j] != 0 {
115 break
116 }
117 }
118 }
119
120 if shift > 0 {
121 hi := val >> (64 - shift)
122 if hi > 0 && limb+1 < 5 {
123 old = acc[limb+1]
124 acc[limb+1] += hi
125 if acc[limb+1] < old {
126 for j := limb + 2; j < 5; j++ {
127 acc[j]++
128 if acc[j] != 0 {
129 break
130 }
131 }
132 }
133 }
134 }
135 }
136
137 // sub320 computes a - b for 320-bit values, returning the lower 256 bits.
138 func sub320(a, b *[5]uint64) [4]uint64 {
139 var result [4]uint64
140 var borrow uint64
141 for i := range 4 {
142 result[i], borrow = bits.Sub64(a[i], b[i], borrow)
143 }
144 return result
145 }
146
147 // Valid checks all wNAF structural invariants for window width w.
148 // Returns nil if valid, or an error describing the violation.
149 func (d *Digits) Valid(w int) error {
150 if w < 2 || w > 8 {
151 return fmt.Errorf("wnaf: invalid window width %d, must be in [2, 8]", w)
152 }
153
154 maxDigit := int8((1 << (w - 1)) - 1)
155 lastNonZero := -1
156
157 for i := range 257 {
158 digit := d.D[i]
159 if digit == 0 {
160 continue
161 }
162
163 // Position 256 is the carry bit: always 1, exempt from spacing rule
164 if i == 256 {
165 if digit != 1 {
166 return fmt.Errorf("wnaf: carry digit at position 256 must be 1, got %d", digit)
167 }
168 lastNonZero = i
169 continue
170 }
171
172 if digit%2 == 0 {
173 return fmt.Errorf("wnaf: digit at position %d is even (%d)", i, digit)
174 }
175
176 if digit > maxDigit || digit < -maxDigit {
177 return fmt.Errorf("wnaf: digit at position %d out of range: %d (max %d)",
178 i, digit, maxDigit)
179 }
180
181 if lastNonZero >= 0 && lastNonZero < 256 && i-lastNonZero < w {
182 return fmt.Errorf("wnaf: non-zero digits at positions %d and %d "+
183 "are only %d apart (need >= %d)", lastNonZero, i, i-lastNonZero, w)
184 }
185
186 lastNonZero = i
187 }
188
189 return nil
190 }
191
192 // FromBytes converts a 32-byte big-endian scalar into [4]uint64 little-endian
193 // limb order (matching the internal Scalar.d layout).
194 func FromBytes(b [32]byte) [4]uint64 {
195 var s [4]uint64
196 s[3] = uint64(b[0])<<56 | uint64(b[1])<<48 | uint64(b[2])<<40 | uint64(b[3])<<32 |
197 uint64(b[4])<<24 | uint64(b[5])<<16 | uint64(b[6])<<8 | uint64(b[7])
198 s[2] = uint64(b[8])<<56 | uint64(b[9])<<48 | uint64(b[10])<<40 | uint64(b[11])<<32 |
199 uint64(b[12])<<24 | uint64(b[13])<<16 | uint64(b[14])<<8 | uint64(b[15])
200 s[1] = uint64(b[16])<<56 | uint64(b[17])<<48 | uint64(b[18])<<40 | uint64(b[19])<<32 |
201 uint64(b[20])<<24 | uint64(b[21])<<16 | uint64(b[22])<<8 | uint64(b[23])
202 s[0] = uint64(b[24])<<56 | uint64(b[25])<<48 | uint64(b[26])<<40 | uint64(b[27])<<32 |
203 uint64(b[28])<<24 | uint64(b[29])<<16 | uint64(b[30])<<8 | uint64(b[31])
204 return s
205 }
206
207 // getBits extracts count bits starting at offset from a [4]uint64 scalar.
208 // offset+count must be <= 256. count must be in [1, 32].
209 func getBits(scalar [4]uint64, offset, count uint) uint32 {
210 limbIdx := offset / 64
211 bitIdx := offset % 64
212
213 if bitIdx+count <= 64 {
214 return uint32((scalar[limbIdx] >> bitIdx) & ((1 << count) - 1))
215 }
216 lowBits := 64 - bitIdx
217 highBits := count - lowBits
218 low := uint32((scalar[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
219 high := uint32(scalar[limbIdx+1] & ((1 << highBits) - 1))
220 return low | (high << lowBits)
221 }
222
223 // negate256 computes the two's complement negation of a 256-bit value.
224 func negate256(s [4]uint64) [4]uint64 {
225 s[0] = ^s[0]
226 s[1] = ^s[1]
227 s[2] = ^s[2]
228 s[3] = ^s[3]
229 s[0]++
230 if s[0] == 0 {
231 s[1]++
232 if s[1] == 0 {
233 s[2]++
234 if s[2] == 0 {
235 s[3]++
236 }
237 }
238 }
239 return s
240 }
241