p256.mjs raw

   1  // p256.mjs — NIST P-256 ECDH + ECDSA using JS BigInt.
   2  // Short-Weierstrass curve y² = x³ - 3x + b over GF(p).
   3  // Used by MLS cipher suite 0x0001 (DHKEM-P256, ECDSA-P256-SHA256).
   4  
   5  import { Slice } from './builtin.mjs';
   6  
   7  // --- Constants (FIPS 186-4, §D.1.2.3) ---
   8  
   9  const P = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn;
  10  const N = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551n;
  11  const B = 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604bn;
  12  const Gx = 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296n;
  13  const Gy = 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5n;
  14  
  15  // --- Field arithmetic mod P ---
  16  
  17  function modP(x) { return ((x % P) + P) % P; }
  18  
  19  function modPow(base, exp, m) {
  20    let r = 1n;
  21    base = ((base % m) + m) % m;
  22    while (exp > 0n) {
  23      if (exp & 1n) r = r * base % m;
  24      exp >>= 1n;
  25      base = base * base % m;
  26    }
  27    return r;
  28  }
  29  
  30  function modInv(a, m) { return modPow(a, m - 2n, m); }
  31  
  32  // --- Point ops in Jacobian coordinates [X, Y, Z] where x=X/Z², y=Y/Z³ ---
  33  // Identity: Z = 0.
  34  
  35  function jDouble(p) {
  36    const [X1, Y1, Z1] = p;
  37    if (Z1 === 0n) return p;
  38    // a = -3 optimization (NIST curves).
  39    const A = modP((X1 - Z1 * Z1 % P) * ((X1 + Z1 * Z1 % P) % P));
  40    const M = modP(A + A + A);
  41    const Y2 = modP(Y1 * Y1);
  42    const S = modP(4n * X1 * Y2);
  43    const X3 = modP(M * M - 2n * S);
  44    const Y3 = modP(M * (S - X3) - 8n * Y2 * Y2);
  45    const Z3 = modP(2n * Y1 * Z1);
  46    return [X3, Y3, Z3];
  47  }
  48  
  49  function jAdd(p, q) {
  50    const [X1, Y1, Z1] = p;
  51    const [X2, Y2, Z2] = q;
  52    if (Z1 === 0n) return q;
  53    if (Z2 === 0n) return p;
  54    const Z1Z1 = modP(Z1 * Z1);
  55    const Z2Z2 = modP(Z2 * Z2);
  56    const U1 = modP(X1 * Z2Z2);
  57    const U2 = modP(X2 * Z1Z1);
  58    const S1 = modP(Y1 * Z2 * Z2Z2);
  59    const S2 = modP(Y2 * Z1 * Z1Z1);
  60    const H = modP(U2 - U1);
  61    const R = modP(S2 - S1);
  62    if (H === 0n) {
  63      if (R === 0n) return jDouble(p);
  64      return [0n, 0n, 0n]; // identity (shouldn't happen for distinct valid points)
  65    }
  66    const HH = modP(H * H);
  67    const HHH = modP(H * HH);
  68    const V = modP(U1 * HH);
  69    const X3 = modP(R * R - HHH - 2n * V);
  70    const Y3 = modP(R * (V - X3) - S1 * HHH);
  71    const Z3 = modP(Z1 * Z2 * H);
  72    return [X3, Y3, Z3];
  73  }
  74  
  75  function scalarMult(k, p) {
  76    let r = [0n, 0n, 0n];
  77    let q = p;
  78    k = ((k % N) + N) % N;
  79    while (k > 0n) {
  80      if (k & 1n) r = jAdd(r, q);
  81      q = jDouble(q);
  82      k >>= 1n;
  83    }
  84    return r;
  85  }
  86  
  87  function toAffine(p) {
  88    const [X, Y, Z] = p;
  89    if (Z === 0n) return null;
  90    const zInv = modInv(Z, P);
  91    const zInv2 = modP(zInv * zInv);
  92    const zInv3 = modP(zInv2 * zInv);
  93    return [modP(X * zInv2), modP(Y * zInv3)];
  94  }
  95  
  96  const G = [Gx, Gy, 1n];
  97  
  98  // --- Encoding ---
  99  
 100  function i2osp(x, len) {
 101    const out = new Uint8Array(len);
 102    for (let i = len - 1; i >= 0; i--) {
 103      out[i] = Number(x & 0xffn);
 104      x >>= 8n;
 105    }
 106    return out;
 107  }
 108  
 109  function os2ip(b) {
 110    let x = 0n;
 111    for (let i = 0; i < b.length; i++) x = (x << 8n) | BigInt(b[i]);
 112    return x;
 113  }
 114  
 115  // Uncompressed SEC1: 0x04 || X(32) || Y(32) = 65 bytes.
 116  function encodeUncompressed(p) {
 117    const aff = toAffine(p);
 118    if (!aff) return new Uint8Array(65);
 119    const out = new Uint8Array(65);
 120    out[0] = 0x04;
 121    out.set(i2osp(aff[0], 32), 1);
 122    out.set(i2osp(aff[1], 32), 33);
 123    return out;
 124  }
 125  
 126  function decodeUncompressed(u8) {
 127    if (u8.length !== 65 || u8[0] !== 0x04) return null;
 128    const x = os2ip(u8.subarray(1, 33));
 129    const y = os2ip(u8.subarray(33, 65));
 130    if (x >= P || y >= P) return null;
 131    // Verify point is on curve: y² = x³ - 3x + b
 132    const lhs = modP(y * y);
 133    const rhs = modP(x * x * x - 3n * x + B);
 134    if (lhs !== rhs) return null;
 135    return [x, y, 1n];
 136  }
 137  
 138  // --- Slice interop ---
 139  
 140  function sliceToU8(s) {
 141    if (s instanceof Uint8Array) return s;
 142    if (typeof s === 'string') return new TextEncoder().encode(s);
 143    const u = new Uint8Array(s.$length);
 144    for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
 145    return u;
 146  }
 147  
 148  function u8ToSlice(u8) {
 149    const arr = new Array(u8.length);
 150    for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
 151    return new Slice(arr, 0, u8.length, u8.length);
 152  }
 153  
 154  // --- SHA-256 (minimal, for ECDSA) ---
 155  
 156  function sha256(msg) {
 157    const H = new Uint32Array([
 158      0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
 159      0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
 160    ]);
 161    const K = new Uint32Array([
 162      0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
 163      0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
 164      0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
 165      0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
 166      0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
 167      0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
 168      0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
 169      0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
 170    ]);
 171    const bitLen = BigInt(msg.length) * 8n;
 172    const padLen = ((msg.length + 9 + 63) >> 6) << 6;
 173    const buf = new Uint8Array(padLen);
 174    buf.set(msg);
 175    buf[msg.length] = 0x80;
 176    for (let i = 0; i < 8; i++) buf[padLen - 1 - i] = Number((bitLen >> BigInt(8 * i)) & 0xffn);
 177    const W = new Uint32Array(64);
 178    for (let off = 0; off < padLen; off += 64) {
 179      for (let i = 0; i < 16; i++) {
 180        W[i] = (buf[off + 4*i] << 24) | (buf[off + 4*i+1] << 16) | (buf[off + 4*i+2] << 8) | buf[off + 4*i+3];
 181        W[i] >>>= 0;
 182      }
 183      for (let i = 16; i < 64; i++) {
 184        const s0 = ((W[i-15] >>> 7) | (W[i-15] << 25)) ^ ((W[i-15] >>> 18) | (W[i-15] << 14)) ^ (W[i-15] >>> 3);
 185        const s1 = ((W[i-2] >>> 17) | (W[i-2] << 15)) ^ ((W[i-2] >>> 19) | (W[i-2] << 13)) ^ (W[i-2] >>> 10);
 186        W[i] = (W[i-16] + s0 + W[i-7] + s1) >>> 0;
 187      }
 188      let [a, b, c, d, e, f, g, h] = H;
 189      for (let i = 0; i < 64; i++) {
 190        const S1 = ((e >>> 6) | (e << 26)) ^ ((e >>> 11) | (e << 21)) ^ ((e >>> 25) | (e << 7));
 191        const ch = (e & f) ^ (~e & g);
 192        const t1 = (h + S1 + ch + K[i] + W[i]) >>> 0;
 193        const S0 = ((a >>> 2) | (a << 30)) ^ ((a >>> 13) | (a << 19)) ^ ((a >>> 22) | (a << 10));
 194        const mj = (a & b) ^ (a & c) ^ (b & c);
 195        const t2 = (S0 + mj) >>> 0;
 196        h = g; g = f; f = e; e = (d + t1) >>> 0; d = c; c = b; b = a; a = (t1 + t2) >>> 0;
 197      }
 198      H[0] = (H[0] + a) >>> 0;
 199      H[1] = (H[1] + b) >>> 0;
 200      H[2] = (H[2] + c) >>> 0;
 201      H[3] = (H[3] + d) >>> 0;
 202      H[4] = (H[4] + e) >>> 0;
 203      H[5] = (H[5] + f) >>> 0;
 204      H[6] = (H[6] + g) >>> 0;
 205      H[7] = (H[7] + h) >>> 0;
 206    }
 207    const out = new Uint8Array(32);
 208    for (let i = 0; i < 8; i++) {
 209      out[4*i+0] = (H[i] >>> 24) & 0xff;
 210      out[4*i+1] = (H[i] >>> 16) & 0xff;
 211      out[4*i+2] = (H[i] >>> 8) & 0xff;
 212      out[4*i+3] = H[i] & 0xff;
 213    }
 214    return out;
 215  }
 216  
 217  // --- ECDH exports ---
 218  
 219  // ScalarBaseMult computes k*G and returns the 65-byte uncompressed point.
 220  export function ScalarBaseMult(scalar) {
 221    const k = os2ip(sliceToU8(scalar));
 222    return u8ToSlice(encodeUncompressed(scalarMult(k, G)));
 223  }
 224  
 225  // ScalarMult computes k*P and returns the 32-byte X coordinate (ECDH shared secret).
 226  export function ScalarMult(scalar, point) {
 227    const k = os2ip(sliceToU8(scalar));
 228    const pt = decodeUncompressed(sliceToU8(point));
 229    if (!pt) return u8ToSlice(new Uint8Array(32));
 230    const r = scalarMult(k, pt);
 231    const aff = toAffine(r);
 232    if (!aff) return u8ToSlice(new Uint8Array(32));
 233    return u8ToSlice(i2osp(aff[0], 32));
 234  }
 235  
 236  // --- ECDSA exports ---
 237  // Signatures are DER-encoded per MLS spec (RFC 9420 §5.1.2 signature_algorithm
 238  // ecdsa_secp256r1_sha256 uses ASN.1 SEQUENCE{r,s}).
 239  
 240  function derEncodeScalar(x) {
 241    const b = i2osp(x, 32);
 242    // Trim leading zeros
 243    let start = 0;
 244    while (start < 31 && b[start] === 0) start++;
 245    let trimmed = b.subarray(start);
 246    // If high bit set, prepend 0x00 to mark as positive.
 247    if (trimmed[0] & 0x80) {
 248      const out = new Uint8Array(trimmed.length + 1);
 249      out.set(trimmed, 1);
 250      trimmed = out;
 251    }
 252    const wrap = new Uint8Array(trimmed.length + 2);
 253    wrap[0] = 0x02;
 254    wrap[1] = trimmed.length;
 255    wrap.set(trimmed, 2);
 256    return wrap;
 257  }
 258  
 259  function derEncodeSig(r, s) {
 260    const rEnc = derEncodeScalar(r);
 261    const sEnc = derEncodeScalar(s);
 262    const inner = new Uint8Array(rEnc.length + sEnc.length);
 263    inner.set(rEnc);
 264    inner.set(sEnc, rEnc.length);
 265    const out = new Uint8Array(inner.length + 2);
 266    out[0] = 0x30;
 267    out[1] = inner.length;
 268    out.set(inner, 2);
 269    return out;
 270  }
 271  
 272  function derDecodeSig(b) {
 273    if (b.length < 8 || b[0] !== 0x30) return null;
 274    let off = 2;
 275    if (b[off++] !== 0x02) return null;
 276    let rLen = b[off++];
 277    const r = os2ip(b.subarray(off, off + rLen));
 278    off += rLen;
 279    if (b[off++] !== 0x02) return null;
 280    let sLen = b[off++];
 281    const s = os2ip(b.subarray(off, off + sLen));
 282    return [r, s];
 283  }
 284  
 285  // hashToScalar reduces SHA-256(msg) to a scalar mod N.
 286  function hashToScalar(msg) {
 287    return os2ip(sha256(msg)) % N;
 288  }
 289  
 290  // Deterministic nonce per RFC 6979 §3.2 using HMAC-SHA-256.
 291  function hmacSha256(key, data) {
 292    const bs = 64;
 293    let k = key;
 294    if (k.length > bs) k = sha256(k);
 295    if (k.length < bs) { const p = new Uint8Array(bs); p.set(k); k = p; }
 296    const ipad = new Uint8Array(bs);
 297    const opad = new Uint8Array(bs);
 298    for (let i = 0; i < bs; i++) { ipad[i] = k[i] ^ 0x36; opad[i] = k[i] ^ 0x5c; }
 299    const inner = new Uint8Array(bs + data.length);
 300    inner.set(ipad); inner.set(data, bs);
 301    const innerHash = sha256(inner);
 302    const outer = new Uint8Array(bs + 32);
 303    outer.set(opad); outer.set(innerHash, bs);
 304    return sha256(outer);
 305  }
 306  
 307  function rfc6979Nonce(x, hMsg) {
 308    // V = 0x01*32, K = 0x00*32
 309    let V = new Uint8Array(32).fill(0x01);
 310    let K = new Uint8Array(32);
 311    const xBytes = i2osp(x, 32);
 312    // K = HMAC_K(V || 0x00 || x || h)
 313    let buf = new Uint8Array(32 + 1 + 32 + 32);
 314    buf.set(V); buf[32] = 0x00; buf.set(xBytes, 33); buf.set(hMsg, 65);
 315    K = hmacSha256(K, buf);
 316    V = hmacSha256(K, V);
 317    buf = new Uint8Array(32 + 1 + 32 + 32);
 318    buf.set(V); buf[32] = 0x01; buf.set(xBytes, 33); buf.set(hMsg, 65);
 319    K = hmacSha256(K, buf);
 320    V = hmacSha256(K, V);
 321    while (true) {
 322      V = hmacSha256(K, V);
 323      const k = os2ip(V);
 324      if (k >= 1n && k < N) return k;
 325      const b = new Uint8Array(33);
 326      b.set(V); b[32] = 0x00;
 327      K = hmacSha256(K, b);
 328      V = hmacSha256(K, V);
 329    }
 330  }
 331  
 332  // Sign produces a DER-encoded ECDSA-P256-SHA256 signature.
 333  // seed: 32-byte private key scalar (big-endian).
 334  export function Sign(seed, message) {
 335    const d = os2ip(sliceToU8(seed)) % N;
 336    const msg = sliceToU8(message);
 337    const h = sha256(msg);
 338    const z = os2ip(h) % N;
 339    for (let iter = 0; iter < 32; iter++) {
 340      const k = rfc6979Nonce(d, h);
 341      const kG = scalarMult(k, G);
 342      const aff = toAffine(kG);
 343      if (!aff) continue;
 344      const r = aff[0] % N;
 345      if (r === 0n) continue;
 346      const kInv = modInv(k, N);
 347      const s = (kInv * ((z + r * d) % N)) % N;
 348      if (s === 0n) continue;
 349      return u8ToSlice(derEncodeSig(r, s));
 350    }
 351    return u8ToSlice(new Uint8Array(0));
 352  }
 353  
 354  // Verify validates a DER-encoded ECDSA-P256-SHA256 signature.
 355  // pubkey: 65-byte uncompressed SEC1 (0x04 || X || Y).
 356  export function Verify(pubkey, message, sig) {
 357    const pk = decodeUncompressed(sliceToU8(pubkey));
 358    if (!pk) return false;
 359    const msg = sliceToU8(message);
 360    const parsed = derDecodeSig(sliceToU8(sig));
 361    if (!parsed) return false;
 362    const [r, s] = parsed;
 363    if (r <= 0n || r >= N || s <= 0n || s >= N) return false;
 364    const z = os2ip(sha256(msg)) % N;
 365    const w = modInv(s, N);
 366    const u1 = (z * w) % N;
 367    const u2 = (r * w) % N;
 368    const p1 = scalarMult(u1, G);
 369    const p2 = scalarMult(u2, pk);
 370    const sum = jAdd(p1, p2);
 371    const aff = toAffine(sum);
 372    if (!aff) return false;
 373    return (aff[0] % N) === r;
 374  }
 375  
 376  // DeriveKeyPair per RFC 9180 §7.1.3 for DHKEM(P-256, HKDF-SHA256).
 377  // Uses rejection sampling: sk = OS2IP(LabeledExpand(dkp_prk,"candidate",I2OSP(counter,1),32)) mod n.
 378  // But the MLS caller does LabeledExpand — here we just expose a scalar reducer.
 379  // ReduceScalar clamps x into [1, N-1] by mod, returning 32 big-endian bytes; callers must
 380  // iterate with rejection if needed.
 381  export function ReduceScalar(x) {
 382    const v = os2ip(sliceToU8(x));
 383    const r = v % N;
 384    if (r === 0n) return u8ToSlice(new Uint8Array(32));
 385    return u8ToSlice(i2osp(r, 32));
 386  }
 387  
 388  // ValidScalar returns true iff OS2IP(x) is in [1, N-1].
 389  // Used by RFC 9180 DeriveKeyPair rejection sampling and signature key gen.
 390  export function ValidScalar(x) {
 391    const v = os2ip(sliceToU8(x));
 392    return v >= 1n && v < N;
 393  }
 394