schnorr_batch.go raw

   1  //go:build !js && !wasm && !tinygo && !wasm32
   2  
   3  package p256k1
   4  
   5  import (
   6  	"crypto/sha256"
   7  	"encoding/binary"
   8  )
   9  
  10  // BatchSchnorrItem represents a single signature to be verified in a batch
  11  type BatchSchnorrItem struct {
  12  	Pubkey    *XOnlyPubkey // 32-byte x-only public key
  13  	Message   []byte       // 32-byte message
  14  	Signature []byte       // 64-byte signature (r || s)
  15  }
  16  
  17  // SchnorrBatchVerify verifies multiple Schnorr signatures in a single batch.
  18  // This is significantly faster than verifying signatures individually when n > 1.
  19  //
  20  // The algorithm uses random linear combination to verify all signatures at once:
  21  //   - Generate random coefficients z_i from hash of all inputs
  22  //   - Compute: R = (Σ z_i * s_i) * G - Σ (z_i * e_i) * P_i
  23  //   - If R is the point at infinity, all signatures are valid
  24  //
  25  // If the batch fails, the caller should fall back to individual verification
  26  // to identify which signature(s) are invalid.
  27  //
  28  // Returns true if all signatures in the batch are valid.
  29  func SchnorrBatchVerify(items []BatchSchnorrItem) bool {
  30  	n := len(items)
  31  	if n == 0 {
  32  		return true // Empty batch is trivially valid
  33  	}
  34  
  35  	// For single signature, fall back to individual verification
  36  	if n == 1 {
  37  		return SchnorrVerify(items[0].Signature, items[0].Message, items[0].Pubkey)
  38  	}
  39  
  40  	// Validate all inputs first
  41  	for i := range items {
  42  		if items[i].Pubkey == nil {
  43  			return false
  44  		}
  45  		if len(items[i].Message) != 32 {
  46  			return false
  47  		}
  48  		if len(items[i].Signature) != 64 {
  49  			return false
  50  		}
  51  	}
  52  
  53  	// Generate the batch seed for random coefficient generation
  54  	// seed = SHA256(all signatures || all messages || all pubkeys)
  55  	batchSeed := computeBatchSeed(items)
  56  
  57  	// Parse all signatures and compute challenges
  58  	type parsedItem struct {
  59  		rx FieldElement     // r value from signature
  60  		s  Scalar           // s value from signature
  61  		e  Scalar           // challenge e
  62  		pk GroupElementAffine // public key point
  63  		z  Scalar           // random coefficient
  64  	}
  65  
  66  	parsed := make([]parsedItem, n)
  67  
  68  	for i := range items {
  69  		// Parse r as field element
  70  		var r32 [32]byte
  71  		copy(r32[:], items[i].Signature[:32])
  72  		if err := parsed[i].rx.setB32(r32[:]); err != nil {
  73  			return false
  74  		}
  75  
  76  		// Parse s as scalar
  77  		var s32 [32]byte
  78  		copy(s32[:], items[i].Signature[32:])
  79  		parsed[i].s.setB32(s32[:])
  80  		if parsed[i].s.isZero() {
  81  			return false
  82  		}
  83  
  84  		// Check that s < order (additional validation)
  85  		if parsed[i].s.checkOverflow() {
  86  			return false
  87  		}
  88  
  89  		// Parse public key - x-only pubkey with even Y
  90  		if err := parsed[i].pk.x.setB32(items[i].Pubkey.data[:]); err != nil {
  91  			return false
  92  		}
  93  		if !parsed[i].pk.setXOVar(&parsed[i].pk.x, false) {
  94  			return false
  95  		}
  96  
  97  		// Compute challenge e = TaggedHash("BIP0340/challenge", r || pk || msg)
  98  		var challengeInput [96]byte
  99  		copy(challengeInput[0:32], items[i].Signature[:32])
 100  		copy(challengeInput[32:64], items[i].Pubkey.data[:])
 101  		copy(challengeInput[64:96], items[i].Message)
 102  		challengeHash := TaggedHash(bip340ChallengeTag, challengeInput[:])
 103  		parsed[i].e.setB32(challengeHash[:])
 104  
 105  		// Generate random coefficient z_i from batch seed and index
 106  		parsed[i].z = generateBatchCoefficient(batchSeed, uint32(i))
 107  	}
 108  
 109  	// Batch verification equation:
 110  	// (Σ z_i * s_i) * G - Σ (z_i * e_i) * P_i - Σ z_i * R_i = O
 111  	//
 112  	// For efficiency, we use Strauss-style multi-scalar multiplication:
 113  	// 1. Prepare all points and scalars
 114  	// 2. Use a combined loop that shares doublings
 115  
 116  	// Prepare points and scalars for multi-scalar multiplication
 117  	// We have 2n+1 terms: 1 generator term + n pubkey terms + n R terms
 118  	points := make([]GroupElementAffine, 2*n)
 119  	scalars := make([]Scalar, 2*n)
 120  
 121  	// Accumulate scalar for generator: sumS = Σ z_i * s_i
 122  	var sumS Scalar
 123  	sumS.setInt(0)
 124  
 125  	for i := range parsed {
 126  		// Compute z_i * s_i for generator
 127  		var zs Scalar
 128  		zs.mul(&parsed[i].z, &parsed[i].s)
 129  		sumS.add(&sumS, &zs)
 130  
 131  		// Store -z_i * e_i * P_i (negated because we subtract)
 132  		var ze Scalar
 133  		ze.mul(&parsed[i].z, &parsed[i].e)
 134  		ze.negate(&ze)
 135  		scalars[i] = ze
 136  		points[i] = parsed[i].pk
 137  
 138  		// Lift r to point R_i with even Y
 139  		var Ri GroupElementAffine
 140  		if !Ri.setXOVar(&parsed[i].rx, false) {
 141  			return false
 142  		}
 143  
 144  		// Store -z_i * R_i (negated because we subtract)
 145  		var negZ Scalar
 146  		negZ.negate(&parsed[i].z)
 147  		scalars[n+i] = negZ
 148  		points[n+i] = Ri
 149  	}
 150  
 151  	// Use multi-scalar multiplication with Strauss algorithm
 152  	// result = sumS*G + Σ scalars[i] * points[i]
 153  	result := ecmultMultiStrauss(&sumS, points, scalars)
 154  
 155  	// Batch is valid if result is the point at infinity
 156  	return result.isInfinity()
 157  }
 158  
 159  // ecmultMultiStrauss computes ng*G + Σ scalars[i]*points[i] using Strauss algorithm
 160  // This shares doublings across all scalar multiplications for improved performance.
 161  func ecmultMultiStrauss(ng *Scalar, points []GroupElementAffine, scalars []Scalar) GroupElementJacobian {
 162  	n := len(points)
 163  	if n == 0 {
 164  		var result GroupElementJacobian
 165  		EcmultGen(&result, ng)
 166  		return result
 167  	}
 168  
 169  	// Use wNAF representation for all scalars
 170  	const w = 5
 171  	const tableSize = 1 << (w - 1) // 16
 172  
 173  	// Build tables for all points
 174  	type tableEntry struct {
 175  		pre [tableSize]GroupElementAffine
 176  	}
 177  	tables := make([]tableEntry, n)
 178  
 179  	for i := range points {
 180  		if points[i].isInfinity() {
 181  			continue
 182  		}
 183  
 184  		var pJac GroupElementJacobian
 185  		pJac.setGE(&points[i])
 186  
 187  		// Build odd multiples table
 188  		var preJac [tableSize]GroupElementJacobian
 189  		preJac[0] = pJac
 190  
 191  		if tableSize > 1 {
 192  			var twoP GroupElementJacobian
 193  			twoP.double(&pJac)
 194  
 195  			for j := 1; j < tableSize; j++ {
 196  				preJac[j].addVar(&preJac[j-1], &twoP)
 197  			}
 198  		}
 199  
 200  		// Convert to affine
 201  		batchNormalize16(&tables[i].pre, &preJac)
 202  	}
 203  
 204  	// Convert all scalars to wNAF
 205  	wnafs := make([][257]int8, n)
 206  	wnafBits := make([]int, n)
 207  
 208  	maxBits := 0
 209  	for i := range scalars {
 210  		wnafBits[i] = scalars[i].wNAF(&wnafs[i], w)
 211  		if wnafBits[i] > maxBits {
 212  			maxBits = wnafBits[i]
 213  		}
 214  	}
 215  
 216  	// Also convert generator scalar to wNAF
 217  	var ngWNAF [257]int8
 218  	ngBits := ng.wNAF(&ngWNAF, w)
 219  	if ngBits > maxBits {
 220  		maxBits = ngBits
 221  	}
 222  
 223  	// Perform Strauss algorithm
 224  	var result GroupElementJacobian
 225  	result.setInfinity()
 226  
 227  	for bit := maxBits - 1; bit >= 0; bit-- {
 228  		// Double
 229  		if !result.isInfinity() {
 230  			result.double(&result)
 231  		}
 232  
 233  		// Add generator contribution
 234  		if bit < ngBits && ngWNAF[bit] != 0 {
 235  			idx := ngWNAF[bit]
 236  			if idx > 0 {
 237  				// Use precomputed generator table
 238  				var pt GroupElementAffine
 239  				pt = preGenG[(idx-1)/2]
 240  				if result.isInfinity() {
 241  					result.setGE(&pt)
 242  				} else {
 243  					result.addGE(&result, &pt)
 244  				}
 245  			} else {
 246  				var pt GroupElementAffine
 247  				pt = preGenG[(-idx-1)/2]
 248  				pt.negate(&pt)
 249  				if result.isInfinity() {
 250  					result.setGE(&pt)
 251  				} else {
 252  					result.addGE(&result, &pt)
 253  				}
 254  			}
 255  		}
 256  
 257  		// Add contributions from each point
 258  		for i := range scalars {
 259  			if bit >= wnafBits[i] || wnafs[i][bit] == 0 {
 260  				continue
 261  			}
 262  
 263  			idx := wnafs[i][bit]
 264  			var pt GroupElementAffine
 265  
 266  			if idx > 0 {
 267  				pt = tables[i].pre[(idx-1)/2]
 268  			} else {
 269  				pt = tables[i].pre[(-idx-1)/2]
 270  				pt.negate(&pt)
 271  			}
 272  
 273  			if result.isInfinity() {
 274  				result.setGE(&pt)
 275  			} else {
 276  				result.addGE(&result, &pt)
 277  			}
 278  		}
 279  	}
 280  
 281  	return result
 282  }
 283  
 284  // computeBatchSeed computes a deterministic seed for generating random coefficients
 285  // seed = SHA256("SchnorrBatchVerify" || n || sig_0 || ... || sig_{n-1} || msg_0 || ... || pk_0 || ...)
 286  func computeBatchSeed(items []BatchSchnorrItem) [32]byte {
 287  	h := sha256.New()
 288  
 289  	// Domain separator
 290  	h.Write([]byte("SchnorrBatchVerify"))
 291  
 292  	// Number of items
 293  	var buf [4]byte
 294  	binary.BigEndian.PutUint32(buf[:], uint32(len(items)))
 295  	h.Write(buf[:])
 296  
 297  	// All signatures
 298  	for i := range items {
 299  		h.Write(items[i].Signature)
 300  	}
 301  
 302  	// All messages
 303  	for i := range items {
 304  		h.Write(items[i].Message)
 305  	}
 306  
 307  	// All pubkeys
 308  	for i := range items {
 309  		h.Write(items[i].Pubkey.data[:])
 310  	}
 311  
 312  	var result [32]byte
 313  	h.Sum(result[:0])
 314  	return result
 315  }
 316  
 317  // generateBatchCoefficient generates a random coefficient for batch verification
 318  // z_i = SHA256(seed || i) mod n
 319  func generateBatchCoefficient(seed [32]byte, index uint32) Scalar {
 320  	h := sha256.New()
 321  	h.Write(seed[:])
 322  
 323  	var buf [4]byte
 324  	binary.BigEndian.PutUint32(buf[:], index)
 325  	h.Write(buf[:])
 326  
 327  	var hash [32]byte
 328  	h.Sum(hash[:0])
 329  
 330  	var z Scalar
 331  	z.setB32(hash[:])
 332  
 333  	// Ensure z is non-zero
 334  	if z.isZero() {
 335  		z.setInt(1)
 336  	}
 337  
 338  	return z
 339  }
 340  
 341  // SchnorrBatchVerifyWithFallback verifies a batch and returns the indices of invalid signatures.
 342  // This is useful when you need to identify which specific signatures failed.
 343  // Returns (true, nil) if all signatures are valid.
 344  // Returns (false, invalidIndices) if some signatures are invalid.
 345  func SchnorrBatchVerifyWithFallback(items []BatchSchnorrItem) (bool, []int) {
 346  	// First try batch verification
 347  	if SchnorrBatchVerify(items) {
 348  		return true, nil
 349  	}
 350  
 351  	// Batch failed - identify invalid signatures
 352  	var invalidIndices []int
 353  	for i := range items {
 354  		if !SchnorrVerify(items[i].Signature, items[i].Message, items[i].Pubkey) {
 355  			invalidIndices = append(invalidIndices, i)
 356  		}
 357  	}
 358  
 359  	return false, invalidIndices
 360  }
 361