x25519.mjs raw

   1  // x25519.mjs — X25519 Diffie-Hellman (RFC 7748) using JS BigInt.
   2  // Montgomery ladder over Curve25519: y² = x³ + 486662x² + x (mod 2²⁵⁵ − 19).
   3  
   4  import { Slice } from './builtin.mjs';
   5  
   6  const P   = (1n << 255n) - 19n;
   7  const A24 = 121665n; // (486662 - 2) / 4
   8  
   9  function mod(x) { return ((x % P) + P) % P; }
  10  
  11  function modPow(base, exp) {
  12    let r = 1n;
  13    base = mod(base);
  14    while (exp > 0n) {
  15      if (exp & 1n) r = r * base % P;
  16      exp >>= 1n;
  17      base = base * base % P;
  18    }
  19    return r;
  20  }
  21  
  22  function modInv(a) { return modPow(a, P - 2n); }
  23  
  24  // RFC 7748 §5: Montgomery ladder, constant-time in the bit-scan pattern.
  25  function ladder(k, u) {
  26    let x_2 = 1n, z_2 = 0n;
  27    let x_3 = u,  z_3 = 1n;
  28    let swap = 0n;
  29  
  30    for (let t = 254; t >= 0; t--) {
  31      const k_t = (k >> BigInt(t)) & 1n;
  32      swap ^= k_t;
  33      if (swap) { [x_2, x_3] = [x_3, x_2]; [z_2, z_3] = [z_3, z_2]; }
  34      swap = k_t;
  35  
  36      const A  = mod(x_2 + z_2);
  37      const AA = mod(A * A);
  38      const B  = mod(x_2 - z_2);
  39      const BB = mod(B * B);
  40      const E  = mod(AA - BB);
  41      const C  = mod(x_3 + z_3);
  42      const D  = mod(x_3 - z_3);
  43      const DA = mod(D * A);
  44      const CB = mod(C * B);
  45      x_3 = mod((DA + CB) * (DA + CB));
  46      z_3 = mod(u * mod((DA - CB) * (DA - CB)));
  47      x_2 = mod(AA * BB);
  48      z_2 = mod(E * (AA + A24 * E));
  49    }
  50  
  51    if (swap) { [x_2, x_3] = [x_3, x_2]; [z_2, z_3] = [z_3, z_2]; }
  52    return mod(x_2 * modInv(z_2));
  53  }
  54  
  55  // RFC 7748 §5: decode scalar, clamp, decode u-coordinate.
  56  function decodeScalar(u8) {
  57    const s = new Uint8Array(u8);
  58    s[0]  &= 248;
  59    s[31] &= 127;
  60    s[31] |= 64;
  61    let n = 0n;
  62    for (let i = 31; i >= 0; i--) n = (n << 8n) | BigInt(s[i]);
  63    return n;
  64  }
  65  
  66  function decodeU(u8) {
  67    // Mask bit 255 per RFC 7748.
  68    let n = 0n;
  69    for (let i = 31; i >= 0; i--) n = (n << 8n) | BigInt(u8[i]);
  70    return n & ((1n << 255n) - 1n);
  71  }
  72  
  73  function encodeU(n) {
  74    const out = new Uint8Array(32);
  75    for (let i = 0; i < 32; i++) { out[i] = Number(n & 0xFFn); n >>= 8n; }
  76    return out;
  77  }
  78  
  79  // --- Slice conversions (same pattern as schnorr.mjs) ---
  80  
  81  function sliceToU8(s) {
  82    if (s instanceof Uint8Array) return s;
  83    if (typeof s === 'string') return new TextEncoder().encode(s);
  84    const u = new Uint8Array(s.$length);
  85    for (let i = 0; i < s.$length; i++) u[i] = s.$array[s.$offset + i];
  86    return u;
  87  }
  88  
  89  function u8ToSlice(u8) {
  90    const arr = new Array(u8.length);
  91    for (let i = 0; i < u8.length; i++) arr[i] = u8[i];
  92    return new Slice(arr, 0, u8.length, u8.length);
  93  }
  94  
  95  // --- Exports ---
  96  
  97  // ScalarMult computes X25519(scalar, point).
  98  // scalar and point are 32-byte little-endian. Returns 32-byte result.
  99  export function ScalarMult(scalar, point) {
 100    const k = decodeScalar(sliceToU8(scalar));
 101    const u = decodeU(sliceToU8(point));
 102    return u8ToSlice(encodeU(ladder(k, u)));
 103  }
 104  
 105  // ScalarBaseMult computes X25519(scalar, 9).
 106  export function ScalarBaseMult(scalar) {
 107    const k = decodeScalar(sliceToU8(scalar));
 108    return u8ToSlice(encodeU(ladder(k, 9n)));
 109  }
 110