schnorr.mjs raw
1 // schnorr.mjs — BIP-340 Schnorr signatures using JS BigInt.
2 // Field/scalar arithmetic over secp256k1 with wNAF precomputed tables
3 // and Shamir's trick for fast verification.
4
5 import { Slice } from './builtin.mjs';
6
7 const P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2Fn;
8 const N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141n;
9 const Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798n;
10 const Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8n;
11
12 // --- Modular arithmetic ---
13
14 function mod(x, m) { return ((x % m) + m) % m; }
15
16 function modInv(a, m) {
17 let [old_r, r] = [mod(a, m), m];
18 let [old_s, s] = [1n, 0n];
19 while (r !== 0n) {
20 const q = old_r / r;
21 [old_r, r] = [r, old_r - q * r];
22 [old_s, s] = [s, old_s - q * s];
23 }
24 return mod(old_s, m);
25 }
26
27 function modPow(base, exp, m) {
28 let result = 1n;
29 base = mod(base, m);
30 while (exp > 0n) {
31 if (exp & 1n) result = result * base % m;
32 exp >>= 1n;
33 base = base * base % m;
34 }
35 return result;
36 }
37
38 // --- Jacobian point: [X, Y, Z] where affine x=X/Z², y=Y/Z³ ---
39 // Point at infinity: Z === 0n.
40
41 function jDbl(X1, Y1, Z1) {
42 if (Y1 === 0n) return [0n, 1n, 0n];
43 const A = mod(Y1 * Y1, P);
44 const B = mod(4n * X1 * A, P);
45 const C = mod(8n * A * A, P);
46 const D = mod(3n * X1 * X1, P); // a=0 for secp256k1
47 const X3 = mod(D * D - 2n * B, P);
48 const Y3 = mod(D * (B - X3) - C, P);
49 const Z3 = mod(2n * Y1 * Z1, P);
50 return [X3, Y3, Z3];
51 }
52
53 function jAdd(X1, Y1, Z1, X2, Y2, Z2) {
54 if (Z1 === 0n) return [X2, Y2, Z2];
55 if (Z2 === 0n) return [X1, Y1, Z1];
56 const Z1sq = mod(Z1 * Z1, P);
57 const Z2sq = mod(Z2 * Z2, P);
58 const U1 = mod(X1 * Z2sq, P);
59 const U2 = mod(X2 * Z1sq, P);
60 const S1 = mod(Y1 * mod(Z2sq * Z2, P), P);
61 const S2 = mod(Y2 * mod(Z1sq * Z1, P), P);
62 if (U1 === U2) {
63 if (S1 === S2) return jDbl(X1, Y1, Z1);
64 return [0n, 1n, 0n];
65 }
66 const H = mod(U2 - U1, P);
67 const R = mod(S2 - S1, P);
68 const H2 = mod(H * H, P);
69 const H3 = mod(H2 * H, P);
70 const X3 = mod(R * R - H3 - 2n * U1 * H2, P);
71 const Y3 = mod(R * (U1 * H2 - X3) - S1 * H3, P);
72 const Z3 = mod(H * Z1 * Z2, P);
73 return [X3, Y3, Z3];
74 }
75
76 function jNeg(X, Y, Z) {
77 return [X, mod(P - Y, P), Z];
78 }
79
80 function jToAffine(X, Y, Z) {
81 if (Z === 0n) return [0n, 0n];
82 const Zi = modInv(Z, P);
83 const Z2 = mod(Zi * Zi, P);
84 const Z3 = mod(Z2 * Zi, P);
85 return [mod(X * Z2, P), mod(Y * Z3, P)];
86 }
87
88 // --- wNAF precomputed table for generator G ---
89 // Window size W=8: precompute odd multiples 1G, 3G, 5G, ..., 255G.
90 // Then scalarMultG processes 8 bits at a time using table lookup.
91
92 const W = 8;
93 const HALF = 1 << (W - 1); // 128
94 const MASK = (1 << W) - 1; // 255
95
96 // Precomputed table: _gTable[i] = (2i+1)*G in Jacobian [X, Y, Z].
97 // Negation (for negative wNAF digits) is free: just negate Y.
98 let _gTable = null;
99
100 function ensureGTable() {
101 if (_gTable) return;
102 _gTable = new Array(HALF);
103 // 1*G
104 _gTable[0] = [Gx, Gy, 1n];
105 // 2*G
106 const g2 = jDbl(Gx, Gy, 1n);
107 for (let i = 1; i < HALF; i++) {
108 _gTable[i] = jAdd(_gTable[i - 1][0], _gTable[i - 1][1], _gTable[i - 1][2],
109 g2[0], g2[1], g2[2]);
110 }
111 }
112
113 // Encode scalar k into wNAF form with window W.
114 // Returns array of digits in {-(2^(W-1)-1) .. 2^(W-1)-1}, odd or zero.
115 function wnaf(k, w) {
116 const digits = [];
117 const halfW = 1n << BigInt(w - 1);
118 const fullW = halfW << 1n;
119 const mask = fullW - 1n;
120 while (k > 0n) {
121 if (k & 1n) {
122 let d = k & mask;
123 if (d >= halfW) d -= fullW;
124 k -= d;
125 digits.push(d);
126 } else {
127 digits.push(0n);
128 }
129 k >>= 1n;
130 }
131 return digits;
132 }
133
134 // scalarMultG: fast k*G using wNAF precomputed table.
135 function scalarMultG(k) {
136 ensureGTable();
137 k = ((k % N) + N) % N;
138 if (k === 0n) return [0n, 0n];
139
140 const digits = wnaf(k, W);
141 let RX = 0n, RY = 1n, RZ = 0n; // infinity
142
143 for (let i = digits.length - 1; i >= 0; i--) {
144 [RX, RY, RZ] = jDbl(RX, RY, RZ);
145 const d = digits[i];
146 if (d !== 0n) {
147 const idx = Number(d < 0n ? (-d - 1n) >> 1n : (d - 1n) >> 1n);
148 const pt = _gTable[idx];
149 if (d < 0n) {
150 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
151 } else {
152 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
153 }
154 }
155 }
156 return jToAffine(RX, RY, RZ);
157 }
158
159 // Plain scalar multiply for arbitrary points (no precomputation).
160 function scalarMult(k, px, py) {
161 let [RX, RY, RZ] = [0n, 1n, 0n];
162 let [CX, CY, CZ] = [px, py, 1n];
163 k = ((k % N) + N) % N;
164 while (k > 0n) {
165 if (k & 1n) [RX, RY, RZ] = jAdd(RX, RY, RZ, CX, CY, CZ);
166 [CX, CY, CZ] = jDbl(CX, CY, CZ);
167 k >>= 1n;
168 }
169 return jToAffine(RX, RY, RZ);
170 }
171
172 // Shamir's trick: compute k1*G + k2*P in one pass.
173 // Shares doublings between the two scalar multiplications.
174 function shamirMultGP(k1, k2, px, py) {
175 ensureGTable();
176 k1 = ((k1 % N) + N) % N;
177 k2 = ((k2 % N) + N) % N;
178
179 // wNAF encode k1 (for G table lookup)
180 const d1 = wnaf(k1, W);
181 // Simple binary for k2 (no precomputed table for P)
182 const d2 = wnaf(k2, 2);
183
184 const len = Math.max(d1.length, d2.length);
185
186 let RX = 0n, RY = 1n, RZ = 0n; // infinity
187
188 for (let i = len - 1; i >= 0; i--) {
189 [RX, RY, RZ] = jDbl(RX, RY, RZ);
190
191 // k1*G part: use precomputed table
192 const a = i < d1.length ? d1[i] : 0n;
193 if (a !== 0n) {
194 const idx = Number(a < 0n ? (-a - 1n) >> 1n : (a - 1n) >> 1n);
195 const pt = _gTable[idx];
196 if (a < 0n) {
197 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
198 } else {
199 [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
200 }
201 }
202
203 // k2*P part: simple binary (odd digit = ±1 only)
204 const b = i < d2.length ? d2[i] : 0n;
205 if (b > 0n) {
206 [RX, RY, RZ] = jAdd(RX, RY, RZ, px, py, 1n);
207 } else if (b < 0n) {
208 [RX, RY, RZ] = jAdd(RX, RY, RZ, px, mod(P - py, P), 1n);
209 }
210 }
211
212 return jToAffine(RX, RY, RZ);
213 }
214
215 function liftX(x) {
216 const c = mod(x * x * x + 7n, P);
217 const exp = (P + 1n) / 4n;
218 let y = modPow(c, exp, P);
219 if (mod(y * y, P) !== mod(c, P)) return null;
220 if (y % 2n !== 0n) y = mod(P - y, P);
221 return y;
222 }
223
224 // --- SHA-256 (synchronous, pure JS) ---
225
226 const K256 = [
227 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
228 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
229 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
230 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
231 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
232 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
233 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
234 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
235 ];
236
237 function sha256(data) {
238 const len = data.length;
239 const bitLen = len * 8;
240 const padLen = ((56 - (len + 1) % 64) + 64) % 64;
241 const padded = new Uint8Array(len + 1 + padLen + 8);
242 padded.set(data);
243 padded[len] = 0x80;
244 const dv = new DataView(padded.buffer);
245 dv.setUint32(padded.length - 4, bitLen, false);
246
247 let [h0, h1, h2, h3, h4, h5, h6, h7] =
248 [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19];
249
250 const w = new Int32Array(64);
251 for (let off = 0; off < padded.length; off += 64) {
252 for (let i = 0; i < 16; i++) w[i] = dv.getInt32(off + i * 4, false);
253 for (let i = 16; i < 64; i++) {
254 const s0 = (ror(w[i-15], 7) ^ ror(w[i-15], 18) ^ (w[i-15] >>> 3));
255 const s1 = (ror(w[i-2], 17) ^ ror(w[i-2], 19) ^ (w[i-2] >>> 10));
256 w[i] = (w[i-16] + s0 + w[i-7] + s1) | 0;
257 }
258 let [a, b, c, d, e, f, g, h] = [h0, h1, h2, h3, h4, h5, h6, h7];
259 for (let i = 0; i < 64; i++) {
260 const S1 = ror(e, 6) ^ ror(e, 11) ^ ror(e, 25);
261 const ch = (e & f) ^ (~e & g);
262 const t1 = (h + S1 + ch + K256[i] + w[i]) | 0;
263 const S0 = ror(a, 2) ^ ror(a, 13) ^ ror(a, 22);
264 const maj = (a & b) ^ (a & c) ^ (b & c);
265 const t2 = (S0 + maj) | 0;
266 h = g; g = f; f = e; e = (d + t1) | 0;
267 d = c; c = b; b = a; a = (t1 + t2) | 0;
268 }
269 h0 = (h0 + a) | 0; h1 = (h1 + b) | 0; h2 = (h2 + c) | 0; h3 = (h3 + d) | 0;
270 h4 = (h4 + e) | 0; h5 = (h5 + f) | 0; h6 = (h6 + g) | 0; h7 = (h7 + h) | 0;
271 }
272 const out = new Uint8Array(32);
273 const odv = new DataView(out.buffer);
274 odv.setUint32(0, h0, false); odv.setUint32(4, h1, false);
275 odv.setUint32(8, h2, false); odv.setUint32(12, h3, false);
276 odv.setUint32(16, h4, false); odv.setUint32(20, h5, false);
277 odv.setUint32(24, h6, false); odv.setUint32(28, h7, false);
278 return out;
279 }
280
281 function ror(x, n) { return ((x >>> n) | (x << (32 - n))) | 0; }
282
283 function taggedHash(tag, msg) {
284 const tagH = sha256(new TextEncoder().encode(tag));
285 const buf = new Uint8Array(64 + msg.length);
286 buf.set(tagH, 0);
287 buf.set(tagH, 32);
288 buf.set(msg, 64);
289 return sha256(buf);
290 }
291
292 // --- Slice conversions ---
293
294 function sliceToU8(s) {
295 if (s instanceof Uint8Array) return s;
296 const u = new Uint8Array(s.$length);
297 for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
298 return u;
299 }
300
301 function u8ToSlice(u8) {
302 const arr = new Array(u8.length);
303 for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
304 return new Slice(arr, 0, u8.length, u8.length);
305 }
306
307 function bytesToBigInt(u8) {
308 let n = 0n;
309 for (let i = 0; i < u8.length; i++) n = (n << 8n) | BigInt(u8[i]);
310 return n;
311 }
312
313 function bigIntTo32(n) {
314 const out = new Uint8Array(32);
315 for (let i = 31; i >= 0; i--) { out[i] = Number(n & 0xFFn); n >>= 8n; }
316 return out;
317 }
318
319 // --- Exported BIP-340 functions ---
320
321 export function PubKeyFromSecKey(seckey) {
322 const sk = bytesToBigInt(sliceToU8(seckey));
323 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
324 const [px, _] = scalarMultG(sk);
325 return [u8ToSlice(bigIntTo32(px)), true];
326 }
327
328 export function SignSchnorr(seckey, msg, auxRand) {
329 const sk = bytesToBigInt(sliceToU8(seckey));
330 const m = sliceToU8(msg);
331 const aux = sliceToU8(auxRand);
332
333 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(64)), false];
334
335 const [px, py] = scalarMultG(sk);
336 let d = sk;
337 if (py % 2n !== 0n) d = N - d;
338
339 const pkBytes = bigIntTo32(px);
340
341 const auxHash = taggedHash("BIP0340/aux", aux);
342 const dBytes = bigIntTo32(d);
343 const t = new Uint8Array(32);
344 for (let i = 0; i < 32; i++) t[i] = dBytes[i] ^ auxHash[i];
345
346 const nonceInput = new Uint8Array(96);
347 nonceInput.set(t, 0);
348 nonceInput.set(pkBytes, 32);
349 nonceInput.set(m, 64);
350 const kHash = taggedHash("BIP0340/nonce", nonceInput);
351 let k = bytesToBigInt(kHash) % N;
352 if (k === 0n) return [u8ToSlice(new Uint8Array(64)), false];
353
354 const [rx, ry] = scalarMultG(k);
355 if (ry % 2n !== 0n) k = N - k;
356
357 const rxBytes = bigIntTo32(rx);
358
359 const eInput = new Uint8Array(96);
360 eInput.set(rxBytes, 0);
361 eInput.set(pkBytes, 32);
362 eInput.set(m, 64);
363 const eHash = taggedHash("BIP0340/challenge", eInput);
364 const e = bytesToBigInt(eHash) % N;
365
366 const s = ((k + e * d) % N + N) % N;
367
368 const sig = new Uint8Array(64);
369 sig.set(rxBytes, 0);
370 sig.set(bigIntTo32(s), 32);
371
372 return [u8ToSlice(sig), true];
373 }
374
375 // VerifySchnorr — uses Shamir's trick for s*G + (-e)*P.
376 export function VerifySchnorr(pubkey, msg, sig) {
377 const pkU8 = sliceToU8(pubkey);
378 const mU8 = sliceToU8(msg);
379 const sigU8 = sliceToU8(sig);
380
381 const px = bytesToBigInt(pkU8);
382 const py = liftX(px);
383 if (py === null) return false;
384
385 const rx = bytesToBigInt(sigU8.slice(0, 32));
386 const s = bytesToBigInt(sigU8.slice(32, 64));
387 if (rx >= P || s >= N) return false;
388
389 const eInput = new Uint8Array(96);
390 eInput.set(sigU8.slice(0, 32), 0);
391 eInput.set(pkU8, 32);
392 eInput.set(mU8, 64);
393 const eHash = taggedHash("BIP0340/challenge", eInput);
394 const e = bytesToBigInt(eHash) % N;
395
396 // R = s*G + (-e)*P via Shamir's trick.
397 const negE = ((N - e) % N + N) % N;
398 const [Rx, Ry] = shamirMultGP(s, negE, px, py);
399
400 if (Rx === 0n && Ry === 0n) return false;
401 if (Ry % 2n !== 0n) return false;
402 if (Rx !== rx) return false;
403 return true;
404 }
405
406 export function ECDH(seckey, pubkey) {
407 const sk = bytesToBigInt(sliceToU8(seckey));
408 const px = bytesToBigInt(sliceToU8(pubkey));
409 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
410 const py = liftX(px);
411 if (py === null) return [u8ToSlice(new Uint8Array(32)), false];
412 const [rx, _] = scalarMult(sk, px, py);
413 if (rx === 0n) return [u8ToSlice(new Uint8Array(32)), false];
414 return [u8ToSlice(bigIntTo32(rx)), true];
415 }
416
417 export function SHA256Sum(data) {
418 return u8ToSlice(sha256(sliceToU8(data)));
419 }
420
421 export function ScalarAddModN(a, b) {
422 const ai = bytesToBigInt(sliceToU8(a));
423 const bi = bytesToBigInt(sliceToU8(b));
424 const sum = ((ai + bi) % N + N) % N;
425 if (sum === 0n) return [u8ToSlice(new Uint8Array(32)), false];
426 return [u8ToSlice(bigIntTo32(sum)), true];
427 }
428
429 export function CompressedPubKey(seckey) {
430 const sk = bytesToBigInt(sliceToU8(seckey));
431 if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(33)), false];
432 const [px, py] = scalarMultG(sk);
433 const out = new Uint8Array(33);
434 out[0] = (py % 2n === 0n) ? 0x02 : 0x03;
435 out.set(bigIntTo32(px), 1);
436 return [u8ToSlice(out), true];
437 }
438