schnorr.mjs raw

   1  // schnorr.mjs — BIP-340 Schnorr signatures using JS BigInt.
   2  // Field/scalar arithmetic over secp256k1 with wNAF precomputed tables
   3  // and Shamir's trick for fast verification.
   4  
   5  import { Slice } from './builtin.mjs';
   6  
   7  const P  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2Fn;
   8  const N  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141n;
   9  const Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798n;
  10  const Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8n;
  11  
  12  // --- Modular arithmetic ---
  13  
  14  function mod(x, m) { return ((x % m) + m) % m; }
  15  
  16  function modInv(a, m) {
  17    let [old_r, r] = [mod(a, m), m];
  18    let [old_s, s] = [1n, 0n];
  19    while (r !== 0n) {
  20      const q = old_r / r;
  21      [old_r, r] = [r, old_r - q * r];
  22      [old_s, s] = [s, old_s - q * s];
  23    }
  24    return mod(old_s, m);
  25  }
  26  
  27  function modPow(base, exp, m) {
  28    let result = 1n;
  29    base = mod(base, m);
  30    while (exp > 0n) {
  31      if (exp & 1n) result = result * base % m;
  32      exp >>= 1n;
  33      base = base * base % m;
  34    }
  35    return result;
  36  }
  37  
  38  // --- Jacobian point: [X, Y, Z] where affine x=X/Z², y=Y/Z³ ---
  39  // Point at infinity: Z === 0n.
  40  
  41  function jDbl(X1, Y1, Z1) {
  42    if (Y1 === 0n) return [0n, 1n, 0n];
  43    const A = mod(Y1 * Y1, P);
  44    const B = mod(4n * X1 * A, P);
  45    const C = mod(8n * A * A, P);
  46    const D = mod(3n * X1 * X1, P); // a=0 for secp256k1
  47    const X3 = mod(D * D - 2n * B, P);
  48    const Y3 = mod(D * (B - X3) - C, P);
  49    const Z3 = mod(2n * Y1 * Z1, P);
  50    return [X3, Y3, Z3];
  51  }
  52  
  53  function jAdd(X1, Y1, Z1, X2, Y2, Z2) {
  54    if (Z1 === 0n) return [X2, Y2, Z2];
  55    if (Z2 === 0n) return [X1, Y1, Z1];
  56    const Z1sq = mod(Z1 * Z1, P);
  57    const Z2sq = mod(Z2 * Z2, P);
  58    const U1 = mod(X1 * Z2sq, P);
  59    const U2 = mod(X2 * Z1sq, P);
  60    const S1 = mod(Y1 * mod(Z2sq * Z2, P), P);
  61    const S2 = mod(Y2 * mod(Z1sq * Z1, P), P);
  62    if (U1 === U2) {
  63      if (S1 === S2) return jDbl(X1, Y1, Z1);
  64      return [0n, 1n, 0n];
  65    }
  66    const H = mod(U2 - U1, P);
  67    const R = mod(S2 - S1, P);
  68    const H2 = mod(H * H, P);
  69    const H3 = mod(H2 * H, P);
  70    const X3 = mod(R * R - H3 - 2n * U1 * H2, P);
  71    const Y3 = mod(R * (U1 * H2 - X3) - S1 * H3, P);
  72    const Z3 = mod(H * Z1 * Z2, P);
  73    return [X3, Y3, Z3];
  74  }
  75  
  76  function jNeg(X, Y, Z) {
  77    return [X, mod(P - Y, P), Z];
  78  }
  79  
  80  function jToAffine(X, Y, Z) {
  81    if (Z === 0n) return [0n, 0n];
  82    const Zi = modInv(Z, P);
  83    const Z2 = mod(Zi * Zi, P);
  84    const Z3 = mod(Z2 * Zi, P);
  85    return [mod(X * Z2, P), mod(Y * Z3, P)];
  86  }
  87  
  88  // --- wNAF precomputed table for generator G ---
  89  // Window size W=8: precompute odd multiples 1G, 3G, 5G, ..., 255G.
  90  // Then scalarMultG processes 8 bits at a time using table lookup.
  91  
  92  const W = 8;
  93  const HALF = 1 << (W - 1); // 128
  94  const MASK = (1 << W) - 1; // 255
  95  
  96  // Precomputed table: _gTable[i] = (2i+1)*G in Jacobian [X, Y, Z].
  97  // Negation (for negative wNAF digits) is free: just negate Y.
  98  let _gTable = null;
  99  
 100  function ensureGTable() {
 101    if (_gTable) return;
 102    _gTable = new Array(HALF);
 103    // 1*G
 104    _gTable[0] = [Gx, Gy, 1n];
 105    // 2*G
 106    const g2 = jDbl(Gx, Gy, 1n);
 107    for (let i = 1; i < HALF; i++) {
 108      _gTable[i] = jAdd(_gTable[i - 1][0], _gTable[i - 1][1], _gTable[i - 1][2],
 109                         g2[0], g2[1], g2[2]);
 110    }
 111  }
 112  
 113  // Encode scalar k into wNAF form with window W.
 114  // Returns array of digits in {-(2^(W-1)-1) .. 2^(W-1)-1}, odd or zero.
 115  function wnaf(k, w) {
 116    const digits = [];
 117    const halfW = 1n << BigInt(w - 1);
 118    const fullW = halfW << 1n;
 119    const mask = fullW - 1n;
 120    while (k > 0n) {
 121      if (k & 1n) {
 122        let d = k & mask;
 123        if (d >= halfW) d -= fullW;
 124        k -= d;
 125        digits.push(d);
 126      } else {
 127        digits.push(0n);
 128      }
 129      k >>= 1n;
 130    }
 131    return digits;
 132  }
 133  
 134  // scalarMultG: fast k*G using wNAF precomputed table.
 135  function scalarMultG(k) {
 136    ensureGTable();
 137    k = ((k % N) + N) % N;
 138    if (k === 0n) return [0n, 0n];
 139  
 140    const digits = wnaf(k, W);
 141    let RX = 0n, RY = 1n, RZ = 0n; // infinity
 142  
 143    for (let i = digits.length - 1; i >= 0; i--) {
 144      [RX, RY, RZ] = jDbl(RX, RY, RZ);
 145      const d = digits[i];
 146      if (d !== 0n) {
 147        const idx = Number(d < 0n ? (-d - 1n) >> 1n : (d - 1n) >> 1n);
 148        const pt = _gTable[idx];
 149        if (d < 0n) {
 150          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
 151        } else {
 152          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
 153        }
 154      }
 155    }
 156    return jToAffine(RX, RY, RZ);
 157  }
 158  
 159  // Plain scalar multiply for arbitrary points (no precomputation).
 160  function scalarMult(k, px, py) {
 161    let [RX, RY, RZ] = [0n, 1n, 0n];
 162    let [CX, CY, CZ] = [px, py, 1n];
 163    k = ((k % N) + N) % N;
 164    while (k > 0n) {
 165      if (k & 1n) [RX, RY, RZ] = jAdd(RX, RY, RZ, CX, CY, CZ);
 166      [CX, CY, CZ] = jDbl(CX, CY, CZ);
 167      k >>= 1n;
 168    }
 169    return jToAffine(RX, RY, RZ);
 170  }
 171  
 172  // Shamir's trick: compute k1*G + k2*P in one pass.
 173  // Shares doublings between the two scalar multiplications.
 174  function shamirMultGP(k1, k2, px, py) {
 175    ensureGTable();
 176    k1 = ((k1 % N) + N) % N;
 177    k2 = ((k2 % N) + N) % N;
 178  
 179    // wNAF encode k1 (for G table lookup)
 180    const d1 = wnaf(k1, W);
 181    // Simple binary for k2 (no precomputed table for P)
 182    const d2 = wnaf(k2, 2);
 183  
 184    const len = Math.max(d1.length, d2.length);
 185  
 186    let RX = 0n, RY = 1n, RZ = 0n; // infinity
 187  
 188    for (let i = len - 1; i >= 0; i--) {
 189      [RX, RY, RZ] = jDbl(RX, RY, RZ);
 190  
 191      // k1*G part: use precomputed table
 192      const a = i < d1.length ? d1[i] : 0n;
 193      if (a !== 0n) {
 194        const idx = Number(a < 0n ? (-a - 1n) >> 1n : (a - 1n) >> 1n);
 195        const pt = _gTable[idx];
 196        if (a < 0n) {
 197          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
 198        } else {
 199          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
 200        }
 201      }
 202  
 203      // k2*P part: simple binary (odd digit = ±1 only)
 204      const b = i < d2.length ? d2[i] : 0n;
 205      if (b > 0n) {
 206        [RX, RY, RZ] = jAdd(RX, RY, RZ, px, py, 1n);
 207      } else if (b < 0n) {
 208        [RX, RY, RZ] = jAdd(RX, RY, RZ, px, mod(P - py, P), 1n);
 209      }
 210    }
 211  
 212    return jToAffine(RX, RY, RZ);
 213  }
 214  
 215  function liftX(x) {
 216    const c = mod(x * x * x + 7n, P);
 217    const exp = (P + 1n) / 4n;
 218    let y = modPow(c, exp, P);
 219    if (mod(y * y, P) !== mod(c, P)) return null;
 220    if (y % 2n !== 0n) y = mod(P - y, P);
 221    return y;
 222  }
 223  
 224  // --- SHA-256 (synchronous, pure JS) ---
 225  
 226  const K256 = [
 227    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
 228    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
 229    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
 230    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
 231    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
 232    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
 233    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
 234    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
 235  ];
 236  
 237  function sha256(data) {
 238    const len = data.length;
 239    const bitLen = len * 8;
 240    const padLen = ((56 - (len + 1) % 64) + 64) % 64;
 241    const padded = new Uint8Array(len + 1 + padLen + 8);
 242    padded.set(data);
 243    padded[len] = 0x80;
 244    const dv = new DataView(padded.buffer);
 245    dv.setUint32(padded.length - 4, bitLen, false);
 246  
 247    let [h0, h1, h2, h3, h4, h5, h6, h7] =
 248      [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19];
 249  
 250    const w = new Int32Array(64);
 251    for (let off = 0; off < padded.length; off += 64) {
 252      for (let i = 0; i < 16; i++) w[i] = dv.getInt32(off + i * 4, false);
 253      for (let i = 16; i < 64; i++) {
 254        const s0 = (ror(w[i-15], 7) ^ ror(w[i-15], 18) ^ (w[i-15] >>> 3));
 255        const s1 = (ror(w[i-2], 17) ^ ror(w[i-2], 19) ^ (w[i-2] >>> 10));
 256        w[i] = (w[i-16] + s0 + w[i-7] + s1) | 0;
 257      }
 258      let [a, b, c, d, e, f, g, h] = [h0, h1, h2, h3, h4, h5, h6, h7];
 259      for (let i = 0; i < 64; i++) {
 260        const S1 = ror(e, 6) ^ ror(e, 11) ^ ror(e, 25);
 261        const ch = (e & f) ^ (~e & g);
 262        const t1 = (h + S1 + ch + K256[i] + w[i]) | 0;
 263        const S0 = ror(a, 2) ^ ror(a, 13) ^ ror(a, 22);
 264        const maj = (a & b) ^ (a & c) ^ (b & c);
 265        const t2 = (S0 + maj) | 0;
 266        h = g; g = f; f = e; e = (d + t1) | 0;
 267        d = c; c = b; b = a; a = (t1 + t2) | 0;
 268      }
 269      h0 = (h0 + a) | 0; h1 = (h1 + b) | 0; h2 = (h2 + c) | 0; h3 = (h3 + d) | 0;
 270      h4 = (h4 + e) | 0; h5 = (h5 + f) | 0; h6 = (h6 + g) | 0; h7 = (h7 + h) | 0;
 271    }
 272    const out = new Uint8Array(32);
 273    const odv = new DataView(out.buffer);
 274    odv.setUint32(0, h0, false); odv.setUint32(4, h1, false);
 275    odv.setUint32(8, h2, false); odv.setUint32(12, h3, false);
 276    odv.setUint32(16, h4, false); odv.setUint32(20, h5, false);
 277    odv.setUint32(24, h6, false); odv.setUint32(28, h7, false);
 278    return out;
 279  }
 280  
 281  function ror(x, n) { return ((x >>> n) | (x << (32 - n))) | 0; }
 282  
 283  function taggedHash(tag, msg) {
 284    const tagH = sha256(new TextEncoder().encode(tag));
 285    const buf = new Uint8Array(64 + msg.length);
 286    buf.set(tagH, 0);
 287    buf.set(tagH, 32);
 288    buf.set(msg, 64);
 289    return sha256(buf);
 290  }
 291  
 292  // --- Slice conversions ---
 293  
 294  function sliceToU8(s) {
 295    if (s instanceof Uint8Array) return s;
 296    if (typeof s === 'string') return new TextEncoder().encode(s);
 297    const u = new Uint8Array(s.$length);
 298    for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
 299    return u;
 300  }
 301  
 302  function u8ToSlice(u8) {
 303    const arr = new Array(u8.length);
 304    for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
 305    return new Slice(arr, 0, u8.length, u8.length);
 306  }
 307  
 308  function bytesToBigInt(u8) {
 309    let n = 0n;
 310    for (let i = 0; i < u8.length; i++) n = (n << 8n) | BigInt(u8[i]);
 311    return n;
 312  }
 313  
 314  function bigIntTo32(n) {
 315    const out = new Uint8Array(32);
 316    for (let i = 31; i >= 0; i--) { out[i] = Number(n & 0xFFn); n >>= 8n; }
 317    return out;
 318  }
 319  
 320  // --- Exported BIP-340 functions ---
 321  
 322  export function PubKeyFromSecKey(seckey) {
 323    const sk = bytesToBigInt(sliceToU8(seckey));
 324    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
 325    const [px, _] = scalarMultG(sk);
 326    return [u8ToSlice(bigIntTo32(px)), true];
 327  }
 328  
 329  export function SignSchnorr(seckey, msg, auxRand) {
 330    const sk = bytesToBigInt(sliceToU8(seckey));
 331    const m = sliceToU8(msg);
 332    const aux = sliceToU8(auxRand);
 333  
 334    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(64)), false];
 335  
 336    const [px, py] = scalarMultG(sk);
 337    let d = sk;
 338    if (py % 2n !== 0n) d = N - d;
 339  
 340    const pkBytes = bigIntTo32(px);
 341  
 342    const auxHash = taggedHash("BIP0340/aux", aux);
 343    const dBytes = bigIntTo32(d);
 344    const t = new Uint8Array(32);
 345    for (let i = 0; i < 32; i++) t[i] = dBytes[i] ^ auxHash[i];
 346  
 347    const nonceInput = new Uint8Array(96);
 348    nonceInput.set(t, 0);
 349    nonceInput.set(pkBytes, 32);
 350    nonceInput.set(m, 64);
 351    const kHash = taggedHash("BIP0340/nonce", nonceInput);
 352    let k = bytesToBigInt(kHash) % N;
 353    if (k === 0n) return [u8ToSlice(new Uint8Array(64)), false];
 354  
 355    const [rx, ry] = scalarMultG(k);
 356    if (ry % 2n !== 0n) k = N - k;
 357  
 358    const rxBytes = bigIntTo32(rx);
 359  
 360    const eInput = new Uint8Array(96);
 361    eInput.set(rxBytes, 0);
 362    eInput.set(pkBytes, 32);
 363    eInput.set(m, 64);
 364    const eHash = taggedHash("BIP0340/challenge", eInput);
 365    const e = bytesToBigInt(eHash) % N;
 366  
 367    const s = ((k + e * d) % N + N) % N;
 368  
 369    const sig = new Uint8Array(64);
 370    sig.set(rxBytes, 0);
 371    sig.set(bigIntTo32(s), 32);
 372  
 373    return [u8ToSlice(sig), true];
 374  }
 375  
 376  // VerifySchnorr — uses Shamir's trick for s*G + (-e)*P.
 377  export function VerifySchnorr(pubkey, msg, sig) {
 378    const pkU8 = sliceToU8(pubkey);
 379    const mU8 = sliceToU8(msg);
 380    const sigU8 = sliceToU8(sig);
 381  
 382    const px = bytesToBigInt(pkU8);
 383    const py = liftX(px);
 384    if (py === null) return false;
 385  
 386    const rx = bytesToBigInt(sigU8.slice(0, 32));
 387    const s = bytesToBigInt(sigU8.slice(32, 64));
 388    if (rx >= P || s >= N) return false;
 389  
 390    const eInput = new Uint8Array(96);
 391    eInput.set(sigU8.slice(0, 32), 0);
 392    eInput.set(pkU8, 32);
 393    eInput.set(mU8, 64);
 394    const eHash = taggedHash("BIP0340/challenge", eInput);
 395    const e = bytesToBigInt(eHash) % N;
 396  
 397    // R = s*G + (-e)*P via Shamir's trick.
 398    const negE = ((N - e) % N + N) % N;
 399    const [Rx, Ry] = shamirMultGP(s, negE, px, py);
 400  
 401    if (Rx === 0n && Ry === 0n) return false;
 402    if (Ry % 2n !== 0n) return false;
 403    if (Rx !== rx) return false;
 404    return true;
 405  }
 406  
 407  export function ECDH(seckey, pubkey) {
 408    const sk = bytesToBigInt(sliceToU8(seckey));
 409    const px = bytesToBigInt(sliceToU8(pubkey));
 410    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
 411    const py = liftX(px);
 412    if (py === null) return [u8ToSlice(new Uint8Array(32)), false];
 413    const [rx, _] = scalarMult(sk, px, py);
 414    if (rx === 0n) return [u8ToSlice(new Uint8Array(32)), false];
 415    return [u8ToSlice(bigIntTo32(rx)), true];
 416  }
 417  
 418  export function SHA256Sum(data) {
 419    return u8ToSlice(sha256(sliceToU8(data)));
 420  }
 421  
 422  export function ScalarAddModN(a, b) {
 423    const ai = bytesToBigInt(sliceToU8(a));
 424    const bi = bytesToBigInt(sliceToU8(b));
 425    const sum = ((ai + bi) % N + N) % N;
 426    if (sum === 0n) return [u8ToSlice(new Uint8Array(32)), false];
 427    return [u8ToSlice(bigIntTo32(sum)), true];
 428  }
 429  
 430  export function CompressedPubKey(seckey) {
 431    const sk = bytesToBigInt(sliceToU8(seckey));
 432    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(33)), false];
 433    const [px, py] = scalarMultG(sk);
 434    const out = new Uint8Array(33);
 435    out[0] = (py % 2n === 0n) ? 0x02 : 0x03;
 436    out.set(bigIntTo32(px), 1);
 437    return [u8ToSlice(out), true];
 438  }
 439