schnorr_batch.go raw
1 //go:build !js && !wasm && !tinygo && !wasm32
2
3 package p256k1
4
5 import (
6 "crypto/sha256"
7 "encoding/binary"
8 )
9
10 // BatchSchnorrItem represents a single signature to be verified in a batch
11 type BatchSchnorrItem struct {
12 Pubkey *XOnlyPubkey // 32-byte x-only public key
13 Message []byte // 32-byte message
14 Signature []byte // 64-byte signature (r || s)
15 }
16
17 // SchnorrBatchVerify verifies multiple Schnorr signatures in a single batch.
18 // This is significantly faster than verifying signatures individually when n > 1.
19 //
20 // The algorithm uses random linear combination to verify all signatures at once:
21 // - Generate random coefficients z_i from hash of all inputs
22 // - Compute: R = (Σ z_i * s_i) * G - Σ (z_i * e_i) * P_i
23 // - If R is the point at infinity, all signatures are valid
24 //
25 // If the batch fails, the caller should fall back to individual verification
26 // to identify which signature(s) are invalid.
27 //
28 // Returns true if all signatures in the batch are valid.
29 func SchnorrBatchVerify(items []BatchSchnorrItem) bool {
30 n := len(items)
31 if n == 0 {
32 return true // Empty batch is trivially valid
33 }
34
35 // For single signature, fall back to individual verification
36 if n == 1 {
37 return SchnorrVerify(items[0].Signature, items[0].Message, items[0].Pubkey)
38 }
39
40 // Validate all inputs first
41 for i := range items {
42 if items[i].Pubkey == nil {
43 return false
44 }
45 if len(items[i].Message) != 32 {
46 return false
47 }
48 if len(items[i].Signature) != 64 {
49 return false
50 }
51 }
52
53 // Generate the batch seed for random coefficient generation
54 // seed = SHA256(all signatures || all messages || all pubkeys)
55 batchSeed := computeBatchSeed(items)
56
57 // Parse all signatures and compute challenges
58 type parsedItem struct {
59 rx FieldElement // r value from signature
60 s Scalar // s value from signature
61 e Scalar // challenge e
62 pk GroupElementAffine // public key point
63 z Scalar // random coefficient
64 }
65
66 parsed := make([]parsedItem, n)
67
68 for i := range items {
69 // Parse r as field element
70 var r32 [32]byte
71 copy(r32[:], items[i].Signature[:32])
72 if err := parsed[i].rx.setB32(r32[:]); err != nil {
73 return false
74 }
75
76 // Parse s as scalar
77 var s32 [32]byte
78 copy(s32[:], items[i].Signature[32:])
79 parsed[i].s.setB32(s32[:])
80 if parsed[i].s.isZero() {
81 return false
82 }
83
84 // Check that s < order (additional validation)
85 if parsed[i].s.checkOverflow() {
86 return false
87 }
88
89 // Parse public key - x-only pubkey with even Y
90 if err := parsed[i].pk.x.setB32(items[i].Pubkey.data[:]); err != nil {
91 return false
92 }
93 if !parsed[i].pk.setXOVar(&parsed[i].pk.x, false) {
94 return false
95 }
96
97 // Compute challenge e = TaggedHash("BIP0340/challenge", r || pk || msg)
98 var challengeInput [96]byte
99 copy(challengeInput[0:32], items[i].Signature[:32])
100 copy(challengeInput[32:64], items[i].Pubkey.data[:])
101 copy(challengeInput[64:96], items[i].Message)
102 challengeHash := TaggedHash(bip340ChallengeTag, challengeInput[:])
103 parsed[i].e.setB32(challengeHash[:])
104
105 // Generate random coefficient z_i from batch seed and index
106 parsed[i].z = generateBatchCoefficient(batchSeed, uint32(i))
107 }
108
109 // Batch verification equation:
110 // (Σ z_i * s_i) * G - Σ (z_i * e_i) * P_i - Σ z_i * R_i = O
111 //
112 // For efficiency, we use Strauss-style multi-scalar multiplication:
113 // 1. Prepare all points and scalars
114 // 2. Use a combined loop that shares doublings
115
116 // Prepare points and scalars for multi-scalar multiplication
117 // We have 2n+1 terms: 1 generator term + n pubkey terms + n R terms
118 points := make([]GroupElementAffine, 2*n)
119 scalars := make([]Scalar, 2*n)
120
121 // Accumulate scalar for generator: sumS = Σ z_i * s_i
122 var sumS Scalar
123 sumS.setInt(0)
124
125 for i := range parsed {
126 // Compute z_i * s_i for generator
127 var zs Scalar
128 zs.mul(&parsed[i].z, &parsed[i].s)
129 sumS.add(&sumS, &zs)
130
131 // Store -z_i * e_i * P_i (negated because we subtract)
132 var ze Scalar
133 ze.mul(&parsed[i].z, &parsed[i].e)
134 ze.negate(&ze)
135 scalars[i] = ze
136 points[i] = parsed[i].pk
137
138 // Lift r to point R_i with even Y
139 var Ri GroupElementAffine
140 if !Ri.setXOVar(&parsed[i].rx, false) {
141 return false
142 }
143
144 // Store -z_i * R_i (negated because we subtract)
145 var negZ Scalar
146 negZ.negate(&parsed[i].z)
147 scalars[n+i] = negZ
148 points[n+i] = Ri
149 }
150
151 // Use multi-scalar multiplication with Strauss algorithm
152 // result = sumS*G + Σ scalars[i] * points[i]
153 result := ecmultMultiStrauss(&sumS, points, scalars)
154
155 // Batch is valid if result is the point at infinity
156 return result.isInfinity()
157 }
158
159 // ecmultMultiStrauss computes ng*G + Σ scalars[i]*points[i] using Strauss algorithm
160 // This shares doublings across all scalar multiplications for improved performance.
161 func ecmultMultiStrauss(ng *Scalar, points []GroupElementAffine, scalars []Scalar) GroupElementJacobian {
162 n := len(points)
163 if n == 0 {
164 var result GroupElementJacobian
165 EcmultGen(&result, ng)
166 return result
167 }
168
169 // Use wNAF representation for all scalars
170 const w = 5
171 const tableSize = 1 << (w - 1) // 16
172
173 // Build tables for all points
174 type tableEntry struct {
175 pre [tableSize]GroupElementAffine
176 }
177 tables := make([]tableEntry, n)
178
179 for i := range points {
180 if points[i].isInfinity() {
181 continue
182 }
183
184 var pJac GroupElementJacobian
185 pJac.setGE(&points[i])
186
187 // Build odd multiples table
188 var preJac [tableSize]GroupElementJacobian
189 preJac[0] = pJac
190
191 if tableSize > 1 {
192 var twoP GroupElementJacobian
193 twoP.double(&pJac)
194
195 for j := 1; j < tableSize; j++ {
196 preJac[j].addVar(&preJac[j-1], &twoP)
197 }
198 }
199
200 // Convert to affine
201 batchNormalize16(&tables[i].pre, &preJac)
202 }
203
204 // Convert all scalars to wNAF
205 wnafs := make([][257]int8, n)
206 wnafBits := make([]int, n)
207
208 maxBits := 0
209 for i := range scalars {
210 wnafBits[i] = scalars[i].wNAF(&wnafs[i], w)
211 if wnafBits[i] > maxBits {
212 maxBits = wnafBits[i]
213 }
214 }
215
216 // Also convert generator scalar to wNAF
217 var ngWNAF [257]int8
218 ngBits := ng.wNAF(&ngWNAF, w)
219 if ngBits > maxBits {
220 maxBits = ngBits
221 }
222
223 // Perform Strauss algorithm
224 var result GroupElementJacobian
225 result.setInfinity()
226
227 for bit := maxBits - 1; bit >= 0; bit-- {
228 // Double
229 if !result.isInfinity() {
230 result.double(&result)
231 }
232
233 // Add generator contribution
234 if bit < ngBits && ngWNAF[bit] != 0 {
235 idx := ngWNAF[bit]
236 if idx > 0 {
237 // Use precomputed generator table
238 var pt GroupElementAffine
239 pt = preGenG[(idx-1)/2]
240 if result.isInfinity() {
241 result.setGE(&pt)
242 } else {
243 result.addGE(&result, &pt)
244 }
245 } else {
246 var pt GroupElementAffine
247 pt = preGenG[(-idx-1)/2]
248 pt.negate(&pt)
249 if result.isInfinity() {
250 result.setGE(&pt)
251 } else {
252 result.addGE(&result, &pt)
253 }
254 }
255 }
256
257 // Add contributions from each point
258 for i := range scalars {
259 if bit >= wnafBits[i] || wnafs[i][bit] == 0 {
260 continue
261 }
262
263 idx := wnafs[i][bit]
264 var pt GroupElementAffine
265
266 if idx > 0 {
267 pt = tables[i].pre[(idx-1)/2]
268 } else {
269 pt = tables[i].pre[(-idx-1)/2]
270 pt.negate(&pt)
271 }
272
273 if result.isInfinity() {
274 result.setGE(&pt)
275 } else {
276 result.addGE(&result, &pt)
277 }
278 }
279 }
280
281 return result
282 }
283
284 // computeBatchSeed computes a deterministic seed for generating random coefficients
285 // seed = SHA256("SchnorrBatchVerify" || n || sig_0 || ... || sig_{n-1} || msg_0 || ... || pk_0 || ...)
286 func computeBatchSeed(items []BatchSchnorrItem) [32]byte {
287 h := sha256.New()
288
289 // Domain separator
290 h.Write([]byte("SchnorrBatchVerify"))
291
292 // Number of items
293 var buf [4]byte
294 binary.BigEndian.PutUint32(buf[:], uint32(len(items)))
295 h.Write(buf[:])
296
297 // All signatures
298 for i := range items {
299 h.Write(items[i].Signature)
300 }
301
302 // All messages
303 for i := range items {
304 h.Write(items[i].Message)
305 }
306
307 // All pubkeys
308 for i := range items {
309 h.Write(items[i].Pubkey.data[:])
310 }
311
312 var result [32]byte
313 h.Sum(result[:0])
314 return result
315 }
316
317 // generateBatchCoefficient generates a random coefficient for batch verification
318 // z_i = SHA256(seed || i) mod n
319 func generateBatchCoefficient(seed [32]byte, index uint32) Scalar {
320 h := sha256.New()
321 h.Write(seed[:])
322
323 var buf [4]byte
324 binary.BigEndian.PutUint32(buf[:], index)
325 h.Write(buf[:])
326
327 var hash [32]byte
328 h.Sum(hash[:0])
329
330 var z Scalar
331 z.setB32(hash[:])
332
333 // Ensure z is non-zero
334 if z.isZero() {
335 z.setInt(1)
336 }
337
338 return z
339 }
340
341 // SchnorrBatchVerifyWithFallback verifies a batch and returns the indices of invalid signatures.
342 // This is useful when you need to identify which specific signatures failed.
343 // Returns (true, nil) if all signatures are valid.
344 // Returns (false, invalidIndices) if some signatures are invalid.
345 func SchnorrBatchVerifyWithFallback(items []BatchSchnorrItem) (bool, []int) {
346 // First try batch verification
347 if SchnorrBatchVerify(items) {
348 return true, nil
349 }
350
351 // Batch failed - identify invalid signatures
352 var invalidIndices []int
353 for i := range items {
354 if !SchnorrVerify(items[i].Signature, items[i].Message, items[i].Pubkey) {
355 invalidIndices = append(invalidIndices, i)
356 }
357 }
358
359 return false, invalidIndices
360 }
361