From c75129e629c46c5bdc222be93af1e5943eac4ee3 Mon Sep 17 00:00:00 2001 From: Paul Miller Date: Sat, 28 Jan 2023 02:19:46 +0000 Subject: [PATCH] Use declarative curve field validation --- src/abstract/curve.ts | 30 ++++++----- src/abstract/edwards.ts | 67 ++++++++++-------------- src/abstract/hash-to-curve.ts | 2 +- src/abstract/modular.ts | 25 +++++---- src/abstract/montgomery.ts | 86 ++++++++++--------------------- src/abstract/poseidon.ts | 2 +- src/abstract/utils.ts | 67 ++++++++++++++++-------- src/abstract/weierstrass.ts | 97 ++++++++++++++++------------------- src/p224.ts | 2 +- src/secp256k1.ts | 1 - test/nist.test.js | 7 +-- 11 files changed, 179 insertions(+), 207 deletions(-) diff --git a/src/abstract/curve.ts b/src/abstract/curve.ts index d2add41..0cbd80a 100644 --- a/src/abstract/curve.ts +++ b/src/abstract/curve.ts @@ -1,6 +1,7 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ // Abelian group utilities import { Field, validateField, nLength } from './modular.js'; +import { validateObject } from './utils.js'; const _0n = BigInt(0); const _1n = BigInt(1); @@ -153,7 +154,7 @@ export function wNAF>(c: GroupConstructor, bits: number) { // Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok. // Though generator can be different (Fp2 / Fp6 for BLS). -export type AbstractCurve = { +export type BasicCurve = { Fp: Field; // Field over which we'll do calculations (Fp) n: bigint; // Curve order, total count of valid points in the field nBitLength?: number; // bit length of curve order @@ -165,20 +166,21 @@ export type AbstractCurve = { allowInfinityPoint?: boolean; // bls12-381 requires it. ZERO point is valid, but invalid pubkey }; -export function validateAbsOpts(curve: AbstractCurve & T) { +export function validateBasic(curve: BasicCurve & T) { validateField(curve.Fp); - for (const i of ['n', 'h'] as const) { - const val = curve[i]; - if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`); - } - if (!curve.Fp.isValid(curve.Gx)) throw new Error('Invalid generator X coordinate Fp element'); - if (!curve.Fp.isValid(curve.Gy)) throw new Error('Invalid generator Y coordinate Fp element'); - - for (const i of ['nBitLength', 'nByteLength'] as const) { - const val = curve[i]; - if (val === undefined) continue; // Optional - if (!Number.isSafeInteger(val)) throw new Error(`Invalid param ${i}=${val} (${typeof val})`); - } + validateObject( + curve, + { + n: 'bigint', + h: 'bigint', + Gx: 'field', + Gy: 'field', + }, + { + nBitLength: 'isSafeInteger', + nByteLength: 'isSafeInteger', + } + ); // Set defaults return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const); } diff --git a/src/abstract/edwards.ts b/src/abstract/edwards.ts index e37abb0..d1c0708 100644 --- a/src/abstract/edwards.ts +++ b/src/abstract/edwards.ts @@ -1,23 +1,9 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ // Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y² import { mod } from './modular.js'; -import { - bytesToHex, - bytesToNumberLE, - concatBytes, - ensureBytes, - FHash, - Hex, - numberToBytesLE, -} from './utils.js'; -import { - Group, - GroupConstructor, - wNAF, - AbstractCurve, - validateAbsOpts, - AffinePoint, -} from './curve.js'; +import * as ut from './utils.js'; +import { ensureBytes, FHash, Hex } from './utils.js'; +import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js'; // Be friendly to bad ECMAScript parsers by not using bigint literals like 123n const _0n = BigInt(0); @@ -26,7 +12,7 @@ const _2n = BigInt(2); const _8n = BigInt(8); // Edwards curves must declare params a & d. -export type CurveType = AbstractCurve & { +export type CurveType = BasicCurve & { a: bigint; // curve param a d: bigint; // curve param d hash: FHash; // Hashing @@ -39,19 +25,22 @@ export type CurveType = AbstractCurve & { }; function validateOpts(curve: CurveType) { - const opts = validateAbsOpts(curve); - if (typeof opts.hash !== 'function') throw new Error('Invalid hash function'); - for (const i of ['a', 'd'] as const) { - const val = opts[i]; - if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`); - } - for (const fn of ['randomBytes'] as const) { - if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`); - } - for (const fn of ['adjustScalarBytes', 'domain', 'uvRatio', 'mapToCurve'] as const) { - if (opts[fn] === undefined) continue; // Optional - if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`); - } + const opts = validateBasic(curve); + ut.validateObject( + curve, + { + hash: 'function', + a: 'bigint', + d: 'bigint', + randomBytes: 'function', + }, + { + adjustScalarBytes: 'function', + domain: 'function', + uvRatio: 'function', + mapToCurve: 'function', + } + ); // Set defaults return Object.freeze({ ...opts } as const); } @@ -75,7 +64,7 @@ export interface ExtPointConstructor extends GroupConstructor { new (x: bigint, y: bigint, z: bigint, t: bigint): ExtPointType; fromAffine(p: AffinePoint): ExtPointType; fromHex(hex: Hex): ExtPointType; - fromPrivateKey(privateKey: Hex): ExtPointType; // TODO: remove + fromPrivateKey(privateKey: Hex): ExtPointType; } export type CurveFn = { @@ -340,7 +329,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { const normed = hex.slice(); // copy again, we'll manipulate it const lastByte = hex[len - 1]; // select last byte normed[len - 1] = lastByte & ~0x80; // clear last bit - const y = bytesToNumberLE(normed); + const y = ut.bytesToNumberLE(normed); if (y === _0n) { // y=0 is allowed } else { @@ -366,12 +355,12 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { } toRawBytes(): Uint8Array { const { x, y } = this.toAffine(); - const bytes = numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y) + const bytes = ut.numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y) bytes[bytes.length - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y return bytes; // and use the last byte to encode sign of x } toHex(): string { - return bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string. + return ut.bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string. } } const { BASE: G, ZERO: I } = Point; @@ -382,7 +371,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { } // Little-endian SHA512 with modulo n function modN_LE(hash: Uint8Array): bigint { - return modN(bytesToNumberLE(hash)); + return modN(ut.bytesToNumberLE(hash)); } function isHex(item: Hex, err: string) { if (typeof item !== 'string' && !(item instanceof Uint8Array)) @@ -411,7 +400,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { // int('LE', SHA512(dom2(F, C) || msgs)) mod N function hashDomainToScalar(context: Hex = new Uint8Array(), ...msgs: Uint8Array[]) { - const msg = concatBytes(...msgs); + const msg = ut.concatBytes(...msgs); return modN_LE(cHash(domain(msg, ensureBytes(context), !!preHash))); } @@ -426,7 +415,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { const k = hashDomainToScalar(context, R, pointBytes, msg); // R || A || PH(M) const s = modN(r + k * scalar); // S = (r + k * s) mod L assertGE0(s); // 0 <= s < l - const res = concatBytes(R, numberToBytesLE(s, Fp.BYTES)); + const res = ut.concatBytes(R, ut.numberToBytesLE(s, Fp.BYTES)); return ensureBytes(res, nByteLength * 2); // 64-byte signature } @@ -439,7 +428,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn { if (preHash) msg = preHash(msg); // for ed25519ph, etc const A = Point.fromHex(publicKey, false); // Check for s bounds, hex validity const R = Point.fromHex(sig.slice(0, len), false); // 0 <= R < 2^256: ZIP215 R can be >= P - const s = bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l + const s = ut.bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l const SB = G.multiplyUnsafe(s); const k = hashDomainToScalar(context, R.toRawBytes(), A.toRawBytes(), msg); const RkA = R.add(A.multiplyUnsafe(k)); diff --git a/src/abstract/hash-to-curve.ts b/src/abstract/hash-to-curve.ts index 75c94ad..47e6211 100644 --- a/src/abstract/hash-to-curve.ts +++ b/src/abstract/hash-to-curve.ts @@ -45,7 +45,7 @@ declare const TextDecoder: any; export function stringToBytes(str: string): Uint8Array { if (typeof str !== 'string') { - throw new TypeError(`utf8ToBytes expected string, got ${typeof str}`); + throw new Error(`utf8ToBytes expected string, got ${typeof str}`); } return new TextEncoder().encode(str); } diff --git a/src/abstract/modular.ts b/src/abstract/modular.ts index 0060edb..3b44c2d 100644 --- a/src/abstract/modular.ts +++ b/src/abstract/modular.ts @@ -7,6 +7,7 @@ import { bytesToNumberBE, bytesToNumberLE, ensureBytes, + validateObject, } from './utils.js'; // prettier-ignore const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3); @@ -40,7 +41,6 @@ export function pow(num: bigint, power: bigint, modulo: bigint): bigint { } // Does x ^ (2 ^ power) mod p. pow2(30, 4) == 30 ^ (2 ^ 4) -// TODO: Fp version? export function pow2(x: bigint, power: bigint, modulo: bigint): bigint { let res = x; while (power-- > _0n) { @@ -249,18 +249,17 @@ const FIELD_FIELDS = [ 'addN', 'subN', 'mulN', 'sqrN' ] as const; export function validateField(field: Field) { - for (const i of ['ORDER', 'MASK'] as const) { - if (typeof field[i] !== 'bigint') - throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); - } - for (const i of ['BYTES', 'BITS'] as const) { - if (typeof field[i] !== 'number') - throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); - } - for (const i of FIELD_FIELDS) { - if (typeof field[i] !== 'function') - throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); - } + const initial = { + ORDER: 'bigint', + MASK: 'bigint', + BYTES: 'isSafeInteger', + BITS: 'isSafeInteger', + } as Record; + const opts = FIELD_FIELDS.reduce((map, val: string) => { + map[val] = 'function'; + return map; + }, initial); + return validateObject(field, opts); } // Generic field functions diff --git a/src/abstract/montgomery.ts b/src/abstract/montgomery.ts index b658d87..69a6c8e 100644 --- a/src/abstract/montgomery.ts +++ b/src/abstract/montgomery.ts @@ -1,14 +1,13 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ import { mod, pow } from './modular.js'; -import { ensureBytes, numberToBytesLE, bytesToNumberLE } from './utils.js'; +import { bytesToNumberLE, ensureBytes, numberToBytesLE, validateObject } from './utils.js'; const _0n = BigInt(0); const _1n = BigInt(1); type Hex = string | Uint8Array; export type CurveType = { - // Field over which we'll do calculations. Verify with: - P: bigint; + P: bigint; // finite field prime nByteLength: number; adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array; domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array; @@ -27,24 +26,20 @@ export type CurveFn = { }; function validateOpts(curve: CurveType) { - for (const i of ['a24'] as const) { - if (typeof curve[i] !== 'bigint') - throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); - } - for (const i of ['montgomeryBits', 'nByteLength'] as const) { - if (curve[i] === undefined) continue; // Optional - if (!Number.isSafeInteger(curve[i])) - throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); - } - for (const fn of ['adjustScalarBytes', 'domain', 'powPminus2'] as const) { - if (curve[fn] === undefined) continue; // Optional - if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`); - } - for (const i of ['Gu'] as const) { - if (curve[i] === undefined) continue; // Optional - if (typeof curve[i] !== 'string') - throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); - } + validateObject( + curve, + { + a24: 'bigint', + }, + { + montgomeryBits: 'isSafeInteger', + nByteLength: 'isSafeInteger', + adjustScalarBytes: 'function', + domain: 'function', + powPminus2: 'function', + Gu: 'string', + } + ); // Set defaults return Object.freeze({ ...curve } as const); } @@ -61,27 +56,7 @@ export function montgomery(curveDef: CurveType): CurveFn { const adjustScalarBytes = CURVE.adjustScalarBytes || ((bytes: Uint8Array) => bytes); const powPminus2 = CURVE.powPminus2 || ((x: bigint) => pow(x, P - BigInt(2), P)); - /** - * Checks for num to be in range: - * For strict == true: `0 < num < max`. - * For strict == false: `0 <= num < max`. - * Converts non-float safe numbers to bigints. - */ - function normalizeScalar(num: bigint, max: bigint, strict = true): bigint { - if (!max) throw new TypeError('Specify max value'); - if (typeof num === 'number' && Number.isSafeInteger(num)) num = BigInt(num); - if (typeof num === 'bigint' && num < max) { - if (strict) { - if (_0n < num) return num; - } else { - if (_0n <= num) return num; - } - } - throw new TypeError('Expected valid scalar: 0 < scalar < max'); - } - - // cswap from RFC7748 - // NOTE: cswap is not from RFC7748! + // cswap from RFC7748. But it is not from RFC7748! /* cswap(swap, x_2, x_3): dummy = mask(swap) AND (x_2 XOR x_3) @@ -98,6 +73,11 @@ export function montgomery(curveDef: CurveType): CurveFn { return [x_2, x_3]; } + function assertFieldElement(n: bigint): bigint { + if (typeof n === 'bigint' && _0n <= n && n < P) return n; + throw new Error('Expected valid scalar 0 < scalar < CURVE.P'); + } + // x25519 from 4 /** * @@ -106,11 +86,10 @@ export function montgomery(curveDef: CurveType): CurveFn { * @returns new Point on Montgomery curve */ function montgomeryLadder(pointU: bigint, scalar: bigint): bigint { - const { P } = CURVE; - const u = normalizeScalar(pointU, P); + const u = assertFieldElement(pointU); // Section 5: Implementations MUST accept non-canonical values and process them as // if they had been reduced modulo the field prime. - const k = normalizeScalar(scalar, P); + const k = assertFieldElement(scalar); // The constant a24 is (486662 - 2) / 4 = 121665 for curve25519/X25519 const a24 = CURVE.a24; const x_1 = u; @@ -166,28 +145,20 @@ export function montgomery(curveDef: CurveType): CurveFn { } function decodeUCoordinate(uEnc: Hex): bigint { - const u = ensureBytes(uEnc, montgomeryBytes); // Section 5: When receiving such an array, implementations of X25519 // MUST mask the most significant bit in the final byte. // This is very ugly way, but it works because fieldLen-1 is outside of bounds for X448, so this becomes NOOP // fieldLen - scalaryBytes = 1 for X448 and = 0 for X25519 + const u = ensureBytes(uEnc, montgomeryBytes); u[fieldLen - 1] &= 127; // 0b0111_1111 return bytesToNumberLE(u); } - function decodeScalar(n: Hex): bigint { const bytes = ensureBytes(n); if (bytes.length !== montgomeryBytes && bytes.length !== fieldLen) throw new Error(`Expected ${montgomeryBytes} or ${fieldLen} bytes, got ${bytes.length}`); return bytesToNumberLE(adjustScalarBytes(bytes)); } - /** - * Computes shared secret between private key "scalar" and public key's "u" (x) coordinate. - * We can get 'y' coordinate from 'u', - * but Point.fromHex also wants 'x' coordinate oddity flag, - * and we cannot get 'x' without knowing 'v'. - * Need to add generic conversion between twisted edwards and complimentary curve for JubJub. - */ function scalarMult(scalar: Hex, u: Hex): Uint8Array { const pointU = decodeUCoordinate(u); const _scalar = decodeScalar(scalar); @@ -197,12 +168,7 @@ export function montgomery(curveDef: CurveType): CurveFn { if (pu === _0n) throw new Error('Invalid private or public key received'); return encodeUCoordinate(pu); } - /** - * Computes public key from private. - * Executes scalar multiplication of curve's base point by scalar. - * @param scalar private key - * @returns new public key - */ + // Computes public key from private. By doing scalar multiplication of base point. function scalarMultBase(scalar: Hex): Uint8Array { return scalarMult(scalar, CURVE.Gu); } diff --git a/src/abstract/poseidon.ts b/src/abstract/poseidon.ts index 7e5052d..302913f 100644 --- a/src/abstract/poseidon.ts +++ b/src/abstract/poseidon.ts @@ -1,6 +1,6 @@ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ // Poseidon Hash: https://eprint.iacr.org/2019/458.pdf, https://www.poseidon-hash.info -import { Field, validateField, FpPow } from './modular.js'; +import { Field, FpPow, validateField } from './modular.js'; // We don't provide any constants, since different implementations use different constants. // For reference constants see './test/poseidon.test.js'. export type PoseidonOpts = { diff --git a/src/abstract/utils.ts b/src/abstract/utils.ts index 1fb207f..b7d4550 100644 --- a/src/abstract/utils.ts +++ b/src/abstract/utils.ts @@ -18,7 +18,7 @@ export type FHash = (message: Uint8Array | string) => Uint8Array; const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0')); export function bytesToHex(bytes: Uint8Array): string { - if (!u8a(bytes)) throw new Error('Expected Uint8Array'); + if (!u8a(bytes)) throw new Error('Uint8Array expected'); // pre-caching improves the speed 6x let hex = ''; for (let i = 0; i < bytes.length; i++) { @@ -33,21 +33,21 @@ export function numberToHexUnpadded(num: number | bigint): string { } export function hexToNumber(hex: string): bigint { - if (typeof hex !== 'string') throw new Error('hexToNumber: expected string, got ' + typeof hex); + if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex); // Big Endian return BigInt(`0x${hex}`); } // Caching slows it down 2-3x export function hexToBytes(hex: string): Uint8Array { - if (typeof hex !== 'string') throw new Error('hexToBytes: expected string, got ' + typeof hex); - if (hex.length % 2) throw new Error('hexToBytes: received invalid unpadded hex ' + hex.length); + if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex); + if (hex.length % 2) throw new Error('hex string is invalid: unpadded ' + hex.length); const array = new Uint8Array(hex.length / 2); for (let i = 0; i < array.length; i++) { const j = i * 2; const hexByte = hex.slice(j, j + 2); const byte = Number.parseInt(hexByte, 16); - if (Number.isNaN(byte) || byte < 0) throw new Error('Invalid byte sequence'); + if (Number.isNaN(byte) || byte < 0) throw new Error('invalid byte sequence'); array[i] = byte; } return array; @@ -58,7 +58,7 @@ export function bytesToNumberBE(bytes: Uint8Array): bigint { return hexToNumber(bytesToHex(bytes)); } export function bytesToNumberLE(bytes: Uint8Array): bigint { - if (!u8a(bytes)) throw new Error('Expected Uint8Array'); + if (!u8a(bytes)) throw new Error('Uint8Array expected'); return hexToNumber(bytesToHex(Uint8Array.from(bytes).reverse())); } @@ -66,11 +66,7 @@ export const numberToBytesBE = (n: bigint, len: number) => hexToBytes(n.toString(16).padStart(len * 2, '0')); export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse(); // Returns variable number bytes (minimal bigint encoding?) -export const numberToVarBytesBE = (n: bigint) => { - let hex = n.toString(16); - if (hex.length & 1) hex = '0' + hex; - return hexToBytes(hex); -}; +export const numberToVarBytesBE = (n: bigint) => hexToBytes(numberToHexUnpadded(n)); export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array { // Uint8Array.from() instead of hash.slice() because node.js Buffer @@ -82,17 +78,15 @@ export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array { } // Copies several Uint8Arrays into one. -export function concatBytes(...arrays: Uint8Array[]): Uint8Array { - if (!arrays.every((b) => u8a(b))) throw new Error('Uint8Array list expected'); - if (arrays.length === 1) return arrays[0]; - const length = arrays.reduce((a, arr) => a + arr.length, 0); - const result = new Uint8Array(length); - for (let i = 0, pad = 0; i < arrays.length; i++) { - const arr = arrays[i]; - result.set(arr, pad); - pad += arr.length; - } - return result; +export function concatBytes(...arrs: Uint8Array[]): Uint8Array { + const r = new Uint8Array(arrs.reduce((sum, a) => sum + a.length, 0)); + let pad = 0; // walk through each item, ensure they have proper type + arrs.forEach((a) => { + if (!u8a(a)) throw new Error('Uint8Array expected'); + r.set(a, pad); + pad += a.length; + }); + return r; } export function equalBytes(b1: Uint8Array, b2: Uint8Array) { @@ -119,3 +113,32 @@ export const bitSet = (n: bigint, pos: number, value: boolean) => // Return mask for N bits (Same as BigInt(`0b${Array(i).fill('1').join('')}`)) // Not using ** operator with bigints for old engines. export const bitMask = (n: number) => (_2n << BigInt(n - 1)) - _1n; + +type ValMap = Record; +export function validateObject(object: object, validators: ValMap, optValidators: ValMap = {}) { + const validatorFns: Record boolean> = { + bigint: (val) => typeof val === 'bigint', + function: (val) => typeof val === 'function', + boolean: (val) => typeof val === 'boolean', + string: (val) => typeof val === 'string', + isSafeInteger: (val) => Number.isSafeInteger(val), + array: (val) => Array.isArray(val), + field: (val) => (object as any).Fp.isValid(val), + hash: (val) => typeof val === 'function' && Number.isSafeInteger(val.outputLen), + }; + // type Key = keyof typeof validators; + const checkField = (fieldName: string, type: string, isOptional: boolean) => { + const checkVal = validatorFns[type]; + if (typeof checkVal !== 'function') + throw new Error(`Invalid validator "${type}", expected function`); + + const val = object[fieldName as keyof typeof object]; + if (isOptional && val === undefined) return; + if (!checkVal(val)) { + throw new Error(`Invalid param ${fieldName}=${val} (${typeof val}), expected ${type}`); + } + }; + for (let [fieldName, type] of Object.entries(validators)) checkField(fieldName, type, false); + for (let [fieldName, type] of Object.entries(optValidators)) checkField(fieldName, type, true); + return object; +} diff --git a/src/abstract/weierstrass.ts b/src/abstract/weierstrass.ts index f865473..6af7341 100644 --- a/src/abstract/weierstrass.ts +++ b/src/abstract/weierstrass.ts @@ -2,15 +2,8 @@ // Short Weierstrass curve. The formula is: y² = x³ + ax + b import * as mod from './modular.js'; import * as ut from './utils.js'; -import { Hex, PrivKey, ensureBytes, CHash } from './utils.js'; -import { - Group, - GroupConstructor, - wNAF, - AbstractCurve, - validateAbsOpts, - AffinePoint, -} from './curve.js'; +import { CHash, Hex, PrivKey, ensureBytes } from './utils.js'; +import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js'; export type { AffinePoint }; type HmacFnSync = (key: Uint8Array, ...messages: Uint8Array[]) => Uint8Array; @@ -18,7 +11,7 @@ type EndomorphismOpts = { beta: bigint; splitScalar: (k: bigint) => { k1neg: boolean; k1: bigint; k2neg: boolean; k2: bigint }; }; -export type BasicCurve = AbstractCurve & { +export type BasicWCurve = BasicCurve & { // Params: a, b a: T; b: T; @@ -86,34 +79,32 @@ export interface ProjConstructor extends GroupConstructor> { normalizeZ(points: ProjPointType[]): ProjPointType[]; } -export type CurvePointsType = BasicCurve & { +export type CurvePointsType = BasicWCurve & { // Bytes fromBytes: (bytes: Uint8Array) => AffinePoint; toBytes: (c: ProjConstructor, point: ProjPointType, compressed: boolean) => Uint8Array; }; function validatePointOpts(curve: CurvePointsType) { - const opts = validateAbsOpts(curve); - const Fp = opts.Fp; - for (const i of ['a', 'b'] as const) { - if (!Fp.isValid(curve[i])) - throw new Error(`Invalid curve param ${i}=${opts[i]} (${typeof opts[i]})`); - } - for (const i of ['allowedPrivateKeyLengths'] as const) { - if (curve[i] === undefined) continue; // Optional - if (!Array.isArray(curve[i])) throw new Error(`Invalid ${i} array`); - } - for (const i of ['wrapPrivateKey'] as const) { - if (curve[i] === undefined) continue; // Optional - if (typeof curve[i] !== 'boolean') throw new Error(`Invalid ${i} boolean`); - } - for (const i of ['isTorsionFree', 'clearCofactor'] as const) { - if (curve[i] === undefined) continue; // Optional - if (typeof curve[i] !== 'function') throw new Error(`Invalid ${i} function`); - } - const endo = opts.endo; + const opts = validateBasic(curve); + ut.validateObject( + opts, + { + a: 'field', + b: 'field', + fromBytes: 'function', + toBytes: 'function', + }, + { + allowedPrivateKeyLengths: 'array', + wrapPrivateKey: 'boolean', + isTorsionFree: 'function', + clearCofactor: 'function', + } + ); + const { endo, Fp, a } = opts; if (endo) { - if (!Fp.eql(opts.a, Fp.ZERO)) { + if (!Fp.eql(a, Fp.ZERO)) { throw new Error('Endomorphism can only be defined for Koblitz curves that have a=0'); } if ( @@ -124,9 +115,6 @@ function validatePointOpts(curve: CurvePointsType) { throw new Error('Expected endomorphism with beta: bigint and splitScalar: function'); } } - if (typeof opts.fromBytes !== 'function') throw new Error('Invalid fromBytes function'); - if (typeof opts.toBytes !== 'function') throw new Error('Invalid fromBytes function'); - // Set defaults return Object.freeze({ ...opts } as const); } @@ -609,25 +597,30 @@ type SignatureLike = { r: bigint; s: bigint }; export type PubKey = Hex | ProjPointType; -export type CurveType = BasicCurve & { - // Default options - lowS?: boolean; - // Hashes - hash: CHash; // Because we need outputLen for DRBG +export type CurveType = BasicWCurve & { + hash: CHash; // CHash not FHash because we need outputLen for DRBG hmac: HmacFnSync; randomBytes: (bytesLength?: number) => Uint8Array; - // truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => Uint8Array; + lowS?: boolean; bits2int?: (bytes: Uint8Array) => bigint; bits2int_modN?: (bytes: Uint8Array) => bigint; }; function validateOpts(curve: CurveType) { - const opts = validateAbsOpts(curve); - if (typeof opts.hash !== 'function' || !Number.isSafeInteger(opts.hash.outputLen)) - throw new Error('Invalid hash function'); - if (typeof opts.hmac !== 'function') throw new Error('Invalid hmac function'); - if (typeof opts.randomBytes !== 'function') throw new Error('Invalid randomBytes function'); - // Set defaults + const opts = validateBasic(curve); + ut.validateObject( + opts, + { + hash: 'hash', + hmac: 'function', + randomBytes: 'function', + }, + { + bits2int: 'function', + bits2int_modN: 'function', + lowS: 'boolean', + } + ); return Object.freeze({ lowS: true, ...opts } as const); } @@ -756,7 +749,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { return { x, y }; } else { throw new Error( - `Point.fromHex: received invalid point. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes, not ${len}` + `Point of length ${len} was invalid. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes` ); } }, @@ -821,7 +814,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { const ir = invN(radj); // r^-1 const u1 = modN(-h * ir); // -hr^-1 const u2 = modN(s * ir); // sr^-1 - const Q = Point.BASE.multiplyAndAddUnsafe(R, u1, u2); // (sr^-1)R-(hr^-1)G = -(hr^-1)G + (sr^-1) + const Q = Point.BASE.multiplyAndAddUnsafe(R, u1, u2); // (sr^-1)R-(hr^-1)G = -(hr^-1)G + (sr^-1) if (!Q) throw new Error('point at infinify'); // unsafe is fine: no priv data leaked Q.assertValidity(); return Q; @@ -951,9 +944,10 @@ export function weierstrass(curveDef: CurveType): CurveFn { // NOTE: pads output with zero as per spec const ORDER_MASK = ut.bitMask(CURVE.nBitLength); function int2octets(num: bigint): Uint8Array { - if (typeof num !== 'bigint') throw new Error('Expected bigint'); + if (typeof num !== 'bigint') throw new Error('bigint expected'); if (!(_0n <= num && num < ORDER_MASK)) - throw new Error(`Expected number < 2^${CURVE.nBitLength}`); + // n in [0..ORDER_MASK-1] + throw new Error(`bigint expected < 2^${CURVE.nBitLength}`); // works with order, can have different size than numToField! return ut.numberToBytesBE(num, CURVE.nByteLength); } @@ -1045,7 +1039,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { * ``` */ function verify( - signature: Hex | { r: bigint; s: bigint }, + signature: Hex | SignatureLike, msgHash: Hex, publicKey: Hex, opts = defaultVerOpts @@ -1090,7 +1084,6 @@ export function weierstrass(curveDef: CurveType): CurveFn { getSharedSecret, sign, verify, - // Point, ProjectivePoint: Point, Signature, utils, diff --git a/src/p224.ts b/src/p224.ts index be4a34c..e46c26d 100644 --- a/src/p224.ts +++ b/src/p224.ts @@ -10,7 +10,7 @@ export const P224 = createCurve( // Params: a, b a: BigInt('0xfffffffffffffffffffffffffffffffefffffffffffffffffffffffe'), b: BigInt('0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4'), - // Field over which we'll do calculations; 2n**224n - 2n**96n + 1n + // Field over which we'll do calculations; Fp: Fp(BigInt('0xffffffffffffffffffffffffffffffff000000000000000000000001')), // Curve order, total count of valid points in the field n: BigInt('0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d'), diff --git a/src/secp256k1.ts b/src/secp256k1.ts index b04d03f..4f2f6ea 100644 --- a/src/secp256k1.ts +++ b/src/secp256k1.ts @@ -175,7 +175,6 @@ function schnorrSign(message: Hex, privateKey: Hex, auxRand: Hex = randomBytes(3 const { x: px, scalar: d } = schnorrGetScalar(bytesToNum(ensureBytes(privateKey, 32))); const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array - // TODO: replace with proper xor? const t = numTo32b(d ^ bytesToNum(taggedHash(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a) const rand = taggedHash(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m) const k_ = modN(bytesToNum(rand)); // Let k' = int(rand) mod n diff --git a/test/nist.test.js b/test/nist.test.js index 4c3a1b0..f93f15d 100644 --- a/test/nist.test.js +++ b/test/nist.test.js @@ -86,7 +86,8 @@ describe('wycheproof ECDH', () => { try { const pub = CURVE.ProjectivePoint.fromHex(test.public); } catch (e) { - if (e.message.includes('Point.fromHex: received invalid point.')) continue; + // Our strict validation filter doesn't let weird-length DER vectors + if (e.message.startsWith('Point of length')) continue; throw e; } const shared = CURVE.getSharedSecret(test.private, test.public); @@ -140,7 +141,8 @@ describe('wycheproof ECDH', () => { try { const pub = curve.ProjectivePoint.fromHex(test.public); } catch (e) { - if (e.message.includes('Point.fromHex: received invalid point.')) continue; + // Our strict validation filter doesn't let weird-length DER vectors + if (e.message.includes('Point of length')) continue; throw e; } const shared = curve.getSharedSecret(test.private, test.public); @@ -194,7 +196,6 @@ const WYCHEPROOF_ECDSA = { secp256k1: { curve: secp256k1, hashes: { - // TODO: debug why fails, can be bug sha256: { hash: sha256, tests: [secp256k1_sha256_test],