From be0b2a32a5622c71f93c7552b944c9fa68eac11c Mon Sep 17 00:00:00 2001 From: Paul Miller Date: Thu, 26 Jan 2023 02:07:45 +0000 Subject: [PATCH] Fp rename. Edwards refactor. Weierstrass Fn instead of mod --- README.md | 19 +++ src/abstract/bls.ts | 24 +-- src/abstract/edwards.ts | 302 +++++++++++++--------------------- src/abstract/group.ts | 6 +- src/abstract/hash-to-curve.ts | 1 + src/abstract/modular.ts | 68 ++++---- src/abstract/poseidon.ts | 4 +- src/abstract/utils.ts | 1 + src/abstract/weierstrass.ts | 141 +++++++--------- src/bls12-381.ts | 93 +++++------ src/ed25519.ts | 30 ++-- src/ed448.ts | 32 ++-- src/secp256k1.ts | 40 +++-- src/stark.ts | 2 +- 14 files changed, 347 insertions(+), 416 deletions(-) diff --git a/README.md b/README.md index 1b97a23..68e1c36 100644 --- a/README.md +++ b/README.md @@ -458,6 +458,25 @@ verify noble x 698 ops/sec @ 1ms/op ``` +## Upgrading + +Differences from @noble/secp256k1 1.7: + +1. Different double() formula (but same addition) +2. Different sqrt() function +3. DRBG supports outputLen bigger than outputLen of hmac +4. Support for different hash functions + +Differences from @noble/ed25519 1.7: + +1. Variable field element lengths between EDDSA/ECDH: + EDDSA (RFC8032) is 456 bits / 57 bytes, ECDH (RFC7748) is 448 bits / 56 bytes +2. Different addition formula (doubling is same) +3. uvRatio differs between curves (half-expected, not only pow fn changes) +4. Point decompression code is different (unexpected), now using generalized formula +5. Domain function was no-op for ed25519, but adds some data even with empty context for ed448 + + ## Contributing & testing 1. Clone the repository diff --git a/src/abstract/bls.ts b/src/abstract/bls.ts index 70f4e3a..2ca93ca 100644 --- a/src/abstract/bls.ts +++ b/src/abstract/bls.ts @@ -129,18 +129,18 @@ export function bls( let ell_coeff: [Fp2, Fp2, Fp2][] = []; for (let i = BLS_X_LEN - 2; i >= 0; i--) { // Double - let t0 = Fp2.square(Ry); // Ry² - let t1 = Fp2.square(Rz); // Rz² + let t0 = Fp2.sqr(Ry); // Ry² + let t1 = Fp2.sqr(Rz); // Rz² let t2 = Fp2.multiplyByB(Fp2.mul(t1, 3n)); // 3 * T1 * B let t3 = Fp2.mul(t2, 3n); // 3 * T2 - let t4 = Fp2.sub(Fp2.sub(Fp2.square(Fp2.add(Ry, Rz)), t1), t0); // (Ry + Rz)² - T1 - T0 + let t4 = Fp2.sub(Fp2.sub(Fp2.sqr(Fp2.add(Ry, Rz)), t1), t0); // (Ry + Rz)² - T1 - T0 ell_coeff.push([ Fp2.sub(t2, t0), // T2 - T0 - Fp2.mul(Fp2.square(Rx), 3n), // 3 * Rx² - Fp2.negate(t4), // -T4 + Fp2.mul(Fp2.sqr(Rx), 3n), // 3 * Rx² + Fp2.neg(t4), // -T4 ]); Rx = Fp2.div(Fp2.mul(Fp2.mul(Fp2.sub(t0, t3), Rx), Ry), 2n); // ((T0 - T3) * Rx * Ry) / 2 - Ry = Fp2.sub(Fp2.square(Fp2.div(Fp2.add(t0, t3), 2n)), Fp2.mul(Fp2.square(t2), 3n)); // ((T0 + T3) / 2)² - 3 * T2² + Ry = Fp2.sub(Fp2.sqr(Fp2.div(Fp2.add(t0, t3), 2n)), Fp2.mul(Fp2.sqr(t2), 3n)); // ((T0 + T3) / 2)² - 3 * T2² Rz = Fp2.mul(t0, t4); // T0 * T4 if (ut.bitGet(CURVE.x, i)) { // Addition @@ -148,13 +148,13 @@ export function bls( let t1 = Fp2.sub(Rx, Fp2.mul(Qx, Rz)); // Rx - Qx * Rz ell_coeff.push([ Fp2.sub(Fp2.mul(t0, Qx), Fp2.mul(t1, Qy)), // T0 * Qx - T1 * Qy - Fp2.negate(t0), // -T0 + Fp2.neg(t0), // -T0 t1, // T1 ]); - let t2 = Fp2.square(t1); // T1² + let t2 = Fp2.sqr(t1); // T1² let t3 = Fp2.mul(t2, t1); // T2 * T1 let t4 = Fp2.mul(t2, Rx); // T2 * Rx - let t5 = Fp2.add(Fp2.sub(t3, Fp2.mul(t4, 2n)), Fp2.mul(Fp2.square(t0), Rz)); // T3 - 2 * T4 + T0² * Rz + let t5 = Fp2.add(Fp2.sub(t3, Fp2.mul(t4, 2n)), Fp2.mul(Fp2.sqr(t0), Rz)); // T3 - 2 * T4 + T0² * Rz Rx = Fp2.mul(t1, t5); // T1 * T5 Ry = Fp2.sub(Fp2.mul(Fp2.sub(t4, t5), t0), Fp2.mul(t3, Ry)); // (T4 - T5) * T0 - T3 * Ry Rz = Fp2.mul(Rz, t3); // Rz * T3 @@ -176,7 +176,7 @@ export function bls( const F = ell[j]; f12 = Fp12.multiplyBy014(f12, F[0], Fp2.mul(F[1], Px), Fp2.mul(F[2], Py)); } - if (i !== 0) f12 = Fp12.square(f12); + if (i !== 0) f12 = Fp12.sqr(f12); } return Fp12.conjugate(f12); } @@ -300,7 +300,7 @@ export function bls( const ePHm = pairing(P.negate(), Hm, false); const eGS = pairing(G, S, false); const exp = Fp12.finalExponentiate(Fp12.mul(eGS, ePHm)); - return Fp12.equals(exp, Fp12.ONE); + return Fp12.eql(exp, Fp12.ONE); } // Adds a bunch of public key points together. @@ -365,7 +365,7 @@ export function bls( paired.push(pairing(G1.ProjectivePoint.BASE.negate(), sig, false)); const product = paired.reduce((a, b) => Fp12.mul(a, b), Fp12.ONE); const exp = Fp12.finalExponentiate(product); - return Fp12.equals(exp, Fp12.ONE); + return Fp12.eql(exp, Fp12.ONE); } catch { return false; } diff --git a/src/abstract/edwards.ts b/src/abstract/edwards.ts index e1496a2..0f6831e 100644 --- a/src/abstract/edwards.ts +++ b/src/abstract/edwards.ts @@ -1,17 +1,8 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ // Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y² - -// Differences from @noble/ed25519 1.7: -// 1. Variable field element lengths between EDDSA/ECDH: -// EDDSA (RFC8032) is 456 bits / 57 bytes, ECDH (RFC7748) is 448 bits / 56 bytes -// 2. Different addition formula (doubling is same) -// 3. uvRatio differs between curves (half-expected, not only pow fn changes) -// 4. Point decompression code is different (unexpected), now using generalized formula -// 5. Domain function was no-op for ed25519, but adds some data even with empty context for ed448 - import * as mod from './modular.js'; import * as ut from './utils.js'; -import { ensureBytes, Hex, PrivKey } from './utils.js'; +import { ensureBytes, Hex } from './utils.js'; import { Group, GroupConstructor, wNAF } from './group.js'; // Be friendly to bad ECMAScript parsers by not using bigint literals like 123n @@ -22,29 +13,20 @@ const _8n = BigInt(8); // Edwards curves must declare params a & d. export type CurveType = ut.BasicCurve & { - // Params: a, d - a: bigint; - d: bigint; - // Hashes - // The interface, because we need outputLen for DRBG - hash: ut.CHash; - // CSPRNG - randomBytes: (bytesLength?: number) => Uint8Array; - // Probably clears bits in a byte array to produce a valid field element - adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array; - // Used during hashing - domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array; - // Ratio √(u/v) - uvRatio?: (u: bigint, v: bigint) => { isValid: boolean; value: bigint }; - // RFC 8032 pre-hashing of messages to sign() / verify() - preHash?: ut.CHash; - mapToCurve?: (scalar: bigint[]) => AffinePoint; + a: bigint; // curve param a + d: bigint; // curve param d + hash: ut.FHash; // Hashing + randomBytes: (bytesLength?: number) => Uint8Array; // CSPRNG + adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array; // clears bits to get valid field elemtn + domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array; // Used for hashing + uvRatio?: (u: bigint, v: bigint) => { isValid: boolean; value: bigint }; // Ratio √(u/v) + preHash?: ut.FHash; // RFC 8032 pre-hashing of messages to sign() / verify() + mapToCurve?: (scalar: bigint[]) => AffinePoint; // for hash-to-curve standard }; function validateOpts(curve: CurveType) { const opts = ut.validateOpts(curve); - if (typeof opts.hash !== 'function' || !ut.isPositiveInt(opts.hash.outputLen)) - throw new Error('Invalid hash function'); + if (typeof opts.hash !== 'function') throw new Error('Invalid hash function'); for (const i of ['a', 'd'] as const) { const val = opts[i]; if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`); @@ -84,7 +66,7 @@ export interface ExtPointConstructor extends GroupConstructor { new (x: bigint, y: bigint, z: bigint, t: bigint): ExtPointType; fromAffine(p: AffinePoint): ExtPointType; fromHex(hex: Hex): ExtPointType; - fromPrivateKey(privateKey: PrivKey): ExtPointType; // TODO: remove + fromPrivateKey(privateKey: Hex): ExtPointType; // TODO: remove } export type CurveFn = { @@ -105,23 +87,19 @@ export type CurveFn = { }; }; -// NOTE: it is not generic twisted curve for now, but ed25519/ed448 generic implementation +// It is not generic twisted curve for now, but ed25519/ed448 generic implementation export function twistedEdwards(curveDef: CurveType): CurveFn { const CURVE = validateOpts(curveDef) as ReturnType; - const Fp = CURVE.Fp; - const CURVE_ORDER = CURVE.n; - const MASK = _2n ** BigInt(CURVE.nByteLength * 8); - - // Function overrides - const { randomBytes } = CURVE; - const modP = Fp.create; + const { Fp, n: CURVE_ORDER, preHash, hash: cHash, randomBytes, nByteLength, h: cofactor } = CURVE; + const MASK = _2n ** BigInt(nByteLength * 8); + const modP = Fp.create; // Function overrides // sqrt(u/v) const uvRatio = CURVE.uvRatio || ((u: bigint, v: bigint) => { try { - return { isValid: true, value: Fp.sqrt(u * Fp.invert(v)) }; + return { isValid: true, value: Fp.sqrt(u * Fp.inv(v)) }; } catch (e) { return { isValid: false, value: _0n }; } @@ -133,35 +111,24 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { if (ctx.length || phflag) throw new Error('Contexts/pre-hash are not supported'); return data; }); // NOOP - function inBig(n: bigint) { - return typeof n === 'bigint' && 0n < n; - } - function assertInMask(n: bigint) { - if (inBig(n) && n < MASK) return n; - throw new Error(`Expected valid scalar < MASK, got ${typeof n} ${n}`); - } - function assertFE(n: bigint) { - if (inBig(n) && n < Fp.ORDER) return n; - throw new Error(`Expected valid scalar < P, got ${typeof n} ${n}`); - } - function assertGE(n: bigint) { - // GE = subgroup element, not full group - if (inBig(n) && n < CURVE_ORDER) return n; - throw new Error(`Expected valid scalar < N, got ${typeof n} ${n}`); + const inBig = (n: bigint) => typeof n === 'bigint' && 0n < n; // n in [1..] + const inRange = (n: bigint, max: bigint) => inBig(n) && inBig(max) && n < max; // n in [1..max-1] + const in0MaskRange = (n: bigint) => n === _0n || inRange(n, MASK); // n in [0..MASK-1] + function assertInRange(n: bigint, max: bigint) { + // n in [1..max-1] + if (inRange(n, max)) return n; + throw new Error(`Expected valid scalar < ${max}, got ${typeof n} ${n}`); } function assertGE0(n: bigint) { - // GE = subgroup element, not full group - return n === _0n ? n : assertGE(n); + // n in [0..CURVE_ORDER-1] + return n === _0n ? n : assertInRange(n, CURVE_ORDER); // GE = prime subgroup, not full group } - const coord = (n: bigint) => _0n <= n && n < MASK; // not < P because of ZIP215 - const pointPrecomputes = new Map(); - - /** - * Extended Point works in extended coordinates: (x, y, z, t) ∋ (x=x/z, y=y/z, t=xy). - * Default Point works in affine coordinates: (x, y) - * https://en.wikipedia.org/wiki/Twisted_Edwards_curve#Extended_coordinates - */ + function isPoint(other: unknown) { + if (!(other instanceof Point)) throw new Error('ExtendedPoint expected'); + } + // Extended Point works in extended coordinates: (x, y, z, t) ∋ (x=x/z, y=y/z, t=xy). + // https://en.wikipedia.org/wiki/Twisted_Edwards_curve#Extended_coordinates class Point implements ExtPointType { static readonly BASE = new Point(CURVE.Gx, CURVE.Gy, _1n, modP(CURVE.Gx * CURVE.Gy)); static readonly ZERO = new Point(_0n, _1n, _1n, _0n); // 0, 1, 1, 0 @@ -172,10 +139,10 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { readonly ez: bigint, readonly et: bigint ) { - if (!coord(ex)) throw new Error('x required'); - if (!coord(ey)) throw new Error('y required'); - if (!coord(ez)) throw new Error('z required'); - if (!coord(et)) throw new Error('t required'); + if (!in0MaskRange(ex)) throw new Error('x required'); + if (!in0MaskRange(ey)) throw new Error('y required'); + if (!in0MaskRange(ez)) throw new Error('z required'); + if (!in0MaskRange(et)) throw new Error('t required'); } get x(): bigint { @@ -186,9 +153,9 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { } static fromAffine(p: AffinePoint): Point { + if (p instanceof Point) throw new Error('extended point not allowed'); const { x, y } = p || {}; - if (p instanceof Point) throw new Error('fromAffine: extended point not allowed'); - if (!ut.big(x) || !ut.big(y)) throw new Error('fromAffine: invalid affine point'); + if (!in0MaskRange(x) || !in0MaskRange(y)) throw new Error('invalid affine point'); return new Point(x, y, _1n, modP(x * y)); } static normalizeZ(points: Point[]): Point[] { @@ -209,7 +176,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { // Compare one point to another. equals(other: Point): boolean { - assertExtPoint(other); + isPoint(other); const { ex: X1, ey: Y1, ez: Z1 } = this; const { ex: X2, ey: Y2, ez: Z2 } = other; const X1Z2 = modP(X1 * Z2); @@ -223,8 +190,8 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { return this.equals(Point.ZERO); } - // Inverses point to one corresponding to (x, -y) in Affine coordinates. negate(): Point { + // Flips point sign to a negative one (-x, y in affine coords) return new Point(modP(-this.ex), this.ey, this.ez, modP(-this.et)); } @@ -254,7 +221,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { // https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-add-2008-hwcd // Cost: 9M + 1*a + 1*d + 7add. add(other: Point) { - assertExtPoint(other); + isPoint(other); const { a, d } = CURVE; const { ex: X1, ey: Y1, ez: Z1, et: T1 } = this; const { ex: X2, ey: Y2, ez: Z2, et: T2 } = other; @@ -302,11 +269,9 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { return wnaf.wNAFCached(this, pointPrecomputes, n, Point.normalizeZ); } - // Constant time multiplication. - // Uses wNAF method. Windowed method may be 10% faster, - // but takes 2x longer to generate and consumes 2x memory. + // Constant-time multiplication. multiply(scalar: bigint): Point { - const { p, f } = this.wNAF(assertGE(scalar)); + const { p, f } = this.wNAF(assertInRange(scalar, CURVE_ORDER)); return Point.normalizeZ([p, f])[0]; } @@ -326,10 +291,10 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { // point with torsion component. // Multiplies point by cofactor and checks if the result is 0. isSmallOrder(): boolean { - return this.multiplyUnsafe(CURVE.h).is0(); + return this.multiplyUnsafe(cofactor).is0(); } - // Multiplies point by curve order (very big scalar CURVE.n) and checks if the result is 0. + // Multiplies point by curve order and checks if the result is 0. // Returns `false` is the point is dirty. isTorsionFree(): boolean { return wnaf.unsafeLadder(this, CURVE_ORDER).is0(); @@ -340,7 +305,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { toAffine(iz?: bigint): AffinePoint { const { ex: x, ey: y, ez: z } = this; const is0 = this.is0(); - if (iz == null) iz = is0 ? _8n : (Fp.invert(z) as bigint); // 8 was chosen arbitrarily + if (iz == null) iz = is0 ? _8n : (Fp.inv(z) as bigint); // 8 was chosen arbitrarily const ax = modP(x * iz); const ay = modP(y * iz); const zz = modP(z * iz); @@ -348,6 +313,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { if (zz !== _1n) throw new Error('invZ was invalid'); return { x: ax, y: ay }; } + clearCofactor(): Point { const { h: cofactor } = CURVE; if (cofactor === _1n) return this; @@ -356,58 +322,41 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { // Converts hash string or Uint8Array to Point. // Uses algo from RFC8032 5.1.3. - static fromHex(hex: Hex, strict = true) { + static fromHex(hex: Hex, strict = true): Point { const { d, a } = CURVE; const len = Fp.BYTES; - hex = ensureBytes(hex, len); - // 1. First, interpret the string as an integer in little-endian - // representation. Bit 255 of this number is the least significant - // bit of the x-coordinate and denote this value x_0. The - // y-coordinate is recovered simply by clearing this bit. If the - // resulting value is >= p, decoding fails. - const normed = hex.slice(); - const lastByte = hex[len - 1]; - normed[len - 1] = lastByte & ~0x80; + hex = ensureBytes(hex, len); // copy hex to a new array + const normed = hex.slice(); // copy again, we'll manipulate it + const lastByte = hex[len - 1]; // select last byte + normed[len - 1] = lastByte & ~0x80; // clear last bit const y = ut.bytesToNumberLE(normed); - if (y === _0n) { // y=0 is allowed } else { - if (strict) assertFE(y); // strict=true [0..CURVE.Fp.P] (2^255-19 for ed25519) - else assertInMask(y); // strict=false [0..MASK] (2^256 for ed25519) + // RFC8032 prohibits >= p, but ZIP215 doesn't + if (strict) assertInRange(y, Fp.ORDER); // strict=true [1..P-1] (2^255-19-1 for ed25519) + else assertInRange(y, MASK); // strict=false [1..MASK-1] (2^256-1 for ed25519) } - // 2. To recover the x-coordinate, the curve equation implies - // Ed25519: x² = (y² - 1) / (d y² + 1) (mod p). - // Ed448: x² = (y² - 1) / (d y² - 1) (mod p). - // For generic case: - // a*x²+y²=1+d*x²*y² - // -> y²-1 = d*x²*y²-a*x² - // -> y²-1 = x² (d*y²-a) - // -> x² = (y²-1) / (d*y²-a) - - // The denominator is always non-zero mod p. Let u = y² - 1 and v = d y² + 1. - const y2 = modP(y * y); - const u = modP(y2 - _1n); - const v = modP(d * y2 - a); - let { isValid, value: x } = uvRatio(u, v); + // Ed25519: x² = (y²-1)/(dy²+1) mod p. Ed448: x² = (y²-1)/(dy²-1) mod p. Generic case: + // ax²+y²=1+dx²y² => y²-1=dx²y²-ax² => y²-1=x²(dy²-a) => x²=(y²-1)/(dy²-a) + const y2 = modP(y * y); // denominator is always non-0 mod p. + const u = modP(y2 - _1n); // u = y² - 1 + const v = modP(d * y2 - a); // v = d y² + 1. + let { isValid, value: x } = uvRatio(u, v); // √(u/v) if (!isValid) throw new Error('Point.fromHex: invalid y coordinate'); - // 4. Finally, use the x_0 bit to select the right square root. If - // x = 0, and x_0 = 1, decoding fails. Otherwise, if x_0 != x mod - // 2, set x <-- p - x. Return the decoded point (x,y). - const isXOdd = (x & _1n) === _1n; - const isLastByteOdd = (lastByte & 0x80) !== 0; - if (isLastByteOdd !== isXOdd) x = modP(-x); + const isXOdd = (x & _1n) === _1n; // There are 2 square roots. Use x_0 bit to select proper + const isLastByteOdd = (lastByte & 0x80) !== 0; // if x=0 and x_0 = 1, fail + if (isLastByteOdd !== isXOdd) x = modP(-x); // if x_0 != x mod 2, set x = p-x return Point.fromAffine({ x, y }); } - static fromPrivateKey(privateKey: PrivKey) { - return getExtendedPublicKey(privateKey).point; + static fromPrivateKey(privKey: Hex) { + return getExtendedPublicKey(privKey).point; } toRawBytes(): Uint8Array { const { x, y } = this.toAffine(); - const len = Fp.BYTES; - const bytes = ut.numberToBytesLE(y, len); // each y has 2 x values (x, -y) - bytes[len - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y + const bytes = ut.numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y) + bytes[bytes.length - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y return bytes; // and use the last byte to encode sign of x } toHex(): string { @@ -415,105 +364,80 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { } } const { BASE: G, ZERO: I } = Point; - const wnaf = wNAF(Point, CURVE.nByteLength * 8); + const wnaf = wNAF(Point, nByteLength * 8); - function assertExtPoint(other: unknown) { - if (!(other instanceof Point)) throw new Error('ExtendedPoint expected'); - } // Little-endian SHA512 with modulo n function modnLE(hash: Uint8Array): bigint { return mod.mod(ut.bytesToNumberLE(hash), CURVE_ORDER); } + function isHex(item: Hex, err: string) { + if (typeof item !== 'string' && !(item instanceof Uint8Array)) + throw new Error(`${err} must be hex string or Uint8Array`); + } /** Convenience method that creates public key and other stuff. RFC8032 5.1.5 */ - function getExtendedPublicKey(key: PrivKey) { - const groupLen = CURVE.nByteLength; - // Normalize bigint / number / string to Uint8Array - const keyb = typeof key === 'bigint' ? ut.numberToBytesLE(assertInMask(key), groupLen) : key; + function getExtendedPublicKey(key: Hex) { + isHex(key, 'private key'); + const len = nByteLength; // Hash private key with curve's hash function to produce uniformingly random input - // We check byte lengths e.g.: ensureBytes(64, hash(ensureBytes(32, key))) - const hashed = ensureBytes(CURVE.hash(ensureBytes(keyb, groupLen)), 2 * groupLen); - - // First half's bits are cleared to produce a random field element. - const head = adjustScalarBytes(hashed.slice(0, groupLen)); - // Second half is called key prefix (5.1.6) - const prefix = hashed.slice(groupLen, 2 * groupLen); - // The actual private scalar - const scalar = modnLE(head); - // Point on Edwards curve aka public key - const point = G.multiply(scalar); - // Uint8Array representation - const pointBytes = point.toRawBytes(); + // Check byte lengths: ensure(64, h(ensure(32, key))) + const hashed = ensureBytes(cHash(ensureBytes(key, len)), 2 * len); + const head = adjustScalarBytes(hashed.slice(0, len)); // clear first half bits, produce FE + const prefix = hashed.slice(len, 2 * len); // second half is called key prefix (5.1.6) + const scalar = modnLE(head); // The actual private scalar + const point = G.multiply(scalar); // Point on Edwards curve aka public key + const pointBytes = point.toRawBytes(); // Uint8Array representation return { head, prefix, scalar, point, pointBytes }; } - /** - * Calculates ed25519 public key. RFC8032 5.1.5 - * 1. private key is hashed with sha512, then first 32 bytes are taken from the hash - * 2. 3 least significant bits of the first byte are cleared - */ - function getPublicKey(privateKey: PrivKey): Uint8Array { - return getExtendedPublicKey(privateKey).pointBytes; + // Calculates EdDSA pub key. RFC8032 5.1.5. Privkey is hashed. Use first half with 3 bits cleared + function getPublicKey(privKey: Hex): Uint8Array { + return getExtendedPublicKey(privKey).pointBytes; } - const EMPTY = new Uint8Array(); - function hashDomainToScalar(message: Uint8Array, context: Hex = EMPTY) { - context = ensureBytes(context); - return modnLE(CURVE.hash(domain(message, context, !!CURVE.preHash))); + // int('LE', SHA512(dom2(F, C) || msgs)) mod N + function hashDomainToScalar(context: Hex = new Uint8Array(), ...msgs: Uint8Array[]) { + const msg = ut.concatBytes(...msgs); + return modnLE(cHash(domain(msg, ensureBytes(context), !!preHash))); } /** Signs message with privateKey. RFC8032 5.1.6 */ - function sign(message: Hex, privateKey: Hex, context?: Hex): Uint8Array { - message = ensureBytes(message); - if (CURVE.preHash) message = CURVE.preHash(message); - const { prefix, scalar, pointBytes } = getExtendedPublicKey(privateKey); - const r = hashDomainToScalar(ut.concatBytes(prefix, message), context); - const R = G.multiply(r); // R = rG - const k = hashDomainToScalar(ut.concatBytes(R.toRawBytes(), pointBytes, message), context); // k = hash(R+P+msg) - const s = mod.mod(r + k * scalar, CURVE_ORDER); // s = r + kp + function sign(msg: Hex, privKey: Hex, context?: Hex): Uint8Array { + isHex(msg, 'message'); + msg = ensureBytes(msg); + if (preHash) msg = preHash(msg); // for ed25519ph etc. + const { prefix, scalar, pointBytes } = getExtendedPublicKey(privKey); + const r = hashDomainToScalar(context, prefix, msg); // r = dom2(F, C) || prefix || PH(M) + const R = G.multiply(r).toRawBytes(); // R = rG + const k = hashDomainToScalar(context, R, pointBytes, msg); // R || A || PH(M) + const s = mod.mod(r + k * scalar, CURVE_ORDER); // S = (r + k * s) mod L assertGE0(s); // 0 <= s < l - return ut.concatBytes(R.toRawBytes(), ut.numberToBytesLE(s, Fp.BYTES)); + const res = ut.concatBytes(R, ut.numberToBytesLE(s, Fp.BYTES)); + return ensureBytes(res, nByteLength * 2); // 64-byte signature } - /** - * Verifies EdDSA signature against message and public key. - * An extended group equation is checked. - * RFC8032 5.1.7 - * Compliant with ZIP215: - * 0 <= sig.R/publicKey < 2**256 (can be >= curve.P) - * 0 <= sig.s < l - * Not compliant with RFC8032: it's not possible to comply to both ZIP & RFC at the same time. - */ - function verify(sig: Hex, message: Hex, publicKey: Hex, context?: Hex): boolean { - const len = Fp.BYTES; - sig = ensureBytes(sig, 2 * len); - message = ensureBytes(message); - if (CURVE.preHash) message = CURVE.preHash(message); - const R = Point.fromHex(sig.slice(0, len), false); // non-strict; allows 0..MASK - const s = ut.bytesToNumberLE(sig.slice(len, 2 * len)); + function verify(sig: Hex, msg: Hex, publicKey: Hex, context?: Hex): boolean { + isHex(sig, 'sig'); + isHex(msg, 'message'); + const len = Fp.BYTES; // Verifies EdDSA signature against message and public key. RFC8032 5.1.7. + sig = ensureBytes(sig, 2 * len); // An extended group equation is checked. + msg = ensureBytes(msg); // ZIP215 compliant, which means not fully RFC8032 compliant. + if (preHash) msg = preHash(msg); // for ed25519ph, etc const A = Point.fromHex(publicKey, false); // Check for s bounds, hex validity + const R = Point.fromHex(sig.slice(0, len), false); // 0 <= R < 2^256: ZIP215 R can be >= P + const s = ut.bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l const SB = G.multiplyUnsafe(s); - const k = hashDomainToScalar(ut.concatBytes(R.toRawBytes(), A.toRawBytes(), message), context); - const kA = A.multiplyUnsafe(k); - const RkA = R.add(kA); + const k = hashDomainToScalar(context, R.toRawBytes(), A.toRawBytes(), msg); + const RkA = R.add(A.multiplyUnsafe(k)); // [8][S]B = [8]R + [8][k]A' return RkA.subtract(SB).clearCofactor().equals(Point.ZERO); } - // Enable precomputes. Slows down first publicKey computation by 20ms. - G._setWindowSize(8); + G._setWindowSize(8); // Enable precomputes. Slows down first publicKey computation by 20ms. const utils = { getExtendedPublicKey, - /** - * Not needed for ed25519 private keys. Needed if you use scalars directly (rare). - */ - hashToPrivateScalar: (hash: Hex): bigint => ut.hashToPrivateScalar(hash, CURVE_ORDER, true), - - /** - * ed25519 private keys are uniform 32-bit strings. We do not need to check for - * modulo bias like we do in secp256k1 randomPrivateKey() - */ + // ed25519 private keys are uniform 32b. No need to check for modulo bias, like in secp256k1. randomPrivateKey: (): Uint8Array => randomBytes(Fp.BYTES), /** diff --git a/src/abstract/group.ts b/src/abstract/group.ts index 60a10e6..d07b2d5 100644 --- a/src/abstract/group.ts +++ b/src/abstract/group.ts @@ -17,7 +17,9 @@ export type GroupConstructor = { ZERO: T; }; export type Mapper = (i: T[]) => T[]; -// Not big, but pretty complex and it is easy to break stuff. To avoid too much copy paste + +// Elliptic curve multiplication of Point by scalar. Complicated and fragile. Uses wNAF method. +// Windowed method is 10% faster, but takes 2x longer to generate & consumes 2x memory. export function wNAF>(c: GroupConstructor, bits: number) { const constTimeNegate = (condition: boolean, item: T): T => { const neg = item.negate(); @@ -129,7 +131,7 @@ export function wNAF>(c: GroupConstructor, bits: number) { wNAFCached(P: T, precomputesMap: Map, n: bigint, transform: Mapper): { p: T; f: T } { // @ts-ignore - const W: number = '_WINDOW_SIZE' in P ? P._WINDOW_SIZE : 1; + const W: number = P._WINDOW_SIZE || 1; // Calculate precomputes on a first run, reuse them after let comp = precomputesMap.get(P); if (!comp) { diff --git a/src/abstract/hash-to-curve.ts b/src/abstract/hash-to-curve.ts index 57706a2..1e67865 100644 --- a/src/abstract/hash-to-curve.ts +++ b/src/abstract/hash-to-curve.ts @@ -41,6 +41,7 @@ export function validateOpts(opts: Opts) { // UTF8 to ui8a // TODO: looks broken, ASCII only, why not TextEncoder/TextDecoder? it is in hashes anyway export function stringToBytes(str: string) { + // return new TextEncoder().encode(str); const bytes = new Uint8Array(str.length); for (let i = 0; i < str.length; i++) bytes[i] = str.charCodeAt(i); return bytes; diff --git a/src/abstract/modular.ts b/src/abstract/modular.ts index 0d39f5a..7f73dd9 100644 --- a/src/abstract/modular.ts +++ b/src/abstract/modular.ts @@ -92,7 +92,7 @@ export function tonelliShanks(P: bigint) { const p1div4 = (P + _1n) / _4n; return function tonelliFast(Fp: Field, n: T) { const root = Fp.pow(n, p1div4); - if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root'); + if (!Fp.eql(Fp.sqr(root), n)) throw new Error('Cannot find square root'); return root; }; } @@ -101,24 +101,24 @@ export function tonelliShanks(P: bigint) { const Q1div2 = (Q + _1n) / _2n; return function tonelliSlow(Fp: Field, n: T): T { // Step 0: Check that n is indeed a square: (n | p) should not be ≡ -1 - if (Fp.pow(n, legendreC) === Fp.negate(Fp.ONE)) throw new Error('Cannot find square root'); + if (Fp.pow(n, legendreC) === Fp.neg(Fp.ONE)) throw new Error('Cannot find square root'); let r = S; // TODO: will fail at Fp2/etc let g = Fp.pow(Fp.mul(Fp.ONE, Z), Q); // will update both x and b let x = Fp.pow(n, Q1div2); // first guess at the square root let b = Fp.pow(n, Q); // first guess at the fudge factor - while (!Fp.equals(b, Fp.ONE)) { - if (Fp.equals(b, Fp.ZERO)) return Fp.ZERO; // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm (4. If t = 0, return r = 0) + while (!Fp.eql(b, Fp.ONE)) { + if (Fp.eql(b, Fp.ZERO)) return Fp.ZERO; // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm (4. If t = 0, return r = 0) // Find m such b^(2^m)==1 let m = 1; - for (let t2 = Fp.square(b); m < r; m++) { - if (Fp.equals(t2, Fp.ONE)) break; - t2 = Fp.square(t2); // t2 *= t2 + for (let t2 = Fp.sqr(b); m < r; m++) { + if (Fp.eql(t2, Fp.ONE)) break; + t2 = Fp.sqr(t2); // t2 *= t2 } // NOTE: r-m-1 can be bigger than 32, need to convert to bigint before shift, otherwise there will be overflow const ge = Fp.pow(g, _1n << BigInt(r - m - 1)); // ge = 2^(r-m-1) - g = Fp.square(ge); // g = ge * ge + g = Fp.sqr(ge); // g = ge * ge x = Fp.mul(x, ge); // x *= ge b = Fp.mul(b, g); // b *= g r = m; @@ -142,7 +142,7 @@ export function FpSqrt(P: bigint) { return function sqrt3mod4(Fp: Field, n: T) { const root = Fp.pow(n, p1div4); // Throw if root**2 != n - if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root'); + if (!Fp.eql(Fp.sqr(root), n)) throw new Error('Cannot find square root'); return root; }; } @@ -156,7 +156,7 @@ export function FpSqrt(P: bigint) { const nv = Fp.mul(n, v); const i = Fp.mul(Fp.mul(nv, _2n), v); const root = Fp.mul(nv, Fp.sub(i, Fp.ONE)); - if (!Fp.equals(Fp.square(root), n)) throw new Error('Cannot find square root'); + if (!Fp.eql(Fp.sqr(root), n)) throw new Error('Cannot find square root'); return root; }; } @@ -206,13 +206,13 @@ export interface Field { // 1-arg create: (num: T) => T; isValid: (num: T) => boolean; - isZero: (num: T) => boolean; - negate(num: T): T; - invert(num: T): T; + is0: (num: T) => boolean; + neg(num: T): T; + inv(num: T): T; sqrt(num: T): T; - square(num: T): T; + sqr(num: T): T; // 2-args - equals(lhs: T, rhs: T): boolean; + eql(lhs: T, rhs: T): boolean; add(lhs: T, rhs: T): T; sub(lhs: T, rhs: T): T; mul(lhs: T, rhs: T | bigint): T; @@ -222,13 +222,13 @@ export interface Field { addN(lhs: T, rhs: T): T; subN(lhs: T, rhs: T): T; mulN(lhs: T, rhs: T | bigint): T; - squareN(num: T): T; + sqrN(num: T): T; // Optional // Should be same as sgn0 function in https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/ // NOTE: sgn0 is 'negative in LE', which is same as odd. And negative in LE is kinda strange definition anyway. isOdd?(num: T): boolean; // Odd instead of even since we have it for Fp2 - legendre?(num: T): T; + // legendre?(num: T): T; pow(lhs: T, power: bigint): T; invertBatch: (lst: T[]) => T[]; toBytes(num: T): Uint8Array; @@ -238,9 +238,9 @@ export interface Field { } // prettier-ignore const FIELD_FIELDS = [ - 'create', 'isValid', 'isZero', 'negate', 'invert', 'sqrt', 'square', - 'equals', 'add', 'sub', 'mul', 'pow', 'div', - 'addN', 'subN', 'mulN', 'squareN' + 'create', 'isValid', 'is0', 'neg', 'inv', 'sqrt', 'sqr', + 'eql', 'add', 'sub', 'mul', 'pow', 'div', + 'addN', 'subN', 'mulN', 'sqrN' ] as const; export function validateField(field: Field) { for (const i of ['ORDER', 'MASK'] as const) { @@ -268,7 +268,7 @@ export function FpPow(f: Field, num: T, power: bigint): T { let d = num; while (power > _0n) { if (power & _1n) p = f.mul(p, d); - d = f.square(d); + d = f.sqr(d); power >>= 1n; } return p; @@ -278,15 +278,15 @@ export function FpInvertBatch(f: Field, nums: T[]): T[] { const tmp = new Array(nums.length); // Walk from first to last, multiply them by each other MOD p const lastMultiplied = nums.reduce((acc, num, i) => { - if (f.isZero(num)) return acc; + if (f.is0(num)) return acc; tmp[i] = acc; return f.mul(acc, num); }, f.ONE); // Invert last element - const inverted = f.invert(lastMultiplied); + const inverted = f.inv(lastMultiplied); // Walk from last to first, multiply them by inverted each other MOD p nums.reduceRight((acc, num, i) => { - if (f.isZero(num)) return acc; + if (f.is0(num)) return acc; tmp[i] = f.mul(acc, tmp[i]); return f.mul(acc, num); }, inverted); @@ -294,7 +294,7 @@ export function FpInvertBatch(f: Field, nums: T[]): T[] { } export function FpDiv(f: Field, lhs: T, rhs: T | bigint): T { - return f.mul(lhs, typeof rhs === 'bigint' ? invert(rhs, f.ORDER) : f.invert(rhs)); + return f.mul(lhs, typeof rhs === 'bigint' ? invert(rhs, f.ORDER) : f.inv(rhs)); } // This function returns True whenever the value x is a square in the field F. @@ -302,7 +302,7 @@ export function FpIsSquare(f: Field) { const legendreConst = (f.ORDER - _1n) / _2n; // Integer arithmetic return (x: T): boolean => { const p = f.pow(x, legendreConst); - return f.equals(p, f.ZERO) || f.equals(p, f.ONE); + return f.eql(p, f.ZERO) || f.eql(p, f.ONE); }; } @@ -334,12 +334,12 @@ export function Fp( throw new Error(`Invalid field element: expected bigint, got ${typeof num}`); return _0n <= num && num < ORDER; // 0 is valid element, but it's not invertible }, - isZero: (num) => num === _0n, + is0: (num) => num === _0n, isOdd: (num) => (num & _1n) === _1n, - negate: (num) => mod(-num, ORDER), - equals: (lhs, rhs) => lhs === rhs, + neg: (num) => mod(-num, ORDER), + eql: (lhs, rhs) => lhs === rhs, - square: (num) => mod(num * num, ORDER), + sqr: (num) => mod(num * num, ORDER), add: (lhs, rhs) => mod(lhs + rhs, ORDER), sub: (lhs, rhs) => mod(lhs - rhs, ORDER), mul: (lhs, rhs) => mod(lhs * rhs, ORDER), @@ -347,12 +347,12 @@ export function Fp( div: (lhs, rhs) => mod(lhs * invert(rhs, ORDER), ORDER), // Same as above, but doesn't normalize - squareN: (num) => num * num, + sqrN: (num) => num * num, addN: (lhs, rhs) => lhs + rhs, subN: (lhs, rhs) => lhs - rhs, mulN: (lhs, rhs) => lhs * rhs, - invert: (num) => invert(num, ORDER), + inv: (num) => invert(num, ORDER), sqrt: redef.sqrt || ((n) => sqrtP(f, n)), invertBatch: (lst) => FpInvertBatch(f, lst), // TODO: do we really need constant cmov? @@ -372,11 +372,11 @@ export function Fp( export function FpSqrtOdd(Fp: Field, elm: T) { if (!Fp.isOdd) throw new Error(`Field doesn't have isOdd`); const root = Fp.sqrt(elm); - return Fp.isOdd(root) ? root : Fp.negate(root); + return Fp.isOdd(root) ? root : Fp.neg(root); } export function FpSqrtEven(Fp: Field, elm: T) { if (!Fp.isOdd) throw new Error(`Field doesn't have isOdd`); const root = Fp.sqrt(elm); - return Fp.isOdd(root) ? Fp.negate(root) : root; + return Fp.isOdd(root) ? Fp.neg(root) : root; } diff --git a/src/abstract/poseidon.ts b/src/abstract/poseidon.ts index 021564f..fc8b66e 100644 --- a/src/abstract/poseidon.ts +++ b/src/abstract/poseidon.ts @@ -33,8 +33,8 @@ export function validateOpts(opts: PoseidonOpts) { const _sboxPower = BigInt(sboxPower); let sboxFn = (n: bigint) => mod.FpPow(Fp, n, _sboxPower); // Unwrapped sbox power for common cases (195->142μs) - if (sboxPower === 3) sboxFn = (n: bigint) => Fp.mul(Fp.squareN(n), n); - else if (sboxPower === 5) sboxFn = (n: bigint) => Fp.mul(Fp.squareN(Fp.squareN(n)), n); + if (sboxPower === 3) sboxFn = (n: bigint) => Fp.mul(Fp.sqrN(n), n); + else if (sboxPower === 5) sboxFn = (n: bigint) => Fp.mul(Fp.sqrN(Fp.sqrN(n)), n); if (opts.roundsFull % 2 !== 0) throw new Error(`Poseidon roundsFull is not even: ${opts.roundsFull}`); diff --git a/src/abstract/utils.ts b/src/abstract/utils.ts index d34210f..751e56e 100644 --- a/src/abstract/utils.ts +++ b/src/abstract/utils.ts @@ -18,6 +18,7 @@ export type CHash = { outputLen: number; create(opts?: { dkLen?: number }): any; // For shake }; +export type FHash = (message: Uint8Array | string) => Uint8Array; // NOTE: these are generic, even if curve is on some polynominal field (bls), it will still have P/n/h // But generator can be different (Fp2/Fp6 for bls?) diff --git a/src/abstract/weierstrass.ts b/src/abstract/weierstrass.ts index c262307..75b5eff 100644 --- a/src/abstract/weierstrass.ts +++ b/src/abstract/weierstrass.ts @@ -1,16 +1,8 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ // Short Weierstrass curve. The formula is: y² = x³ + ax + b - -// Differences from @noble/secp256k1 1.7: -// 1. Different double() formula (but same addition) -// 2. Different sqrt() function -// 3. truncateHash() truncateOnly mode -// 4. DRBG supports outputLen bigger than outputLen of hmac -// 5. Support for different hash functions - import * as mod from './modular.js'; import * as ut from './utils.js'; -import { bytesToHex, Hex, PrivKey } from './utils.js'; +import { Hex, PrivKey } from './utils.js'; import { Group, GroupConstructor, wNAF } from './group.js'; type HmacFnSync = (key: Uint8Array, ...messages: Uint8Array[]) => Uint8Array; @@ -52,7 +44,7 @@ const DER = { }, parseInt(data: Uint8Array): { data: bigint; left: Uint8Array } { if (data.length < 2 || data[0] !== 0x02) { - throw new DERError(`Invalid signature integer tag: ${bytesToHex(data)}`); + throw new DERError(`Invalid signature integer tag: ${ut.bytesToHex(data)}`); } const len = data[1]; const res = data.subarray(2, len + 2); @@ -67,7 +59,7 @@ const DER = { }, parseSig(data: Uint8Array): { r: bigint; s: bigint } { if (data.length < 2 || data[0] != 0x30) { - throw new DERError(`Invalid signature tag: ${bytesToHex(data)}`); + throw new DERError(`Invalid signature tag: ${ut.bytesToHex(data)}`); } if (data[1] !== data.length - 2) { throw new DERError('Invalid signature: incorrect length'); @@ -75,7 +67,9 @@ const DER = { const { data: r, left: sBytes } = DER.parseInt(data.subarray(2)); const { data: s, left: rBytesLeft } = DER.parseInt(sBytes); if (rBytesLeft.length) { - throw new DERError(`Invalid signature: left bytes after parsing: ${bytesToHex(rBytesLeft)}`); + throw new DERError( + `Invalid signature: left bytes after parsing: ${ut.bytesToHex(rBytesLeft)}` + ); } return { r, s }; }, @@ -156,7 +150,7 @@ function validatePointOpts(curve: CurvePointsType) { } const endo = opts.endo; if (endo) { - if (!Fp.equals(opts.a, Fp.ZERO)) { + if (!Fp.eql(opts.a, Fp.ZERO)) { throw new Error('Endomorphism can only be defined for Koblitz curves that have a=0'); } if ( @@ -194,7 +188,7 @@ export function weierstrassPoints(opts: CurvePointsType) { */ function weierstrassEquation(x: T): T { const { a, b } = CURVE; - const x2 = Fp.square(x); // x * x + const x2 = Fp.sqr(x); // x * x const x3 = Fp.mul(x2, x); // x2 * x return Fp.add(Fp.add(x3, Fp.mul(x, a)), b); // x3 + a * x + b } @@ -213,7 +207,7 @@ export function weierstrassPoints(opts: CurvePointsType) { * - `wrapPrivateKey` when true, executed after most checks, but before `0 < key < n` */ function normalizePrivateKey(key: PrivKey): bigint { - const { normalizePrivateKey: custom, nByteLength: groupLen, wrapPrivateKey, n: order } = CURVE; + const { normalizePrivateKey: custom, nByteLength: groupLen, wrapPrivateKey, n } = CURVE; if (typeof custom === 'function') key = custom(key); let num: bigint; if (typeof key === 'bigint') { @@ -230,7 +224,7 @@ export function weierstrassPoints(opts: CurvePointsType) { throw new Error('private key must be bytes, hex or bigint, not ' + typeof key); } // Useful for curves with cofactor != 1 - if (wrapPrivateKey) num = mod.mod(num, order); + if (wrapPrivateKey) num = mod.mod(num, n); assertGE(num); return num; } @@ -258,7 +252,7 @@ export function weierstrassPoints(opts: CurvePointsType) { const { x, y } = p || {}; if (!p || !Fp.isValid(x) || !Fp.isValid(y)) throw new Error('invalid affine point'); if (p instanceof ProjectivePoint) throw new Error('projective point not allowed'); - const is0 = (i: T) => Fp.equals(i, Fp.ZERO); + const is0 = (i: T) => Fp.eql(i, Fp.ZERO); // fromAffine(x:0, y:0) would produce (x:0, y:0, z:1), but we need (x:0, y:1, z:0) if (is0(x) && is0(y)) return ProjectivePoint.ZERO; return new ProjectivePoint(x, y, Fp.ONE); @@ -319,9 +313,9 @@ export function weierstrassPoints(opts: CurvePointsType) { const { x, y } = this.toAffine(); // Check if x, y are valid field elements if (!Fp.isValid(x) || !Fp.isValid(y)) throw new Error('bad point: x or y not FE'); - const left = Fp.square(y); // y² + const left = Fp.sqr(y); // y² const right = weierstrassEquation(x); // x³ + ax + b - if (!Fp.equals(left, right)) throw new Error('bad point: equation left != right'); + if (!Fp.eql(left, right)) throw new Error('bad point: equation left != right'); if (!this.isTorsionFree()) throw new Error('bad point: not in prime-order subgroup'); } hasEvenY(): boolean { @@ -337,8 +331,8 @@ export function weierstrassPoints(opts: CurvePointsType) { assertPrjPoint(other); const { px: X1, py: Y1, pz: Z1 } = this; const { px: X2, py: Y2, pz: Z2 } = other; - const U1 = Fp.equals(Fp.mul(X1, Z2), Fp.mul(X2, Z1)); - const U2 = Fp.equals(Fp.mul(Y1, Z2), Fp.mul(Y2, Z1)); + const U1 = Fp.eql(Fp.mul(X1, Z2), Fp.mul(X2, Z1)); + const U2 = Fp.eql(Fp.mul(Y1, Z2), Fp.mul(Y2, Z1)); return U1 && U2; } @@ -346,7 +340,7 @@ export function weierstrassPoints(opts: CurvePointsType) { * Flips point to one corresponding to (x, -y) in Affine coordinates. */ negate(): ProjectivePoint { - return new ProjectivePoint(this.px, Fp.negate(this.py), this.pz); + return new ProjectivePoint(this.px, Fp.neg(this.py), this.pz); } // Renes-Costello-Batina exception-free doubling formula. @@ -544,12 +538,12 @@ export function weierstrassPoints(opts: CurvePointsType) { const is0 = this.is0(); // If invZ was 0, we return zero point. However we still want to execute // all operations, so we replace invZ with a random number, 1. - if (iz == null) iz = is0 ? Fp.ONE : Fp.invert(z); + if (iz == null) iz = is0 ? Fp.ONE : Fp.inv(z); const ax = Fp.mul(x, iz); const ay = Fp.mul(y, iz); const zz = Fp.mul(z, iz); if (is0) return { x: Fp.ZERO, y: Fp.ZERO }; - if (!Fp.equals(zz, Fp.ONE)) throw new Error('invZ was invalid'); + if (!Fp.eql(zz, Fp.ONE)) throw new Error('invZ was invalid'); return { x: ax, y: ay }; } isTorsionFree(): boolean { @@ -571,7 +565,7 @@ export function weierstrassPoints(opts: CurvePointsType) { } toHex(isCompressed = true): string { - return bytesToHex(this.toRawBytes(isCompressed)); + return ut.bytesToHex(this.toRawBytes(isCompressed)); } } const _bits = CURVE.nBitLength; @@ -743,7 +737,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { const isYOdd = (y & _1n) === _1n; // ECDSA const isHeadOdd = (head & 1) === 1; - if (isHeadOdd !== isYOdd) y = Fp.negate(y); + if (isHeadOdd !== isYOdd) y = Fp.neg(y); return { x, y }; } else if (len === uncompressedLen && head === 0x04) { const x = Fp.fromBytes(tail.subarray(0, Fp.BYTES)); @@ -756,15 +750,8 @@ export function weierstrass(curveDef: CurveType): CurveFn { } }, }); - // type Point = typeof ProjectivePoint.BASE; - - // Do we need these functions at all? - function numToField(num: bigint): Uint8Array { - if (typeof num !== 'bigint') throw new Error('Expected bigint'); - if (!(_0n <= num && num < Fp.MASK)) throw new Error(`Expected number < 2^${Fp.BYTES * 8}`); - return Fp.toBytes(num); - } - const numToFieldStr = (num: bigint): string => bytesToHex(numToField(num)); + const numToNByteStr = (num: bigint): string => + ut.bytesToHex(ut.numberToBytesBE(num, CURVE.nByteLength)); function isBiggerThanHalfOrder(number: bigint) { const HALF = CURVE_ORDER >> _1n; @@ -820,21 +807,18 @@ export function weierstrass(curveDef: CurveType): CurveFn { const radj = rec === 2 || rec === 3 ? r + N : r; if (radj >= Fp.ORDER) throw new Error('recovery id 2 or 3 invalid'); const prefix = (rec & 1) === 0 ? '02' : '03'; - const R = Point.fromHex(prefix + numToFieldStr(radj)); - const ir = mod.invert(radj, N); // r^-1 - const u1 = mod.mod(-h * ir, N); // -hr^-1 - const u2 = mod.mod(s * ir, N); // sr^-1 + const R = Point.fromHex(prefix + numToNByteStr(radj)); + const Fn = mod.Fp(N); + const ir = Fn.inv(radj); // r^-1 + const u1 = Fn.mul(-h, ir); // -hr^-1 + const u2 = Fn.mul(s, ir); // sr^-1 const Q = Point.BASE.multiplyAndAddUnsafe(R, u1, u2); // (sr^-1)R-(hr^-1)G = -(hr^-1)G + (sr^-1) if (!Q) throw new Error('point at infinify'); // unsafe is fine: no priv data leaked Q.assertValidity(); return Q; } - /** - * Default signatures are always low-s, to prevent malleability. - * `sign(lowS: true)` always produces low-s sigs. - * `verify(lowS: true)` always fails for high-s. - */ + // Signatures should be low-s, to prevent malleability. hasHighS(): boolean { return isBiggerThanHalfOrder(this.s); } @@ -866,7 +850,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { return ut.hexToBytes(this.toCompactHex()); } toCompactHex() { - return numToFieldStr(this.r) + numToFieldStr(this.s); + return numToNByteStr(this.r) + numToNByteStr(this.s); } } @@ -885,7 +869,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { * Converts some bytes to a valid private key. Needs at least (nBitLength+64) bytes. */ hashToPrivateKey: (hash: Hex): Uint8Array => - numToField(ut.hashToPrivateScalar(hash, CURVE_ORDER)), + ut.numberToBytesBE(ut.hashToPrivateScalar(hash, CURVE_ORDER), CURVE.nByteLength), /** * Produces cryptographically secure private key from random of size (nBitLength+64) @@ -1014,27 +998,26 @@ export function weierstrass(curveDef: CurveType): CurveFn { const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash! // Converts signature params into point w r/s, checks result for validity. function k2sig(kBytes: Uint8Array): Signature | undefined { - const { n } = CURVE; + // const { n } = CURVE; // RFC 6979 Section 3.2, step 3: k = bits2int(T) const k = bits2int(kBytes); // Cannot use fields methods, since it is group element if (!isWithinCurveOrder(k)) return; // Important: all mod() calls in the function must be done over `n` - const ik = mod.invert(k, n); - const q = Point.BASE.multiply(k).toAffine(); - // r = x mod n - const r = mod.mod(q.x, n); - if (r === _0n) return; - // s = (m + dr)/k mod n where x/k == x*inv(k) - const s = mod.mod(ik * mod.mod(m + mod.mod(d * r, n), n), n); - if (s === _0n) return; - // recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n - let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n); + const Fn = mod.Fp(CURVE.n); + const ik = Fn.inv(k); // k^-1 mod n + const q = Point.BASE.multiply(k).toAffine(); // q = Gk + const r = Fn.create(q.x); // r = q.x mod n + if (r === _0n) return; // r=0 invalid + const s = Fn.mul(ik, Fn.add(m, Fn.mul(d, r))); // s = k^-1(m + rd) mod n + if (s === _0n) return; // s=0 invalid + let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n); // recovery bit (2 or 3, when q.x > n) let normS = s; if (lowS && isBiggerThanHalfOrder(s)) { - normS = normalizeS(s); + // if lowS was passed, ensure s is always + normS = normalizeS(s); // in the bottom half of CURVE.n recovery ^= 1; } - return new Signature(r, normS, recovery); + return new Signature(r, normS, recovery); // use normS, not s } return { seed, k2sig }; } @@ -1049,9 +1032,8 @@ export function weierstrass(curveDef: CurveType): CurveFn { * r = x mod n * s = (m + dr)/k mod n * ``` - * @param opts `lowS, extraEntropy` + * @param opts `lowS, extraEntropy, prehash` */ - // TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg? function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature { const { seed, k2sig } = prepSig(msgHash, privKey, opts); // Steps A, D of RFC6979 3.2. const genUntil = hmacDrbg(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac); @@ -1108,12 +1090,13 @@ export function weierstrass(curveDef: CurveType): CurveFn { const { n: N } = CURVE; const { r, s } = _sig; const h = bits2int_modN(msgHash); // Cannot use fields methods, since it is group element - const is = mod.invert(s, N); // s^-1 - const u1 = mod.mod(h * is, N); // u1 = hs^-1 mod n - const u2 = mod.mod(r * is, N); // u2 = rs^-1 mod n + const Fn = mod.Fp(N); + const is = Fn.inv(s); // s^-1 + const u1 = Fn.mul(h, is); // u1 = hs^-1 mod n + const u2 = Fn.mul(r, is); // u2 = rs^-1 mod n const R = Point.BASE.multiplyAndAddUnsafe(P, u1, u2)?.toAffine(); // R = u1⋅G + u2⋅P if (!R) return false; - const v = mod.mod(R.x, N); + const v = Fn.create(R.x); return v === r; } return { @@ -1131,7 +1114,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { // Implementation of the Shallue and van de Woestijne method for any Weierstrass curve -// TODO: check if there is a way to merge this with uvRation in Edwards && move to modular? +// TODO: check if there is a way to merge this with uvRatio in Edwards && move to modular? // b = True and y = sqrt(u / v) if (u / v) is square in F, and // b = False and y = sqrt(Z * (u / v)) otherwise. export function SWUFpSqrtRatio(Fp: mod.Field, Z: T) { @@ -1149,7 +1132,7 @@ export function SWUFpSqrtRatio(Fp: mod.Field, Z: T) { let sqrtRatio = (u: T, v: T): { isValid: boolean; value: T } => { let tv1 = c6; // 1. tv1 = c6 let tv2 = Fp.pow(v, c4); // 2. tv2 = v^c4 - let tv3 = Fp.square(tv2); // 3. tv3 = tv2^2 + let tv3 = Fp.sqr(tv2); // 3. tv3 = tv2^2 tv3 = Fp.mul(tv3, v); // 4. tv3 = tv3 * v let tv5 = Fp.mul(u, tv3); // 5. tv5 = u * tv3 tv5 = Fp.pow(tv5, c3); // 6. tv5 = tv5^c3 @@ -1158,7 +1141,7 @@ export function SWUFpSqrtRatio(Fp: mod.Field, Z: T) { tv3 = Fp.mul(tv5, u); // 9. tv3 = tv5 * u let tv4 = Fp.mul(tv3, tv2); // 10. tv4 = tv3 * tv2 tv5 = Fp.pow(tv4, c5); // 11. tv5 = tv4^c5 - let isQR = Fp.equals(tv5, Fp.ONE); // 12. isQR = tv5 == 1 + let isQR = Fp.eql(tv5, Fp.ONE); // 12. isQR = tv5 == 1 tv2 = Fp.mul(tv3, c7); // 13. tv2 = tv3 * c7 tv5 = Fp.mul(tv4, tv1); // 14. tv5 = tv4 * tv1 tv3 = Fp.cmov(tv2, tv3, isQR); // 15. tv3 = CMOV(tv2, tv3, isQR) @@ -1167,7 +1150,7 @@ export function SWUFpSqrtRatio(Fp: mod.Field, Z: T) { for (let i = c1; i > 1; i--) { let tv5 = 2n ** (i - 2n); // 18. tv5 = i - 2; 19. tv5 = 2^tv5 let tvv5 = Fp.pow(tv4, tv5); // 20. tv5 = tv4^tv5 - const e1 = Fp.equals(tvv5, Fp.ONE); // 21. e1 = tv5 == 1 + const e1 = Fp.eql(tvv5, Fp.ONE); // 21. e1 = tv5 == 1 tv2 = Fp.mul(tv3, tv1); // 22. tv2 = tv3 * tv1 tv1 = Fp.mul(tv1, tv1); // 23. tv1 = tv1 * tv1 tvv5 = Fp.mul(tv4, tv1); // 24. tv5 = tv4 * tv1 @@ -1179,16 +1162,16 @@ export function SWUFpSqrtRatio(Fp: mod.Field, Z: T) { if (Fp.ORDER % 4n === 3n) { // sqrt_ratio_3mod4(u, v) const c1 = (Fp.ORDER - 3n) / 4n; // 1. c1 = (q - 3) / 4 # Integer arithmetic - const c2 = Fp.sqrt(Fp.negate(Z)); // 2. c2 = sqrt(-Z) + const c2 = Fp.sqrt(Fp.neg(Z)); // 2. c2 = sqrt(-Z) sqrtRatio = (u: T, v: T) => { - let tv1 = Fp.square(v); // 1. tv1 = v^2 + let tv1 = Fp.sqr(v); // 1. tv1 = v^2 const tv2 = Fp.mul(u, v); // 2. tv2 = u * v tv1 = Fp.mul(tv1, tv2); // 3. tv1 = tv1 * tv2 let y1 = Fp.pow(tv1, c1); // 4. y1 = tv1^c1 y1 = Fp.mul(y1, tv2); // 5. y1 = y1 * tv2 const y2 = Fp.mul(y1, c2); // 6. y2 = y1 * c2 - const tv3 = Fp.mul(Fp.square(y1), v); // 7. tv3 = y1^2; 8. tv3 = tv3 * v - const isQR = Fp.equals(tv3, u); // 9. isQR = tv3 == u + const tv3 = Fp.mul(Fp.sqr(y1), v); // 7. tv3 = y1^2; 8. tv3 = tv3 * v + const isQR = Fp.eql(tv3, u); // 9. isQR = tv3 == u let y = Fp.cmov(y2, y1, isQR); // 10. y = CMOV(y2, y1, isQR) return { isValid: isQR, value: y }; // 11. return (isQR, y) isQR ? y : y*c2 }; @@ -1216,16 +1199,16 @@ export function mapToCurveSimpleSWU( return (u: T): { x: T; y: T } => { // prettier-ignore let tv1, tv2, tv3, tv4, tv5, tv6, x, y; - tv1 = Fp.square(u); // 1. tv1 = u^2 + tv1 = Fp.sqr(u); // 1. tv1 = u^2 tv1 = Fp.mul(tv1, opts.Z); // 2. tv1 = Z * tv1 - tv2 = Fp.square(tv1); // 3. tv2 = tv1^2 + tv2 = Fp.sqr(tv1); // 3. tv2 = tv1^2 tv2 = Fp.add(tv2, tv1); // 4. tv2 = tv2 + tv1 tv3 = Fp.add(tv2, Fp.ONE); // 5. tv3 = tv2 + 1 tv3 = Fp.mul(tv3, opts.B); // 6. tv3 = B * tv3 - tv4 = Fp.cmov(opts.Z, Fp.negate(tv2), !Fp.equals(tv2, Fp.ZERO)); // 7. tv4 = CMOV(Z, -tv2, tv2 != 0) + tv4 = Fp.cmov(opts.Z, Fp.neg(tv2), !Fp.eql(tv2, Fp.ZERO)); // 7. tv4 = CMOV(Z, -tv2, tv2 != 0) tv4 = Fp.mul(tv4, opts.A); // 8. tv4 = A * tv4 - tv2 = Fp.square(tv3); // 9. tv2 = tv3^2 - tv6 = Fp.square(tv4); // 10. tv6 = tv4^2 + tv2 = Fp.sqr(tv3); // 9. tv2 = tv3^2 + tv6 = Fp.sqr(tv4); // 10. tv6 = tv4^2 tv5 = Fp.mul(tv6, opts.A); // 11. tv5 = A * tv6 tv2 = Fp.add(tv2, tv5); // 12. tv2 = tv2 + tv5 tv2 = Fp.mul(tv2, tv3); // 13. tv2 = tv2 * tv3 @@ -1239,7 +1222,7 @@ export function mapToCurveSimpleSWU( x = Fp.cmov(x, tv3, isValid); // 21. x = CMOV(x, tv3, is_gx1_square) y = Fp.cmov(y, value, isValid); // 22. y = CMOV(y, y1, is_gx1_square) const e1 = Fp.isOdd!(u) === Fp.isOdd!(y); // 23. e1 = sgn0(u) == sgn0(y) - y = Fp.cmov(Fp.negate(y), y, e1); // 24. y = CMOV(-y, y, e1) + y = Fp.cmov(Fp.neg(y), y, e1); // 24. y = CMOV(-y, y, e1) x = Fp.div(x, tv4); // 25. x = x / tv4 return { x, y }; }; diff --git a/src/bls12-381.ts b/src/bls12-381.ts index 464bd82..9d21e1e 100644 --- a/src/bls12-381.ts +++ b/src/bls12-381.ts @@ -99,25 +99,24 @@ const Fp2: mod.Field & Fp2Utils = { ONE: { c0: Fp.ONE, c1: Fp.ZERO }, create: (num) => num, isValid: ({ c0, c1 }) => typeof c0 === 'bigint' && typeof c1 === 'bigint', - isZero: ({ c0, c1 }) => Fp.isZero(c0) && Fp.isZero(c1), - equals: ({ c0, c1 }: Fp2, { c0: r0, c1: r1 }: Fp2) => Fp.equals(c0, r0) && Fp.equals(c1, r1), - negate: ({ c0, c1 }) => ({ c0: Fp.negate(c0), c1: Fp.negate(c1) }), + is0: ({ c0, c1 }) => Fp.is0(c0) && Fp.is0(c1), + eql: ({ c0, c1 }: Fp2, { c0: r0, c1: r1 }: Fp2) => Fp.eql(c0, r0) && Fp.eql(c1, r1), + neg: ({ c0, c1 }) => ({ c0: Fp.neg(c0), c1: Fp.neg(c1) }), pow: (num, power) => mod.FpPow(Fp2, num, power), invertBatch: (nums) => mod.FpInvertBatch(Fp2, nums), // Normalized add: Fp2Add, sub: Fp2Subtract, mul: Fp2Multiply, - square: Fp2Square, + sqr: Fp2Square, // NonNormalized stuff addN: Fp2Add, subN: Fp2Subtract, mulN: Fp2Multiply, - squareN: Fp2Square, + sqrN: Fp2Square, // Why inversion for bigint inside Fp instead of Fp2? it is even used in that context? - div: (lhs, rhs) => - Fp2.mul(lhs, typeof rhs === 'bigint' ? Fp.invert(Fp.create(rhs)) : Fp2.invert(rhs)), - invert: ({ c0: a, c1: b }) => { + div: (lhs, rhs) => Fp2.mul(lhs, typeof rhs === 'bigint' ? Fp.inv(Fp.create(rhs)) : Fp2.inv(rhs)), + inv: ({ c0: a, c1: b }) => { // We wish to find the multiplicative inverse of a nonzero // element a + bu in Fp2. We leverage an identity // @@ -131,11 +130,11 @@ const Fp2: mod.Field & Fp2Utils = { // This gives that (a - bu)/(a² + b²) is the inverse // of (a + bu). Importantly, this can be computing using // only a single inversion in Fp. - const factor = Fp.invert(Fp.create(a * a + b * b)); + const factor = Fp.inv(Fp.create(a * a + b * b)); return { c0: Fp.mul(factor, Fp.create(a)), c1: Fp.mul(factor, Fp.create(-b)) }; }, sqrt: (num) => { - if (Fp2.equals(num, Fp2.ZERO)) return Fp2.ZERO; // Algo doesn't handles this case + if (Fp2.eql(num, Fp2.ZERO)) return Fp2.ZERO; // Algo doesn't handles this case // TODO: Optimize this line. It's extremely slow. // Speeding this up would boost aggregateSignatures. // https://eprint.iacr.org/2012/685.pdf applicable? @@ -143,15 +142,15 @@ const Fp2: mod.Field & Fp2Utils = { // https://github.com/supranational/blst/blob/aae0c7d70b799ac269ff5edf29d8191dbd357876/src/exp2.c#L1 // Inspired by https://github.com/dalek-cryptography/curve25519-dalek/blob/17698df9d4c834204f83a3574143abacb4fc81a5/src/field.rs#L99 const candidateSqrt = Fp2.pow(num, (Fp2.ORDER + 8n) / 16n); - const check = Fp2.div(Fp2.square(candidateSqrt), num); // candidateSqrt.square().div(this); + const check = Fp2.div(Fp2.sqr(candidateSqrt), num); // candidateSqrt.square().div(this); const R = FP2_ROOTS_OF_UNITY; - const divisor = [R[0], R[2], R[4], R[6]].find((r) => Fp2.equals(r, check)); + const divisor = [R[0], R[2], R[4], R[6]].find((r) => Fp2.eql(r, check)); if (!divisor) throw new Error('No root'); const index = R.indexOf(divisor); const root = R[index / 2]; if (!root) throw new Error('Invalid root'); const x1 = Fp2.div(candidateSqrt, root); - const x2 = Fp2.negate(x1); + const x2 = Fp2.neg(x1); const { re: re1, im: im1 } = Fp2.reim(x1); const { re: re2, im: im2 } = Fp2.reim(x2); if (im1 > im2 || (im1 === im2 && re1 > re2)) return x1; @@ -280,18 +279,15 @@ const Fp6Multiply = ({ c0, c1, c2 }: Fp6, rhs: Fp6 | bigint) => { }; }; const Fp6Square = ({ c0, c1, c2 }: Fp6) => { - let t0 = Fp2.square(c0); // c0² + let t0 = Fp2.sqr(c0); // c0² let t1 = Fp2.mul(Fp2.mul(c0, c1), 2n); // 2 * c0 * c1 let t3 = Fp2.mul(Fp2.mul(c1, c2), 2n); // 2 * c1 * c2 - let t4 = Fp2.square(c2); // c2² + let t4 = Fp2.sqr(c2); // c2² return { c0: Fp2.add(Fp2.mulByNonresidue(t3), t0), // T3 * (u + 1) + T0 c1: Fp2.add(Fp2.mulByNonresidue(t4), t1), // T4 * (u + 1) + T1 // T1 + (c0 - c1 + c2)² + T3 - T0 - T4 - c2: Fp2.sub( - Fp2.sub(Fp2.add(Fp2.add(t1, Fp2.square(Fp2.add(Fp2.sub(c0, c1), c2))), t3), t0), - t4 - ), + c2: Fp2.sub(Fp2.sub(Fp2.add(Fp2.add(t1, Fp2.sqr(Fp2.add(Fp2.sub(c0, c1), c2))), t3), t0), t4), }; }; type Fp6Utils = { @@ -312,35 +308,34 @@ const Fp6: mod.Field & Fp6Utils = { ONE: { c0: Fp2.ONE, c1: Fp2.ZERO, c2: Fp2.ZERO }, create: (num) => num, isValid: ({ c0, c1, c2 }) => Fp2.isValid(c0) && Fp2.isValid(c1) && Fp2.isValid(c2), - isZero: ({ c0, c1, c2 }) => Fp2.isZero(c0) && Fp2.isZero(c1) && Fp2.isZero(c2), - negate: ({ c0, c1, c2 }) => ({ c0: Fp2.negate(c0), c1: Fp2.negate(c1), c2: Fp2.negate(c2) }), - equals: ({ c0, c1, c2 }, { c0: r0, c1: r1, c2: r2 }) => - Fp2.equals(c0, r0) && Fp2.equals(c1, r1) && Fp2.equals(c2, r2), + is0: ({ c0, c1, c2 }) => Fp2.is0(c0) && Fp2.is0(c1) && Fp2.is0(c2), + neg: ({ c0, c1, c2 }) => ({ c0: Fp2.neg(c0), c1: Fp2.neg(c1), c2: Fp2.neg(c2) }), + eql: ({ c0, c1, c2 }, { c0: r0, c1: r1, c2: r2 }) => + Fp2.eql(c0, r0) && Fp2.eql(c1, r1) && Fp2.eql(c2, r2), sqrt: () => { throw new Error('Not implemented'); }, // Do we need division by bigint at all? Should be done via order: - div: (lhs, rhs) => - Fp6.mul(lhs, typeof rhs === 'bigint' ? Fp.invert(Fp.create(rhs)) : Fp6.invert(rhs)), + div: (lhs, rhs) => Fp6.mul(lhs, typeof rhs === 'bigint' ? Fp.inv(Fp.create(rhs)) : Fp6.inv(rhs)), pow: (num, power) => mod.FpPow(Fp6, num, power), invertBatch: (nums) => mod.FpInvertBatch(Fp6, nums), // Normalized add: Fp6Add, sub: Fp6Subtract, mul: Fp6Multiply, - square: Fp6Square, + sqr: Fp6Square, // NonNormalized stuff addN: Fp6Add, subN: Fp6Subtract, mulN: Fp6Multiply, - squareN: Fp6Square, + sqrN: Fp6Square, - invert: ({ c0, c1, c2 }) => { - let t0 = Fp2.sub(Fp2.square(c0), Fp2.mulByNonresidue(Fp2.mul(c2, c1))); // c0² - c2 * c1 * (u + 1) - let t1 = Fp2.sub(Fp2.mulByNonresidue(Fp2.square(c2)), Fp2.mul(c0, c1)); // c2² * (u + 1) - c0 * c1 - let t2 = Fp2.sub(Fp2.square(c1), Fp2.mul(c0, c2)); // c1² - c0 * c2 + inv: ({ c0, c1, c2 }) => { + let t0 = Fp2.sub(Fp2.sqr(c0), Fp2.mulByNonresidue(Fp2.mul(c2, c1))); // c0² - c2 * c1 * (u + 1) + let t1 = Fp2.sub(Fp2.mulByNonresidue(Fp2.sqr(c2)), Fp2.mul(c0, c1)); // c2² * (u + 1) - c0 * c1 + let t2 = Fp2.sub(Fp2.sqr(c1), Fp2.mul(c0, c2)); // c1² - c0 * c2 // 1/(((c2 * T1 + c1 * T2) * v) + c0 * T0) - let t4 = Fp2.invert( + let t4 = Fp2.inv( Fp2.add(Fp2.mulByNonresidue(Fp2.add(Fp2.mul(c2, t1), Fp2.mul(c1, t2))), Fp2.mul(c0, t0)) ); return { c0: Fp2.mul(t4, t0), c1: Fp2.mul(t4, t1), c2: Fp2.mul(t4, t2) }; @@ -498,11 +493,11 @@ const Fp12Square = ({ c0, c1 }: Fp12) => { }; // AB + AB }; function Fp4Square(a: Fp2, b: Fp2): { first: Fp2; second: Fp2 } { - const a2 = Fp2.square(a); - const b2 = Fp2.square(b); + const a2 = Fp2.sqr(a); + const b2 = Fp2.sqr(b); return { first: Fp2.add(Fp2.mulByNonresidue(b2), a2), // b² * Nonresidue + a² - second: Fp2.sub(Fp2.sub(Fp2.square(Fp2.add(a, b)), a2), b2), // (a + b)² - a² - b² + second: Fp2.sub(Fp2.sub(Fp2.sqr(Fp2.add(a, b)), a2), b2), // (a + b)² - a² - b² }; } type Fp12Utils = { @@ -525,30 +520,30 @@ const Fp12: mod.Field & Fp12Utils = { ONE: { c0: Fp6.ONE, c1: Fp6.ZERO }, create: (num) => num, isValid: ({ c0, c1 }) => Fp6.isValid(c0) && Fp6.isValid(c1), - isZero: ({ c0, c1 }) => Fp6.isZero(c0) && Fp6.isZero(c1), - negate: ({ c0, c1 }) => ({ c0: Fp6.negate(c0), c1: Fp6.negate(c1) }), - equals: ({ c0, c1 }, { c0: r0, c1: r1 }) => Fp6.equals(c0, r0) && Fp6.equals(c1, r1), + is0: ({ c0, c1 }) => Fp6.is0(c0) && Fp6.is0(c1), + neg: ({ c0, c1 }) => ({ c0: Fp6.neg(c0), c1: Fp6.neg(c1) }), + eql: ({ c0, c1 }, { c0: r0, c1: r1 }) => Fp6.eql(c0, r0) && Fp6.eql(c1, r1), sqrt: () => { throw new Error('Not implemented'); }, - invert: ({ c0, c1 }) => { - let t = Fp6.invert(Fp6.sub(Fp6.square(c0), Fp6.mulByNonresidue(Fp6.square(c1)))); // 1 / (c0² - c1² * v) - return { c0: Fp6.mul(c0, t), c1: Fp6.negate(Fp6.mul(c1, t)) }; // ((C0 * T) * T) + (-C1 * T) * w + inv: ({ c0, c1 }) => { + let t = Fp6.inv(Fp6.sub(Fp6.sqr(c0), Fp6.mulByNonresidue(Fp6.sqr(c1)))); // 1 / (c0² - c1² * v) + return { c0: Fp6.mul(c0, t), c1: Fp6.neg(Fp6.mul(c1, t)) }; // ((C0 * T) * T) + (-C1 * T) * w }, div: (lhs, rhs) => - Fp12.mul(lhs, typeof rhs === 'bigint' ? Fp.invert(Fp.create(rhs)) : Fp12.invert(rhs)), + Fp12.mul(lhs, typeof rhs === 'bigint' ? Fp.inv(Fp.create(rhs)) : Fp12.inv(rhs)), pow: (num, power) => mod.FpPow(Fp12, num, power), invertBatch: (nums) => mod.FpInvertBatch(Fp12, nums), // Normalized add: Fp12Add, sub: Fp12Subtract, mul: Fp12Multiply, - square: Fp12Square, + sqr: Fp12Square, // NonNormalized stuff addN: Fp12Add, subN: Fp12Subtract, mulN: Fp12Multiply, - squareN: Fp12Square, + sqrN: Fp12Square, // Bytes utils fromBytes: (b: Uint8Array): Fp12 => { @@ -602,7 +597,7 @@ const Fp12: mod.Field & Fp12Utils = { c0: Fp6.multiplyByFp2(c0, rhs), c1: Fp6.multiplyByFp2(c1, rhs), }), - conjugate: ({ c0, c1 }): Fp12 => ({ c0, c1: Fp6.negate(c1) }), + conjugate: ({ c0, c1 }): Fp12 => ({ c0, c1: Fp6.neg(c1) }), // A cyclotomic group is a subgroup of Fp^n defined by // GΦₙ(p) = {α ∈ Fpⁿ : α^Φₙ(p) = 1} @@ -897,7 +892,7 @@ const PSI2_C1 = 0x1a0111ea397fe699ec02408663d4de85aa0d857d89759ad4897d29650fb85f9b409427eb4f49fffd8bfd00000000aaacn; function psi2(x: Fp2, y: Fp2): [Fp2, Fp2] { - return [Fp2.mul(x, PSI2_C1), Fp2.negate(y)]; + return [Fp2.mul(x, PSI2_C1), Fp2.neg(y)]; } function G2psi2(c: ProjConstructor, P: ProjPointType) { const affine = P.toAffine(); @@ -1031,7 +1026,7 @@ export const bls12_381: CurveFn = bls({ let y = Fp.sqrt(right); if (!y) throw new Error('Invalid compressed G1 point'); const aflag = bitGet(compressedValue, C_BIT_POS); - if ((y * 2n) / P !== aflag) y = Fp.negate(y); + if ((y * 2n) / P !== aflag) y = Fp.neg(y); return { x: Fp.create(x), y: Fp.create(y) }; } else if (bytes.length === 96) { // Check if the infinity flag is set @@ -1149,7 +1144,7 @@ export const bls12_381: CurveFn = bls({ const right = Fp2.add(Fp2.pow(x, 3n), b); // y² = x³ + 4 * (u+1) = x³ + b let y = Fp2.sqrt(right); const Y_bit = y.c1 === 0n ? (y.c0 * 2n) / P : (y.c1 * 2n) / P ? 1n : 0n; - y = bitS > 0 && Y_bit > 0 ? y : Fp2.negate(y); + y = bitS > 0 && Y_bit > 0 ? y : Fp2.neg(y); return { x, y }; } else if (bytes.length === 192 && !bitC) { // Check if the infinity flag is set @@ -1216,7 +1211,7 @@ export const bls12_381: CurveFn = bls({ const aflag1 = bitGet(z1, 381); const isGreater = y1 > 0n && (y1 * 2n) / P !== aflag1; const isZero = y1 === 0n && (y0 * 2n) / P !== aflag1; - if (isGreater || isZero) y = Fp2.negate(y); + if (isGreater || isZero) y = Fp2.neg(y); const point = bls12_381.G2.ProjectivePoint.fromAffine({ x, y }); // console.log('Signature.decode', point); point.assertValidity(); diff --git a/src/ed25519.ts b/src/ed25519.ts index 7600b17..a2f1812 100644 --- a/src/ed25519.ts +++ b/src/ed25519.ts @@ -101,54 +101,54 @@ const Fp = Field(ED25519_P, undefined, true); const ELL2_C1 = (Fp.ORDER + BigInt(3)) / BigInt(8); // 1. c1 = (q + 3) / 8 # Integer arithmetic const ELL2_C2 = Fp.pow(_2n, ELL2_C1); // 2. c2 = 2^c1 -const ELL2_C3 = Fp.sqrt(Fp.negate(Fp.ONE)); // 3. c3 = sqrt(-1) +const ELL2_C3 = Fp.sqrt(Fp.neg(Fp.ONE)); // 3. c3 = sqrt(-1) const ELL2_C4 = (Fp.ORDER - BigInt(5)) / BigInt(8); // 4. c4 = (q - 5) / 8 # Integer arithmetic const ELL2_J = BigInt(486662); // prettier-ignore function map_to_curve_elligator2_curve25519(u: bigint) { - let tv1 = Fp.square(u); // 1. tv1 = u^2 + let tv1 = Fp.sqr(u); // 1. tv1 = u^2 tv1 = Fp.mul(tv1, _2n); // 2. tv1 = 2 * tv1 let xd = Fp.add(tv1, Fp.ONE); // 3. xd = tv1 + 1 # Nonzero: -1 is square (mod p), tv1 is not - let x1n = Fp.negate(ELL2_J); // 4. x1n = -J # x1 = x1n / xd = -J / (1 + 2 * u^2) - let tv2 = Fp.square(xd); // 5. tv2 = xd^2 + let x1n = Fp.neg(ELL2_J); // 4. x1n = -J # x1 = x1n / xd = -J / (1 + 2 * u^2) + let tv2 = Fp.sqr(xd); // 5. tv2 = xd^2 let gxd = Fp.mul(tv2, xd); // 6. gxd = tv2 * xd # gxd = xd^3 let gx1 = Fp.mul(tv1, ELL2_J); // 7. gx1 = J * tv1 # x1n + J * xd gx1 = Fp.mul(gx1, x1n); // 8. gx1 = gx1 * x1n # x1n^2 + J * x1n * xd gx1 = Fp.add(gx1, tv2); // 9. gx1 = gx1 + tv2 # x1n^2 + J * x1n * xd + xd^2 gx1 = Fp.mul(gx1, x1n); // 10. gx1 = gx1 * x1n # x1n^3 + J * x1n^2 * xd + x1n * xd^2 - let tv3 = Fp.square(gxd); // 11. tv3 = gxd^2 - tv2 = Fp.square(tv3); // 12. tv2 = tv3^2 # gxd^4 + let tv3 = Fp.sqr(gxd); // 11. tv3 = gxd^2 + tv2 = Fp.sqr(tv3); // 12. tv2 = tv3^2 # gxd^4 tv3 = Fp.mul(tv3, gxd); // 13. tv3 = tv3 * gxd # gxd^3 tv3 = Fp.mul(tv3, gx1); // 14. tv3 = tv3 * gx1 # gx1 * gxd^3 tv2 = Fp.mul(tv2, tv3); // 15. tv2 = tv2 * tv3 # gx1 * gxd^7 let y11 = Fp.pow(tv2, ELL2_C4); // 16. y11 = tv2^c4 # (gx1 * gxd^7)^((p - 5) / 8) y11 = Fp.mul(y11, tv3); // 17. y11 = y11 * tv3 # gx1*gxd^3*(gx1*gxd^7)^((p-5)/8) let y12 = Fp.mul(y11, ELL2_C3); // 18. y12 = y11 * c3 - tv2 = Fp.square(y11); // 19. tv2 = y11^2 + tv2 = Fp.sqr(y11); // 19. tv2 = y11^2 tv2 = Fp.mul(tv2, gxd); // 20. tv2 = tv2 * gxd - let e1 = Fp.equals(tv2, gx1); // 21. e1 = tv2 == gx1 + let e1 = Fp.eql(tv2, gx1); // 21. e1 = tv2 == gx1 let y1 = Fp.cmov(y12, y11, e1); // 22. y1 = CMOV(y12, y11, e1) # If g(x1) is square, this is its sqrt let x2n = Fp.mul(x1n, tv1); // 23. x2n = x1n * tv1 # x2 = x2n / xd = 2 * u^2 * x1n / xd let y21 = Fp.mul(y11, u); // 24. y21 = y11 * u y21 = Fp.mul(y21, ELL2_C2); // 25. y21 = y21 * c2 let y22 = Fp.mul(y21, ELL2_C3); // 26. y22 = y21 * c3 let gx2 = Fp.mul(gx1, tv1); // 27. gx2 = gx1 * tv1 # g(x2) = gx2 / gxd = 2 * u^2 * g(x1) - tv2 = Fp.square(y21); // 28. tv2 = y21^2 + tv2 = Fp.sqr(y21); // 28. tv2 = y21^2 tv2 = Fp.mul(tv2, gxd); // 29. tv2 = tv2 * gxd - let e2 = Fp.equals(tv2, gx2); // 30. e2 = tv2 == gx2 + let e2 = Fp.eql(tv2, gx2); // 30. e2 = tv2 == gx2 let y2 = Fp.cmov(y22, y21, e2); // 31. y2 = CMOV(y22, y21, e2) # If g(x2) is square, this is its sqrt - tv2 = Fp.square(y1); // 32. tv2 = y1^2 + tv2 = Fp.sqr(y1); // 32. tv2 = y1^2 tv2 = Fp.mul(tv2, gxd); // 33. tv2 = tv2 * gxd - let e3 = Fp.equals(tv2, gx1); // 34. e3 = tv2 == gx1 + let e3 = Fp.eql(tv2, gx1); // 34. e3 = tv2 == gx1 let xn = Fp.cmov(x2n, x1n, e3); // 35. xn = CMOV(x2n, x1n, e3) # If e3, x = x1, else x = x2 let y = Fp.cmov(y2, y1, e3); // 36. y = CMOV(y2, y1, e3) # If e3, y = y1, else y = y2 let e4 = Fp.isOdd(y); // 37. e4 = sgn0(y) == 1 # Fix sign of y - y = Fp.cmov(y, Fp.negate(y), e3 !== e4); // 38. y = CMOV(y, -y, e3 XOR e4) + y = Fp.cmov(y, Fp.neg(y), e3 !== e4); // 38. y = CMOV(y, -y, e3 XOR e4) return { xMn: xn, xMd: xd, yMn: y, yMd: 1n }; // 39. return (xn, xd, y, 1) } -const ELL2_C1_EDWARDS = FpSqrtEven(Fp, Fp.negate(BigInt(486664))); // sgn0(c1) MUST equal 0 +const ELL2_C1_EDWARDS = FpSqrtEven(Fp, Fp.neg(BigInt(486664))); // sgn0(c1) MUST equal 0 function map_to_curve_elligator2_edwards25519(u: bigint) { const { xMn, xMd, yMn, yMd } = map_to_curve_elligator2_curve25519(u); // 1. (xMn, xMd, yMn, yMd) = map_to_curve_elligator2_curve25519(u) let xn = Fp.mul(xMn, yMd); // 2. xn = xMn * yMd @@ -157,7 +157,7 @@ function map_to_curve_elligator2_edwards25519(u: bigint) { let yn = Fp.sub(xMn, xMd); // 5. yn = xMn - xMd let yd = Fp.add(xMn, xMd); // 6. yd = xMn + xMd # (n / d - 1) / (n / d + 1) = (n - d) / (n + d) let tv1 = Fp.mul(xd, yd); // 7. tv1 = xd * yd - let e = Fp.equals(tv1, Fp.ZERO); // 8. e = tv1 == 0 + let e = Fp.eql(tv1, Fp.ZERO); // 8. e = tv1 == 0 xn = Fp.cmov(xn, Fp.ZERO, e); // 9. xn = CMOV(xn, 0, e) xd = Fp.cmov(xd, Fp.ONE, e); // 10. xd = CMOV(xd, 1, e) yn = Fp.cmov(yn, Fp.ONE, e); // 11. yn = CMOV(yn, 1, e) diff --git a/src/ed448.ts b/src/ed448.ts index 630f227..bfe2301 100644 --- a/src/ed448.ts +++ b/src/ed448.ts @@ -59,41 +59,41 @@ const Fp = Field(ed448P, 456, true); const ELL2_C1 = (Fp.ORDER - BigInt(3)) / BigInt(4); // 1. c1 = (q - 3) / 4 # Integer arithmetic const ELL2_J = BigInt(156326); function map_to_curve_elligator2_curve448(u: bigint) { - let tv1 = Fp.square(u); // 1. tv1 = u^2 - let e1 = Fp.equals(tv1, Fp.ONE); // 2. e1 = tv1 == 1 + let tv1 = Fp.sqr(u); // 1. tv1 = u^2 + let e1 = Fp.eql(tv1, Fp.ONE); // 2. e1 = tv1 == 1 tv1 = Fp.cmov(tv1, Fp.ZERO, e1); // 3. tv1 = CMOV(tv1, 0, e1) # If Z * u^2 == -1, set tv1 = 0 let xd = Fp.sub(Fp.ONE, tv1); // 4. xd = 1 - tv1 - let x1n = Fp.negate(ELL2_J); // 5. x1n = -J - let tv2 = Fp.square(xd); // 6. tv2 = xd^2 + let x1n = Fp.neg(ELL2_J); // 5. x1n = -J + let tv2 = Fp.sqr(xd); // 6. tv2 = xd^2 let gxd = Fp.mul(tv2, xd); // 7. gxd = tv2 * xd # gxd = xd^3 - let gx1 = Fp.mul(tv1, Fp.negate(ELL2_J)); // 8. gx1 = -J * tv1 # x1n + J * xd + let gx1 = Fp.mul(tv1, Fp.neg(ELL2_J)); // 8. gx1 = -J * tv1 # x1n + J * xd gx1 = Fp.mul(gx1, x1n); // 9. gx1 = gx1 * x1n # x1n^2 + J * x1n * xd gx1 = Fp.add(gx1, tv2); // 10. gx1 = gx1 + tv2 # x1n^2 + J * x1n * xd + xd^2 gx1 = Fp.mul(gx1, x1n); // 11. gx1 = gx1 * x1n # x1n^3 + J * x1n^2 * xd + x1n * xd^2 - let tv3 = Fp.square(gxd); // 12. tv3 = gxd^2 + let tv3 = Fp.sqr(gxd); // 12. tv3 = gxd^2 tv2 = Fp.mul(gx1, gxd); // 13. tv2 = gx1 * gxd # gx1 * gxd tv3 = Fp.mul(tv3, tv2); // 14. tv3 = tv3 * tv2 # gx1 * gxd^3 let y1 = Fp.pow(tv3, ELL2_C1); // 15. y1 = tv3^c1 # (gx1 * gxd^3)^((p - 3) / 4) y1 = Fp.mul(y1, tv2); // 16. y1 = y1 * tv2 # gx1 * gxd * (gx1 * gxd^3)^((p - 3) / 4) - let x2n = Fp.mul(x1n, Fp.negate(tv1)); // 17. x2n = -tv1 * x1n # x2 = x2n / xd = -1 * u^2 * x1n / xd + let x2n = Fp.mul(x1n, Fp.neg(tv1)); // 17. x2n = -tv1 * x1n # x2 = x2n / xd = -1 * u^2 * x1n / xd let y2 = Fp.mul(y1, u); // 18. y2 = y1 * u y2 = Fp.cmov(y2, Fp.ZERO, e1); // 19. y2 = CMOV(y2, 0, e1) - tv2 = Fp.square(y1); // 20. tv2 = y1^2 + tv2 = Fp.sqr(y1); // 20. tv2 = y1^2 tv2 = Fp.mul(tv2, gxd); // 21. tv2 = tv2 * gxd - let e2 = Fp.equals(tv2, gx1); // 22. e2 = tv2 == gx1 + let e2 = Fp.eql(tv2, gx1); // 22. e2 = tv2 == gx1 let xn = Fp.cmov(x2n, x1n, e2); // 23. xn = CMOV(x2n, x1n, e2) # If e2, x = x1, else x = x2 let y = Fp.cmov(y2, y1, e2); // 24. y = CMOV(y2, y1, e2) # If e2, y = y1, else y = y2 let e3 = Fp.isOdd(y); // 25. e3 = sgn0(y) == 1 # Fix sign of y - y = Fp.cmov(y, Fp.negate(y), e2 !== e3); // 26. y = CMOV(y, -y, e2 XOR e3) + y = Fp.cmov(y, Fp.neg(y), e2 !== e3); // 26. y = CMOV(y, -y, e2 XOR e3) return { xn, xd, yn: y, yd: Fp.ONE }; // 27. return (xn, xd, y, 1) } function map_to_curve_elligator2_edwards448(u: bigint) { let { xn, xd, yn, yd } = map_to_curve_elligator2_curve448(u); // 1. (xn, xd, yn, yd) = map_to_curve_elligator2_curve448(u) - let xn2 = Fp.square(xn); // 2. xn2 = xn^2 - let xd2 = Fp.square(xd); // 3. xd2 = xd^2 - let xd4 = Fp.square(xd2); // 4. xd4 = xd2^2 - let yn2 = Fp.square(yn); // 5. yn2 = yn^2 - let yd2 = Fp.square(yd); // 6. yd2 = yd^2 + let xn2 = Fp.sqr(xn); // 2. xn2 = xn^2 + let xd2 = Fp.sqr(xd); // 3. xd2 = xd^2 + let xd4 = Fp.sqr(xd2); // 4. xd4 = xd2^2 + let yn2 = Fp.sqr(yn); // 5. yn2 = yn^2 + let yd2 = Fp.sqr(yd); // 6. yd2 = yd^2 let xEn = Fp.sub(xn2, xd2); // 7. xEn = xn2 - xd2 let tv2 = Fp.sub(xEn, xd2); // 8. tv2 = xEn - xd2 xEn = Fp.mul(xEn, xd2); // 9. xEn = xEn * xd2 @@ -120,7 +120,7 @@ function map_to_curve_elligator2_edwards448(u: bigint) { tv4 = Fp.mul(tv4, yd2); // 30. tv4 = tv4 * yd2 yEd = Fp.add(yEd, tv4); // 31. yEd = yEd + tv4 tv1 = Fp.mul(xEd, yEd); // 32. tv1 = xEd * yEd - let e = Fp.equals(tv1, Fp.ZERO); // 33. e = tv1 == 0 + let e = Fp.eql(tv1, Fp.ZERO); // 33. e = tv1 == 0 xEn = Fp.cmov(xEn, Fp.ZERO, e); // 34. xEn = CMOV(xEn, 0, e) xEd = Fp.cmov(xEd, Fp.ONE, e); // 35. xEd = CMOV(xEd, 1, e) yEn = Fp.cmov(yEn, Fp.ONE, e); // 36. yEn = CMOV(yEn, 1, e) diff --git a/src/secp256k1.ts b/src/secp256k1.ts index 01c36fa..90ed701 100644 --- a/src/secp256k1.ts +++ b/src/secp256k1.ts @@ -29,10 +29,7 @@ const _2n = BigInt(2); const divNearest = (a: bigint, b: bigint) => (a + b / _2n) / b; /** - * Allows to compute square root √y 2x faster. - * To calculate √y, we need to exponentiate it to a very big number: - * `y² = x³ + ax + b; y = y² ^ (p+1)/4` - * We are unwrapping the loop and multiplying it bit-by-bit. + * √n = n^((p+1)/4) for fields p = 3 mod 4. We unwrap the loop and multiply bit-by-bit. * (P+1n/4n).toString(2) would produce bits [223x 1, 0, 22x 1, 4x 0, 11, 00] */ function sqrtMod(y: bigint): bigint { @@ -55,7 +52,7 @@ function sqrtMod(y: bigint): bigint { const t1 = (pow2(b223, _23n, P) * b22) % P; const t2 = (pow2(t1, _6n, P) * b2) % P; const root = pow2(t2, _2n, P); - if (!Fp.equals(Fp.square(root), y)) throw new Error('Cannot find square root'); + if (!Fp.eql(Fp.sqr(root), y)) throw new Error('Cannot find square root'); return root; } @@ -108,7 +105,9 @@ export const secp256k1 = createCurve( sha256 ); -// Schnorr +// Schnorr signatures are superior to ECDSA from above. +// Below is Schnorr-specific code as per BIP0340. +// https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki const _0n = BigInt(0); const fe = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1P; const ge = (x: bigint) => typeof x === 'bigint' && _0n < x && x < secp256k1N; @@ -130,8 +129,6 @@ export function taggedHash(tag: string, ...messages: Uint8Array[]): Uint8Array { } return sha256(concatBytes(tagP, ...messages)); } -// Schnorr signatures are superior to ECDSA from above. -// Below is Schnorr-specific code as per BIP0340. const tag = taggedHash; const toRawX = (point: PointType) => point.toRawBytes(true).slice(1); @@ -142,22 +139,29 @@ const Gmul = (priv: PrivKey) => _Point.fromPrivateKey(priv); const GmulAdd = (Q: PointType, a: bigint, b: bigint) => _Point.BASE.multiplyAndAddUnsafe(Q, a, b); function schnorrGetScalar(priv: bigint) { + // Let d' = int(sk) + // Fail if d' = 0 or d' ≥ n + // Let P = d'⋅G + // Let d = d' if has_even_y(P), otherwise let d = n - d' . const point = Gmul(priv); const scalar = point.hasEvenY() ? priv : modN(-priv); return { point, scalar, x: toRawX(point) }; } -function lift_x(x: bigint) { +function lift_x(x: bigint): PointType { if (!fe(x)) throw new Error('not fe'); // Fail if x ≥ p. - const c = mod(x * x * x + BigInt(7), secp256k1P); // Let c = x3 + 7 mod p. + const c = mod(x * x * x + BigInt(7), secp256k1P); // Let c = x³ + 7 mod p. let y = sqrtMod(c); // Let y = c^(p+1)/4 mod p. if (y % 2n !== 0n) y = mod(-y, secp256k1P); // Return the unique point P such that x(P) = x and const p = new _Point(x, y, _1n); // y(P) = y if y mod 2 = 0 or y(P) = p-y otherwise. p.assertValidity(); return p; } -function challenge(...args: Uint8Array[]) { +function challenge(...args: Uint8Array[]): bigint { return modN(bytesToNum(tag(TAGS.challenge, ...args))); } +function schnorrGetPublicKey(privateKey: PrivKey): Uint8Array { + return toRawX(Gmul(privateKey)); // Let d' = int(sk). Fail if d' = 0 or d' ≥ n. Return bytes(d'⋅G) +} /** * Synchronously creates Schnorr signature. Improved security: verifies itself before * producing an output. @@ -169,18 +173,20 @@ function schnorrSign(message: Hex, privateKey: Hex, auxRand: Hex = randomBytes(3 if (message == null) throw new Error(`sign: Expected valid message, not "${message}"`); const m = ensureBytes(message); // checks for isWithinCurveOrder + const { x: px, scalar: d } = schnorrGetScalar(bytesToNum(ensureBytes(privateKey, 32))); const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array // TODO: replace with proper xor? - const t = numTo32b(d ^ bytesToNum(tag(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hashBIP0340/aux(a) - const rand = tag(TAGS.nonce, t, px, m); // Let rand = hashBIP0340/nonce(t || bytes(P) || m) + const t = numTo32b(d ^ bytesToNum(tag(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a) + const rand = tag(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m) const k_ = modN(bytesToNum(rand)); // Let k' = int(rand) mod n if (k_ === _0n) throw new Error('sign failed: k is zero'); // Fail if k' = 0. - const { point: R, x: rx, scalar: k } = schnorrGetScalar(k_); - const e = challenge(rx, px, m); + const { point: R, x: rx, scalar: k } = schnorrGetScalar(k_); // Let R = k'⋅G. + const e = challenge(rx, px, m); // Let e = int(hash/challenge(bytes(R) || bytes(P) || m)) mod n. const sig = new Uint8Array(64); // Let sig = bytes(R) || bytes((k + ed) mod n). sig.set(numTo32b(R.px), 0); sig.set(numTo32b(modN(k + e * d)), 32); + // If Verify(bytes(P), m, sig) (see below) returns failure, abort if (!schnorrVerify(sig, m, px)) throw new Error('sign: Invalid signature produced'); return sig; } @@ -200,7 +206,7 @@ function schnorrVerify(signature: Hex, message: Hex, publicKey: Hex): boolean { const e = challenge(numTo32b(r), toRawX(P), m); // int(challenge(bytes(r)||bytes(P)||m)) mod n const R = GmulAdd(P, s, modN(-e)); // R = s⋅G - e⋅P if (!R || !R.hasEvenY() || R.toAffine().x !== r) return false; // -eP == (n-e)P - return true; + return true; // Fail if is_infinite(R) / not has_even_y(R) / x(R) ≠ r. } catch (error) { return false; } @@ -208,7 +214,7 @@ function schnorrVerify(signature: Hex, message: Hex, publicKey: Hex): boolean { export const schnorr = { // Schnorr's pubkey is just `x` of Point (BIP340) - getPublicKey: (privateKey: PrivKey): Uint8Array => toRawX(Gmul(privateKey)), + getPublicKey: schnorrGetPublicKey, sign: schnorrSign, verify: schnorrVerify, }; diff --git a/src/stark.ts b/src/stark.ts index 271362e..e747b4c 100644 --- a/src/stark.ts +++ b/src/stark.ts @@ -301,7 +301,7 @@ export function _poseidonMDS(Fp: Field, name: string, m: number, attempt } if (new Set([...x_values, ...y_values]).size !== 2 * m) throw new Error('X and Y values are not distinct'); - return x_values.map((x) => y_values.map((y) => Fp.invert(Fp.sub(x, y)))); + return x_values.map((x) => y_values.map((y) => Fp.inv(Fp.sub(x, y)))); } const MDS_SMALL = [