schnorr.mjs raw

   1  // schnorr.mjs — BIP-340 Schnorr signatures using JS BigInt.
   2  // Field/scalar arithmetic over secp256k1 with wNAF precomputed tables
   3  // and Shamir's trick for fast verification.
   4  
   5  import { Slice } from './builtin.mjs';
   6  
   7  const P  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2Fn;
   8  const N  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141n;
   9  const Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798n;
  10  const Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8n;
  11  
  12  // --- Modular arithmetic ---
  13  
  14  function mod(x, m) { return ((x % m) + m) % m; }
  15  
  16  function modInv(a, m) {
  17    let [old_r, r] = [mod(a, m), m];
  18    let [old_s, s] = [1n, 0n];
  19    while (r !== 0n) {
  20      const q = old_r / r;
  21      [old_r, r] = [r, old_r - q * r];
  22      [old_s, s] = [s, old_s - q * s];
  23    }
  24    return mod(old_s, m);
  25  }
  26  
  27  function modPow(base, exp, m) {
  28    let result = 1n;
  29    base = mod(base, m);
  30    while (exp > 0n) {
  31      if (exp & 1n) result = result * base % m;
  32      exp >>= 1n;
  33      base = base * base % m;
  34    }
  35    return result;
  36  }
  37  
  38  // --- Jacobian point: [X, Y, Z] where affine x=X/Z², y=Y/Z³ ---
  39  // Point at infinity: Z === 0n.
  40  
  41  function jDbl(X1, Y1, Z1) {
  42    if (Y1 === 0n) return [0n, 1n, 0n];
  43    const A = mod(Y1 * Y1, P);
  44    const B = mod(4n * X1 * A, P);
  45    const C = mod(8n * A * A, P);
  46    const D = mod(3n * X1 * X1, P); // a=0 for secp256k1
  47    const X3 = mod(D * D - 2n * B, P);
  48    const Y3 = mod(D * (B - X3) - C, P);
  49    const Z3 = mod(2n * Y1 * Z1, P);
  50    return [X3, Y3, Z3];
  51  }
  52  
  53  function jAdd(X1, Y1, Z1, X2, Y2, Z2) {
  54    if (Z1 === 0n) return [X2, Y2, Z2];
  55    if (Z2 === 0n) return [X1, Y1, Z1];
  56    const Z1sq = mod(Z1 * Z1, P);
  57    const Z2sq = mod(Z2 * Z2, P);
  58    const U1 = mod(X1 * Z2sq, P);
  59    const U2 = mod(X2 * Z1sq, P);
  60    const S1 = mod(Y1 * mod(Z2sq * Z2, P), P);
  61    const S2 = mod(Y2 * mod(Z1sq * Z1, P), P);
  62    if (U1 === U2) {
  63      if (S1 === S2) return jDbl(X1, Y1, Z1);
  64      return [0n, 1n, 0n];
  65    }
  66    const H = mod(U2 - U1, P);
  67    const R = mod(S2 - S1, P);
  68    const H2 = mod(H * H, P);
  69    const H3 = mod(H2 * H, P);
  70    const X3 = mod(R * R - H3 - 2n * U1 * H2, P);
  71    const Y3 = mod(R * (U1 * H2 - X3) - S1 * H3, P);
  72    const Z3 = mod(H * Z1 * Z2, P);
  73    return [X3, Y3, Z3];
  74  }
  75  
  76  function jNeg(X, Y, Z) {
  77    return [X, mod(P - Y, P), Z];
  78  }
  79  
  80  function jToAffine(X, Y, Z) {
  81    if (Z === 0n) return [0n, 0n];
  82    const Zi = modInv(Z, P);
  83    const Z2 = mod(Zi * Zi, P);
  84    const Z3 = mod(Z2 * Zi, P);
  85    return [mod(X * Z2, P), mod(Y * Z3, P)];
  86  }
  87  
  88  // --- wNAF precomputed table for generator G ---
  89  // Window size W=8: precompute odd multiples 1G, 3G, 5G, ..., 255G.
  90  // Then scalarMultG processes 8 bits at a time using table lookup.
  91  
  92  const W = 8;
  93  const HALF = 1 << (W - 1); // 128
  94  const MASK = (1 << W) - 1; // 255
  95  
  96  // Precomputed table: _gTable[i] = (2i+1)*G in Jacobian [X, Y, Z].
  97  // Negation (for negative wNAF digits) is free: just negate Y.
  98  let _gTable = null;
  99  
 100  function ensureGTable() {
 101    if (_gTable) return;
 102    _gTable = new Array(HALF);
 103    // 1*G
 104    _gTable[0] = [Gx, Gy, 1n];
 105    // 2*G
 106    const g2 = jDbl(Gx, Gy, 1n);
 107    for (let i = 1; i < HALF; i++) {
 108      _gTable[i] = jAdd(_gTable[i - 1][0], _gTable[i - 1][1], _gTable[i - 1][2],
 109                         g2[0], g2[1], g2[2]);
 110    }
 111  }
 112  
 113  // Encode scalar k into wNAF form with window W.
 114  // Returns array of digits in {-(2^(W-1)-1) .. 2^(W-1)-1}, odd or zero.
 115  function wnaf(k, w) {
 116    const digits = [];
 117    const halfW = 1n << BigInt(w - 1);
 118    const fullW = halfW << 1n;
 119    const mask = fullW - 1n;
 120    while (k > 0n) {
 121      if (k & 1n) {
 122        let d = k & mask;
 123        if (d >= halfW) d -= fullW;
 124        k -= d;
 125        digits.push(d);
 126      } else {
 127        digits.push(0n);
 128      }
 129      k >>= 1n;
 130    }
 131    return digits;
 132  }
 133  
 134  // scalarMultG: fast k*G using wNAF precomputed table.
 135  function scalarMultG(k) {
 136    ensureGTable();
 137    k = ((k % N) + N) % N;
 138    if (k === 0n) return [0n, 0n];
 139  
 140    const digits = wnaf(k, W);
 141    let RX = 0n, RY = 1n, RZ = 0n; // infinity
 142  
 143    for (let i = digits.length - 1; i >= 0; i--) {
 144      [RX, RY, RZ] = jDbl(RX, RY, RZ);
 145      const d = digits[i];
 146      if (d !== 0n) {
 147        const idx = Number(d < 0n ? (-d - 1n) >> 1n : (d - 1n) >> 1n);
 148        const pt = _gTable[idx];
 149        if (d < 0n) {
 150          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
 151        } else {
 152          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
 153        }
 154      }
 155    }
 156    return jToAffine(RX, RY, RZ);
 157  }
 158  
 159  // Plain scalar multiply for arbitrary points (no precomputation).
 160  function scalarMult(k, px, py) {
 161    let [RX, RY, RZ] = [0n, 1n, 0n];
 162    let [CX, CY, CZ] = [px, py, 1n];
 163    k = ((k % N) + N) % N;
 164    while (k > 0n) {
 165      if (k & 1n) [RX, RY, RZ] = jAdd(RX, RY, RZ, CX, CY, CZ);
 166      [CX, CY, CZ] = jDbl(CX, CY, CZ);
 167      k >>= 1n;
 168    }
 169    return jToAffine(RX, RY, RZ);
 170  }
 171  
 172  // Shamir's trick: compute k1*G + k2*P in one pass.
 173  // Shares doublings between the two scalar multiplications.
 174  function shamirMultGP(k1, k2, px, py) {
 175    ensureGTable();
 176    k1 = ((k1 % N) + N) % N;
 177    k2 = ((k2 % N) + N) % N;
 178  
 179    // wNAF encode k1 (for G table lookup)
 180    const d1 = wnaf(k1, W);
 181    // Simple binary for k2 (no precomputed table for P)
 182    const d2 = wnaf(k2, 2);
 183  
 184    const len = Math.max(d1.length, d2.length);
 185  
 186    let RX = 0n, RY = 1n, RZ = 0n; // infinity
 187  
 188    for (let i = len - 1; i >= 0; i--) {
 189      [RX, RY, RZ] = jDbl(RX, RY, RZ);
 190  
 191      // k1*G part: use precomputed table
 192      const a = i < d1.length ? d1[i] : 0n;
 193      if (a !== 0n) {
 194        const idx = Number(a < 0n ? (-a - 1n) >> 1n : (a - 1n) >> 1n);
 195        const pt = _gTable[idx];
 196        if (a < 0n) {
 197          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], mod(P - pt[1], P), pt[2]);
 198        } else {
 199          [RX, RY, RZ] = jAdd(RX, RY, RZ, pt[0], pt[1], pt[2]);
 200        }
 201      }
 202  
 203      // k2*P part: simple binary (odd digit = ±1 only)
 204      const b = i < d2.length ? d2[i] : 0n;
 205      if (b > 0n) {
 206        [RX, RY, RZ] = jAdd(RX, RY, RZ, px, py, 1n);
 207      } else if (b < 0n) {
 208        [RX, RY, RZ] = jAdd(RX, RY, RZ, px, mod(P - py, P), 1n);
 209      }
 210    }
 211  
 212    return jToAffine(RX, RY, RZ);
 213  }
 214  
 215  function liftX(x) {
 216    const c = mod(x * x * x + 7n, P);
 217    const exp = (P + 1n) / 4n;
 218    let y = modPow(c, exp, P);
 219    if (mod(y * y, P) !== mod(c, P)) return null;
 220    if (y % 2n !== 0n) y = mod(P - y, P);
 221    return y;
 222  }
 223  
 224  // --- SHA-256 (synchronous, pure JS) ---
 225  
 226  const K256 = [
 227    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
 228    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
 229    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
 230    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
 231    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
 232    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
 233    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
 234    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
 235  ];
 236  
 237  function sha256(data) {
 238    const len = data.length;
 239    const bitLen = len * 8;
 240    const padLen = ((56 - (len + 1) % 64) + 64) % 64;
 241    const padded = new Uint8Array(len + 1 + padLen + 8);
 242    padded.set(data);
 243    padded[len] = 0x80;
 244    const dv = new DataView(padded.buffer);
 245    dv.setUint32(padded.length - 4, bitLen, false);
 246  
 247    let [h0, h1, h2, h3, h4, h5, h6, h7] =
 248      [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19];
 249  
 250    const w = new Int32Array(64);
 251    for (let off = 0; off < padded.length; off += 64) {
 252      for (let i = 0; i < 16; i++) w[i] = dv.getInt32(off + i * 4, false);
 253      for (let i = 16; i < 64; i++) {
 254        const s0 = (ror(w[i-15], 7) ^ ror(w[i-15], 18) ^ (w[i-15] >>> 3));
 255        const s1 = (ror(w[i-2], 17) ^ ror(w[i-2], 19) ^ (w[i-2] >>> 10));
 256        w[i] = (w[i-16] + s0 + w[i-7] + s1) | 0;
 257      }
 258      let [a, b, c, d, e, f, g, h] = [h0, h1, h2, h3, h4, h5, h6, h7];
 259      for (let i = 0; i < 64; i++) {
 260        const S1 = ror(e, 6) ^ ror(e, 11) ^ ror(e, 25);
 261        const ch = (e & f) ^ (~e & g);
 262        const t1 = (h + S1 + ch + K256[i] + w[i]) | 0;
 263        const S0 = ror(a, 2) ^ ror(a, 13) ^ ror(a, 22);
 264        const maj = (a & b) ^ (a & c) ^ (b & c);
 265        const t2 = (S0 + maj) | 0;
 266        h = g; g = f; f = e; e = (d + t1) | 0;
 267        d = c; c = b; b = a; a = (t1 + t2) | 0;
 268      }
 269      h0 = (h0 + a) | 0; h1 = (h1 + b) | 0; h2 = (h2 + c) | 0; h3 = (h3 + d) | 0;
 270      h4 = (h4 + e) | 0; h5 = (h5 + f) | 0; h6 = (h6 + g) | 0; h7 = (h7 + h) | 0;
 271    }
 272    const out = new Uint8Array(32);
 273    const odv = new DataView(out.buffer);
 274    odv.setUint32(0, h0, false); odv.setUint32(4, h1, false);
 275    odv.setUint32(8, h2, false); odv.setUint32(12, h3, false);
 276    odv.setUint32(16, h4, false); odv.setUint32(20, h5, false);
 277    odv.setUint32(24, h6, false); odv.setUint32(28, h7, false);
 278    return out;
 279  }
 280  
 281  function ror(x, n) { return ((x >>> n) | (x << (32 - n))) | 0; }
 282  
 283  function taggedHash(tag, msg) {
 284    const tagH = sha256(new TextEncoder().encode(tag));
 285    const buf = new Uint8Array(64 + msg.length);
 286    buf.set(tagH, 0);
 287    buf.set(tagH, 32);
 288    buf.set(msg, 64);
 289    return sha256(buf);
 290  }
 291  
 292  // --- Slice conversions ---
 293  
 294  function sliceToU8(s) {
 295    if (s instanceof Uint8Array) return s;
 296    const u = new Uint8Array(s.$length);
 297    for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
 298    return u;
 299  }
 300  
 301  function u8ToSlice(u8) {
 302    const arr = new Array(u8.length);
 303    for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
 304    return new Slice(arr, 0, u8.length, u8.length);
 305  }
 306  
 307  function bytesToBigInt(u8) {
 308    let n = 0n;
 309    for (let i = 0; i < u8.length; i++) n = (n << 8n) | BigInt(u8[i]);
 310    return n;
 311  }
 312  
 313  function bigIntTo32(n) {
 314    const out = new Uint8Array(32);
 315    for (let i = 31; i >= 0; i--) { out[i] = Number(n & 0xFFn); n >>= 8n; }
 316    return out;
 317  }
 318  
 319  // --- Exported BIP-340 functions ---
 320  
 321  export function PubKeyFromSecKey(seckey) {
 322    const sk = bytesToBigInt(sliceToU8(seckey));
 323    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
 324    const [px, _] = scalarMultG(sk);
 325    return [u8ToSlice(bigIntTo32(px)), true];
 326  }
 327  
 328  export function SignSchnorr(seckey, msg, auxRand) {
 329    const sk = bytesToBigInt(sliceToU8(seckey));
 330    const m = sliceToU8(msg);
 331    const aux = sliceToU8(auxRand);
 332  
 333    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(64)), false];
 334  
 335    const [px, py] = scalarMultG(sk);
 336    let d = sk;
 337    if (py % 2n !== 0n) d = N - d;
 338  
 339    const pkBytes = bigIntTo32(px);
 340  
 341    const auxHash = taggedHash("BIP0340/aux", aux);
 342    const dBytes = bigIntTo32(d);
 343    const t = new Uint8Array(32);
 344    for (let i = 0; i < 32; i++) t[i] = dBytes[i] ^ auxHash[i];
 345  
 346    const nonceInput = new Uint8Array(96);
 347    nonceInput.set(t, 0);
 348    nonceInput.set(pkBytes, 32);
 349    nonceInput.set(m, 64);
 350    const kHash = taggedHash("BIP0340/nonce", nonceInput);
 351    let k = bytesToBigInt(kHash) % N;
 352    if (k === 0n) return [u8ToSlice(new Uint8Array(64)), false];
 353  
 354    const [rx, ry] = scalarMultG(k);
 355    if (ry % 2n !== 0n) k = N - k;
 356  
 357    const rxBytes = bigIntTo32(rx);
 358  
 359    const eInput = new Uint8Array(96);
 360    eInput.set(rxBytes, 0);
 361    eInput.set(pkBytes, 32);
 362    eInput.set(m, 64);
 363    const eHash = taggedHash("BIP0340/challenge", eInput);
 364    const e = bytesToBigInt(eHash) % N;
 365  
 366    const s = ((k + e * d) % N + N) % N;
 367  
 368    const sig = new Uint8Array(64);
 369    sig.set(rxBytes, 0);
 370    sig.set(bigIntTo32(s), 32);
 371  
 372    return [u8ToSlice(sig), true];
 373  }
 374  
 375  // VerifySchnorr — uses Shamir's trick for s*G + (-e)*P.
 376  export function VerifySchnorr(pubkey, msg, sig) {
 377    const pkU8 = sliceToU8(pubkey);
 378    const mU8 = sliceToU8(msg);
 379    const sigU8 = sliceToU8(sig);
 380  
 381    const px = bytesToBigInt(pkU8);
 382    const py = liftX(px);
 383    if (py === null) return false;
 384  
 385    const rx = bytesToBigInt(sigU8.slice(0, 32));
 386    const s = bytesToBigInt(sigU8.slice(32, 64));
 387    if (rx >= P || s >= N) return false;
 388  
 389    const eInput = new Uint8Array(96);
 390    eInput.set(sigU8.slice(0, 32), 0);
 391    eInput.set(pkU8, 32);
 392    eInput.set(mU8, 64);
 393    const eHash = taggedHash("BIP0340/challenge", eInput);
 394    const e = bytesToBigInt(eHash) % N;
 395  
 396    // R = s*G + (-e)*P via Shamir's trick.
 397    const negE = ((N - e) % N + N) % N;
 398    const [Rx, Ry] = shamirMultGP(s, negE, px, py);
 399  
 400    if (Rx === 0n && Ry === 0n) return false;
 401    if (Ry % 2n !== 0n) return false;
 402    if (Rx !== rx) return false;
 403    return true;
 404  }
 405  
 406  export function ECDH(seckey, pubkey) {
 407    const sk = bytesToBigInt(sliceToU8(seckey));
 408    const px = bytesToBigInt(sliceToU8(pubkey));
 409    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(32)), false];
 410    const py = liftX(px);
 411    if (py === null) return [u8ToSlice(new Uint8Array(32)), false];
 412    const [rx, _] = scalarMult(sk, px, py);
 413    if (rx === 0n) return [u8ToSlice(new Uint8Array(32)), false];
 414    return [u8ToSlice(bigIntTo32(rx)), true];
 415  }
 416  
 417  export function SHA256Sum(data) {
 418    return u8ToSlice(sha256(sliceToU8(data)));
 419  }
 420  
 421  export function ScalarAddModN(a, b) {
 422    const ai = bytesToBigInt(sliceToU8(a));
 423    const bi = bytesToBigInt(sliceToU8(b));
 424    const sum = ((ai + bi) % N + N) % N;
 425    if (sum === 0n) return [u8ToSlice(new Uint8Array(32)), false];
 426    return [u8ToSlice(bigIntTo32(sum)), true];
 427  }
 428  
 429  export function CompressedPubKey(seckey) {
 430    const sk = bytesToBigInt(sliceToU8(seckey));
 431    if (sk === 0n || sk >= N) return [u8ToSlice(new Uint8Array(33)), false];
 432    const [px, py] = scalarMultG(sk);
 433    const out = new Uint8Array(33);
 434    out[0] = (py % 2n === 0n) ? 0x02 : 0x03;
 435    out.set(bigIntTo32(px), 1);
 436    return [u8ToSlice(out), true];
 437  }
 438