schnorr.mjs raw
1 // schnorr.mjs — BIP-340 Schnorr signatures using JS BigInt.
2 // Field/scalar arithmetic over secp256k1 with wNAF precomputed tables
3 // and Shamir's trick for fast verification.
4
5 import { Slice } from './builtin.mjs';
6
7 const P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2Fn;
8 const N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141n;
9 const Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798n;
10 const Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8n;
11
12 // --- Modular arithmetic ---
13
14 function mod(x, m) { return ((x % m) + m) % m; }
15
16 function modInv(a, m) {
17 let [old_r, r] = [mod(a, m), m];
18 let [old_s, s] = [1n, 0n];
19 while (r !== 0n) {
20 const q = old_r / r;
21 [old_r, r] = [r, old_r - q * r];
22 [old_s, s] = [s, old_s - q * s];
23 }
24 return mod(old_s, m);
25 }
26
27 function modPow(base, exp, m) {
28 let result = 1n;
29 base = mod(base, m);
30 while (exp > 0n) {
31 if (exp & 1n) result = result * base % m;
32 exp >>= 1n;
33 base = base * base % m;
34 }
35 return result;
36 }
37
38 // --- Jacobian point: [X, Y, Z] where affine x=X/Z², y=Y/Z³ ---
39 // Point at infinity: Z === 0n.
40
41 function jDbl(X1, Y1, Z1) {
42 if (Y1 === 0n) return [0n, 1n, 0n];
43 const A = mod(Y1 * Y1, P);
44 const B = mod(4n * X1 * A, P);
45 const C = mod(8n * A * A, P);
46 const D = mod(3n * X1 * X1, P); // a=0 for secp256k1
47 const X3 = mod(D * D - 2n * B, P);
48 const Y3 = mod(D * (B - X3) - C, P);
49 const Z3 = mod(2n * Y1 * Z1, P);
50 return [X3, Y3, Z3];
51 }
52
53 function jAdd(X1, Y1, Z1, X2, Y2, Z2) {
54 if (Z1 === 0n) return [X2, Y2, Z2];
55 if (Z2 === 0n) return [X1, Y1, Z1];
56 const Z1sq = mod(Z1 * Z1, P);
57 const Z2sq = mod(Z2 * Z2, P);
58 const U1 = mod(X1 * Z2sq, P);
59 const U2 = mod(X2 * Z1sq, P);
60 const S1 = mod(Y1 * mod(Z2sq * Z2, P), P);
61 const S2 = mod(Y2 * mod(Z1sq * Z1, P), P);
62 if (U1 === U2) {
63 if (S1 === S2) return jDbl(X1, Y1, Z1);
64 return [0n, 1n, 0n];
65 }
66 const H = mod(U2 - U1, P);
67 const R = mod(S2 - S1, P);
68 const H2 = mod(H * H, P);
69 const H3 = mod(H2 * H, P);
70 const X3 = mod(R * R - H3 - 2n * U1 * H2, P);
71 const Y3 = mod(R * (U1 * H2 - X3) - S1 * H3, P);
72 const Z3 = mod(H * Z1 * Z2, P);
73 return [X3, Y3, Z3];
74 }
75
76 function jNeg(X, Y, Z) {
77 return [X, mod(P - Y, P), Z];
78 }
79
80 function jToAffine(X, Y, Z) {
81 if (Z === 0n) return [0n, 0n];
82 const Zi = modInv(Z, P);
83 const Z2 = mod(Zi * Zi, P);
84 const Z3 = mod(Z2 * Zi, P);
85 return [mod(X * Z2, P), mod(Y * Z3, P)];
86 }
87
88 // --- wNAF precomputed table for generator G ---
89 // Window size W=8: precompute odd multiples 1G, 3G, 5G, ..., 255G.
90 // Then scalarMultG processes 8 bits at a time using table lookup.
91
92 const W = 8;
93 const HALF = 1 << (W - 1); // 128
94 const MASK = (1 << W) - 1; // 255
95
96 // Precomputed table: _gTable[i] = (2i+1)*G in Jacobian [X, Y, Z].
97 // Negation (for negative wNAF digits) is free: just negate Y.
98 let _gTable = null;
99
100 function ensureGTable() {
101 if (_gTable) return;
102 _gTable = new Array(HALF);
103 // 1*G
104 _gTable[0] = [Gx, Gy, 1n];
105 // 2*G
106 const g2 = jDbl(Gx, Gy, 1n);
107 for (let i = 1; i < HALF; i++) {
108 _gTable[i] = jAdd(_gTable[i - 1][0], _gTable[i - 1][1], _gTable[i - 1][2],
109 g2[0], g2[1], g2[2]);
110 }
111 }
112
113 // Encode scalar k into wNAF form with window W.
114 // Returns array of digits in {-(2^(W-1)-1) .. 2^(W-1)-1}, odd or zero.
115 function wnaf(k, w) {
116 const digits = [];
117 const halfW = 1n << BigInt(w - 1);
118 const fullW = halfW << 1n;
119 const mask = fullW - 1n;
120 while (k > 0n) {
121 if (k & 1n) {
122 let d = k & mask;
123 if (d >= halfW) d -= fullW;
124 k -= d;
125 digits.push(d);
126 } else {
127 digits.push(0n);
128 }
129 k >>= 1n;
130 }
131 return digits;
132 }
133
134 // scalarMultG: fast k*G using wNAF precomputed table.
135 function scalarMultG(k) {
136 ensureGTable();
137 k = ((k % N) + N) % N;
138 if (k === 0n) return [0n, 0n];
139
140 const digits = wnaf(k, W);
141 let RX = 0n, RY = 1n, RZ = 0n; // infinity
142
143 for (let i = digits.length - 1; i >= 0; i--) {
144 [RX, RY, RZ] = jDbl(RX, RY, RZ);
145 const d = digits[i];
146 if (d !== 0n) {
147 const idx = Number(d < 0n ? (-d - 1n) >> 1n : (d - 1n) >> 1n);
148 const pt = _gTable[idx];
149 if (d < 0n) {
150 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
151 } else {
152 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
153 }
154 }
155 }
156 return jToAffine(RX, RY, RZ);
157 }
158
159 // Plain scalar multiply for arbitrary points (no precomputation).
160 function scalarMult(k, px, py) {
161 let [RX, RY, RZ] = [0n, 1n, 0n];
162 let [CX, CY, CZ] = [px, py, 1n];
163 k = ((k % N) + N) % N;
164 while (k > 0n) {
165 if (k & 1n) [RX, RY, RZ] = jAdd(RX, RY, RZ, CX, CY, CZ);
166 [CX, CY, CZ] = jDbl(CX, CY, CZ);
167 k >>= 1n;
168 }
169 return jToAffine(RX, RY, RZ);
170 }
171
172 // Shamir's trick: compute k1*G + k2*P in one pass.
173 // Shares doublings between the two scalar multiplications.
174 function shamirMultGP(k1, k2, px, py) {
175 ensureGTable();
176 k1 = ((k1 % N) + N) % N;
177 k2 = ((k2 % N) + N) % N;
178
179 // wNAF encode k1 (for G table lookup)
180 const d1 = wnaf(k1, W);
181 // Simple binary for k2 (no precomputed table for P)
182 const d2 = wnaf(k2, 2);
183
184 const len = Math.max(d1.length, d2.length);
185
186 let RX = 0n, RY = 1n, RZ = 0n; // infinity
187
188 for (let i = len - 1; i >= 0; i--) {
189 [RX, RY, RZ] = jDbl(RX, RY, RZ);
190
191 // k1*G part: use precomputed table
192 const a = i < d1.length ? d1[i] : 0n;
193 if (a !== 0n) {
194 const idx = Number(a < 0n ? (-a - 1n) >> 1n : (a - 1n) >> 1n);
195 const pt = _gTable[idx];
196 if (a < 0n) {
197 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
198 } else {
199 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
200 }
201 }
202
203 // k2*P part: simple binary (odd digit = ±1 only)
204 const b = i < d2.length ? d2[i] : 0n;
205 if (b > 0n) {
206 [RX, RY, RZ] = jAdd(RX, RY, RZ, px, py, 1n);
207 } else if (b < 0n) {
208 [RX, RY, RZ] = jAdd(RX, RY, RZ, px, mod(P - py, P), 1n);
209 }
210 }
211
212 return jToAffine(RX, RY, RZ);
213 }
214
215 function liftX(x) {
216 const c = mod(x * x * x + 7n, P);
217 const exp = (P + 1n) / 4n;
218 let y = modPow(c, exp, P);
219 if (mod(y * y, P) !== mod(c, P)) return null;
220 if (y % 2n !== 0n) y = mod(P - y, P);
221 return y;
222 }
223
224 // --- SHA-256 (synchronous, pure JS) ---
225
226 const K256 = [
227 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
228 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
229 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
230 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
231 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
232 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
233 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
234 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
235 ];
236
237 function sha256(data) {
238 const len = data.length;
239 const bitLen = len * 8;
240 const padLen = ((56 - (len + 1) % 64) + 64) % 64;
241 const padded = new Uint8Array(len + 1 + padLen + 8);
242 padded.set(data);
243 padded[len] = 0x80;
244 const dv = new DataView(padded.buffer);
245 dv.setUint32(padded.length - 4, bitLen, false);
246
247 let [h0, h1, h2, h3, h4, h5, h6, h7] =
248 [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19];
249
250 const w = new Int32Array(64);
251 for (let off = 0; off < padded.length; off += 64) {
252 for (let i = 0; i < 16; i++) w[i] = dv.getInt32(off + i * 4, false);
253 for (let i = 16; i < 64; i++) {
254 const s0 = (ror(w[i-15], 7) ^ ror(w[i-15], 18) ^ (w[i-15] >>> 3));
255 const s1 = (ror(w[i-2], 17) ^ ror(w[i-2], 19) ^ (w[i-2] >>> 10));
256 w[i] = (w[i-16] + s0 + w[i-7] + s1) | 0;
257 }
258 let [a, b, c, d, e, f, g, h] = [h0, h1, h2, h3, h4, h5, h6, h7];
259 for (let i = 0; i < 64; i++) {
260 const S1 = ror(e, 6) ^ ror(e, 11) ^ ror(e, 25);
261 const ch = (e & f) ^ (~e & g);
262 const t1 = (h + S1 + ch + K256[i] + w[i]) | 0;
263 const S0 = ror(a, 2) ^ ror(a, 13) ^ ror(a, 22);
264 const maj = (a & b) ^ (a & c) ^ (b & c);
265 const t2 = (S0 + maj) | 0;
266 h = g; g = f; f = e; e = (d + t1) | 0;
267 d = c; c = b; b = a; a = (t1 + t2) | 0;
268 }
269 h0 = (h0 + a) | 0; h1 = (h1 + b) | 0; h2 = (h2 + c) | 0; h3 = (h3 + d) | 0;
270 h4 = (h4 + e) | 0; h5 = (h5 + f) | 0; h6 = (h6 + g) | 0; h7 = (h7 + h) | 0;
271 }
272 const out = new Uint8Array(32);
273 const odv = new DataView(out.buffer);
274 odv.setUint32(0, h0, false); odv.setUint32(4, h1, false);
275 odv.setUint32(8, h2, false); odv.setUint32(12, h3, false);
276 odv.setUint32(16, h4, false); odv.setUint32(20, h5, false);
277 odv.setUint32(24, h6, false); odv.setUint32(28, h7, false);
278 return out;
279 }
280
281 function ror(x, n) { return ((x >>> n) | (x << (32 - n))) | 0; }
282
283 function taggedHash(tag, msg) {
284 const tagH = sha256(new TextEncoder().encode(tag));
285 const buf = new Uint8Array(64 + msg.length);
286 buf.set(tagH, 0);
287 buf.set(tagH, 32);
288 buf.set(msg, 64);
289 return sha256(buf);
290 }
291
292 // --- Slice conversions ---
293
294 function sliceToU8(s) {
295 if (s instanceof Uint8Array) return s;
296 if (typeof s === 'string') return new TextEncoder().encode(s);
297 const u = new Uint8Array(s.$length);
298 for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
299 return u;
300 }
301
302 function u8ToSlice(u8) {
303 const arr = new Array(u8.length);
304 for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
305 return new Slice(arr, 0, u8.length, u8.length);
306 }
307
308 function bytesToBigInt(u8) {
309 let n = 0n;
310 for (let i = 0; i < u8.length; i++) n = (n << 8n) | BigInt(u8[i]);
311 return n;
312 }
313
314 function bigIntTo32(n) {
315 const out = new Uint8Array(32);
316 for (let i = 31; i >= 0; i--) { out[i] = Number(n & 0xFFn); n >>= 8n; }
317 return out;
318 }
319
320 // --- Exported BIP-340 functions ---
321
322 export function PubKeyFromSecKey(seckey) {
323 const sk = bytesToBigInt(sliceToU8(seckey));
324 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
325 const [px, _] = scalarMultG(sk);
326 return [u8ToSlice(bigIntTo32(px)), true];
327 }
328
329 export function SignSchnorr(seckey, msg, auxRand) {
330 const sk = bytesToBigInt(sliceToU8(seckey));
331 const m = sliceToU8(msg);
332 const aux = sliceToU8(auxRand);
333
334 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(64)), false];
335
336 const [px, py] = scalarMultG(sk);
337 let d = sk;
338 if (py % 2n !== 0n) d = N - d;
339
340 const pkBytes = bigIntTo32(px);
341
342 const auxHash = taggedHash("BIP0340/aux", aux);
343 const dBytes = bigIntTo32(d);
344 const t = new Uint8Array(32);
345 for (let i = 0; i < 32; i++) t[i] = dBytes[i] ^ auxHash[i];
346
347 const nonceInput = new Uint8Array(96);
348 nonceInput.set(t, 0);
349 nonceInput.set(pkBytes, 32);
350 nonceInput.set(m, 64);
351 const kHash = taggedHash("BIP0340/nonce", nonceInput);
352 let k = bytesToBigInt(kHash) % N;
353 if (k === 0n) return [u8ToSlice(new Uint8Array(64)), false];
354
355 const [rx, ry] = scalarMultG(k);
356 if (ry % 2n !== 0n) k = N - k;
357
358 const rxBytes = bigIntTo32(rx);
359
360 const eInput = new Uint8Array(96);
361 eInput.set(rxBytes, 0);
362 eInput.set(pkBytes, 32);
363 eInput.set(m, 64);
364 const eHash = taggedHash("BIP0340/challenge", eInput);
365 const e = bytesToBigInt(eHash) % N;
366
367 const s = ((k + e * d) % N + N) % N;
368
369 const sig = new Uint8Array(64);
370 sig.set(rxBytes, 0);
371 sig.set(bigIntTo32(s), 32);
372
373 return [u8ToSlice(sig), true];
374 }
375
376 // VerifySchnorr — uses Shamir's trick for s*G + (-e)*P.
377 export function VerifySchnorr(pubkey, msg, sig) {
378 const pkU8 = sliceToU8(pubkey);
379 const mU8 = sliceToU8(msg);
380 const sigU8 = sliceToU8(sig);
381
382 const px = bytesToBigInt(pkU8);
383 const py = liftX(px);
384 if (py === null) return false;
385
386 const rx = bytesToBigInt(sigU8.slice(0, 32));
387 const s = bytesToBigInt(sigU8.slice(32, 64));
388 if (rx >= P || s >= N) return false;
389
390 const eInput = new Uint8Array(96);
391 eInput.set(sigU8.slice(0, 32), 0);
392 eInput.set(pkU8, 32);
393 eInput.set(mU8, 64);
394 const eHash = taggedHash("BIP0340/challenge", eInput);
395 const e = bytesToBigInt(eHash) % N;
396
397 // R = s*G + (-e)*P via Shamir's trick.
398 const negE = ((N - e) % N + N) % N;
399 const [Rx, Ry] = shamirMultGP(s, negE, px, py);
400
401 if (Rx === 0n && Ry === 0n) return false;
402 if (Ry % 2n !== 0n) return false;
403 if (Rx !== rx) return false;
404 return true;
405 }
406
407 export function ECDH(seckey, pubkey) {
408 const sk = bytesToBigInt(sliceToU8(seckey));
409 const px = bytesToBigInt(sliceToU8(pubkey));
410 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
411 const py = liftX(px);
412 if (py === null) return [u8ToSlice(new Uint8Array(32)), false];
413 const [rx, _] = scalarMult(sk, px, py);
414 if (rx === 0n) return [u8ToSlice(new Uint8Array(32)), false];
415 return [u8ToSlice(bigIntTo32(rx)), true];
416 }
417
418 export function SHA256Sum(data) {
419 return u8ToSlice(sha256(sliceToU8(data)));
420 }
421
422 export function ScalarAddModN(a, b) {
423 const ai = bytesToBigInt(sliceToU8(a));
424 const bi = bytesToBigInt(sliceToU8(b));
425 const sum = ((ai + bi) % N + N) % N;
426 if (sum === 0n) return [u8ToSlice(new Uint8Array(32)), false];
427 return [u8ToSlice(bigIntTo32(sum)), true];
428 }
429
430 export function CompressedPubKey(seckey) {
431 const sk = bytesToBigInt(sliceToU8(seckey));
432 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(33)), false];
433 const [px, py] = scalarMultG(sk);
434 const out = new Uint8Array(33);
435 out[0] = (py % 2n === 0n) ? 0x02 : 0x03;
436 out.set(bigIntTo32(px), 1);
437 return [u8ToSlice(out), true];
438 }
439