wnaf.go raw

   1  // Package wnaf implements windowed Non-Adjacent Form (wNAF) encoding for
   2  // 256-bit scalars. wNAF is a signed-digit representation that minimizes
   3  // the number of non-zero digits, reducing point additions in elliptic
   4  // curve scalar multiplication.
   5  //
   6  // All types are stack-allocated fixed-length arrays with no heap allocation.
   7  package wnaf
   8  
   9  import (
  10  	"fmt"
  11  	"math/bits"
  12  )
  13  
  14  // Digits holds a wNAF representation: up to 257 signed digits.
  15  // Each non-zero digit is an odd integer in [-(2^(w-1)-1), 2^(w-1)-1],
  16  // and non-zero digits are separated by at least w-1 zeros.
  17  type Digits struct {
  18  	D   [257]int8 // signed digits
  19  	Len int       // number of significant positions (highest non-zero index + 1)
  20  }
  21  
  22  // Encode converts a 256-bit scalar (as [4]uint64, little-endian limbs) into
  23  // wNAF representation with window width w. w must be in [2, 8].
  24  //
  25  // The algorithm processes bits from LSB to MSB, extracting w-bit windows
  26  // at each non-zero position and using carry propagation to ensure the
  27  // non-adjacency property.
  28  func Encode(scalar [4]uint64, w int) Digits {
  29  	if w < 2 || w > 8 {
  30  		panic("wnaf: w must be in [2, 8]")
  31  	}
  32  
  33  	var d Digits
  34  	var carry uint32
  35  
  36  	bit := 0
  37  	for bit < 256 {
  38  		if getBits(scalar, uint(bit), 1) == carry {
  39  			bit++
  40  			continue
  41  		}
  42  
  43  		window := uint(w)
  44  		if bit+int(window) > 256 {
  45  			window = uint(256 - bit)
  46  		}
  47  
  48  		word := getBits(scalar, uint(bit), window) + carry
  49  
  50  		carry = (word >> (window - 1)) & 1
  51  		word -= carry << window
  52  
  53  		d.D[bit] = int8(int32(word))
  54  		d.Len = bit + int(window)
  55  
  56  		bit += int(window)
  57  	}
  58  
  59  	if carry != 0 {
  60  		d.D[256] = int8(carry)
  61  		d.Len = 257
  62  	}
  63  
  64  	return d
  65  }
  66  
  67  // EncodeSigned converts a scalar to wNAF, handling sign normalization.
  68  // If the scalar has bit 255 set (is "negative" in modular sense), it is
  69  // negated before encoding. Returns the digits and whether negation occurred.
  70  // The caller must negate the final EC point result if negated is true.
  71  func EncodeSigned(scalar [4]uint64, w int) (Digits, bool) {
  72  	if getBits(scalar, 255, 1) == 1 {
  73  		scalar = negate256(scalar)
  74  		d := Encode(scalar, w)
  75  		return d, true
  76  	}
  77  	return Encode(scalar, w), false
  78  }
  79  
  80  // Reconstruct recovers the original scalar value from a wNAF representation.
  81  // Returns [4]uint64 in little-endian limb order.
  82  //
  83  // Splits digits into positive and negative contributions, accumulates each
  84  // as an unsigned 320-bit value, then subtracts to get the result.
  85  func (d *Digits) Reconstruct() [4]uint64 {
  86  	var pos, neg [5]uint64
  87  
  88  	for i := range 257 {
  89  		if d.D[i] == 0 {
  90  			continue
  91  		}
  92  
  93  		limb := uint(i) / 64
  94  		shift := uint(i) % 64
  95  
  96  		if d.D[i] > 0 {
  97  			addShifted(&pos, uint64(d.D[i]), limb, shift)
  98  		} else {
  99  			addShifted(&neg, uint64(-d.D[i]), limb, shift)
 100  		}
 101  	}
 102  
 103  	return sub320(&pos, &neg)
 104  }
 105  
 106  // addShifted adds val << (limb*64 + shift) to a 320-bit accumulator.
 107  func addShifted(acc *[5]uint64, val uint64, limb, shift uint) {
 108  	lo := val << shift
 109  	old := acc[limb]
 110  	acc[limb] += lo
 111  	if acc[limb] < old {
 112  		for j := limb + 1; j < 5; j++ {
 113  			acc[j]++
 114  			if acc[j] != 0 {
 115  				break
 116  			}
 117  		}
 118  	}
 119  
 120  	if shift > 0 {
 121  		hi := val >> (64 - shift)
 122  		if hi > 0 && limb+1 < 5 {
 123  			old = acc[limb+1]
 124  			acc[limb+1] += hi
 125  			if acc[limb+1] < old {
 126  				for j := limb + 2; j < 5; j++ {
 127  					acc[j]++
 128  					if acc[j] != 0 {
 129  						break
 130  					}
 131  				}
 132  			}
 133  		}
 134  	}
 135  }
 136  
 137  // sub320 computes a - b for 320-bit values, returning the lower 256 bits.
 138  func sub320(a, b *[5]uint64) [4]uint64 {
 139  	var result [4]uint64
 140  	var borrow uint64
 141  	for i := range 4 {
 142  		result[i], borrow = bits.Sub64(a[i], b[i], borrow)
 143  	}
 144  	return result
 145  }
 146  
 147  // Valid checks all wNAF structural invariants for window width w.
 148  // Returns nil if valid, or an error describing the violation.
 149  func (d *Digits) Valid(w int) error {
 150  	if w < 2 || w > 8 {
 151  		return fmt.Errorf("wnaf: invalid window width %d, must be in [2, 8]", w)
 152  	}
 153  
 154  	maxDigit := int8((1 << (w - 1)) - 1)
 155  	lastNonZero := -1
 156  
 157  	for i := range 257 {
 158  		digit := d.D[i]
 159  		if digit == 0 {
 160  			continue
 161  		}
 162  
 163  		// Position 256 is the carry bit: always 1, exempt from spacing rule
 164  		if i == 256 {
 165  			if digit != 1 {
 166  				return fmt.Errorf("wnaf: carry digit at position 256 must be 1, got %d", digit)
 167  			}
 168  			lastNonZero = i
 169  			continue
 170  		}
 171  
 172  		if digit%2 == 0 {
 173  			return fmt.Errorf("wnaf: digit at position %d is even (%d)", i, digit)
 174  		}
 175  
 176  		if digit > maxDigit || digit < -maxDigit {
 177  			return fmt.Errorf("wnaf: digit at position %d out of range: %d (max %d)",
 178  				i, digit, maxDigit)
 179  		}
 180  
 181  		if lastNonZero >= 0 && lastNonZero < 256 && i-lastNonZero < w {
 182  			return fmt.Errorf("wnaf: non-zero digits at positions %d and %d "+
 183  				"are only %d apart (need >= %d)", lastNonZero, i, i-lastNonZero, w)
 184  		}
 185  
 186  		lastNonZero = i
 187  	}
 188  
 189  	return nil
 190  }
 191  
 192  // FromBytes converts a 32-byte big-endian scalar into [4]uint64 little-endian
 193  // limb order (matching the internal Scalar.d layout).
 194  func FromBytes(b [32]byte) [4]uint64 {
 195  	var s [4]uint64
 196  	s[3] = uint64(b[0])<<56 | uint64(b[1])<<48 | uint64(b[2])<<40 | uint64(b[3])<<32 |
 197  		uint64(b[4])<<24 | uint64(b[5])<<16 | uint64(b[6])<<8 | uint64(b[7])
 198  	s[2] = uint64(b[8])<<56 | uint64(b[9])<<48 | uint64(b[10])<<40 | uint64(b[11])<<32 |
 199  		uint64(b[12])<<24 | uint64(b[13])<<16 | uint64(b[14])<<8 | uint64(b[15])
 200  	s[1] = uint64(b[16])<<56 | uint64(b[17])<<48 | uint64(b[18])<<40 | uint64(b[19])<<32 |
 201  		uint64(b[20])<<24 | uint64(b[21])<<16 | uint64(b[22])<<8 | uint64(b[23])
 202  	s[0] = uint64(b[24])<<56 | uint64(b[25])<<48 | uint64(b[26])<<40 | uint64(b[27])<<32 |
 203  		uint64(b[28])<<24 | uint64(b[29])<<16 | uint64(b[30])<<8 | uint64(b[31])
 204  	return s
 205  }
 206  
 207  // getBits extracts count bits starting at offset from a [4]uint64 scalar.
 208  // offset+count must be <= 256. count must be in [1, 32].
 209  func getBits(scalar [4]uint64, offset, count uint) uint32 {
 210  	limbIdx := offset / 64
 211  	bitIdx := offset % 64
 212  
 213  	if bitIdx+count <= 64 {
 214  		return uint32((scalar[limbIdx] >> bitIdx) & ((1 << count) - 1))
 215  	}
 216  	lowBits := 64 - bitIdx
 217  	highBits := count - lowBits
 218  	low := uint32((scalar[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
 219  	high := uint32(scalar[limbIdx+1] & ((1 << highBits) - 1))
 220  	return low | (high << lowBits)
 221  }
 222  
 223  // negate256 computes the two's complement negation of a 256-bit value.
 224  func negate256(s [4]uint64) [4]uint64 {
 225  	s[0] = ^s[0]
 226  	s[1] = ^s[1]
 227  	s[2] = ^s[2]
 228  	s[3] = ^s[3]
 229  	s[0]++
 230  	if s[0] == 0 {
 231  		s[1]++
 232  		if s[1] == 0 {
 233  			s[2]++
 234  			if s[2] == 0 {
 235  				s[3]++
 236  			}
 237  		}
 238  	}
 239  	return s
 240  }
 241