Use declarative curve field validation

This commit is contained in:
Paul Miller 2023-01-28 02:19:46 +00:00
parent f39fb80c52
commit c75129e629
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
11 changed files with 179 additions and 207 deletions

@ -1,6 +1,7 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Abelian group utilities // Abelian group utilities
import { Field, validateField, nLength } from './modular.js'; import { Field, validateField, nLength } from './modular.js';
import { validateObject } from './utils.js';
const _0n = BigInt(0); const _0n = BigInt(0);
const _1n = BigInt(1); const _1n = BigInt(1);
@ -153,7 +154,7 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
// Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok. // Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok.
// Though generator can be different (Fp2 / Fp6 for BLS). // Though generator can be different (Fp2 / Fp6 for BLS).
export type AbstractCurve<T> = { export type BasicCurve<T> = {
Fp: Field<T>; // Field over which we'll do calculations (Fp) Fp: Field<T>; // Field over which we'll do calculations (Fp)
n: bigint; // Curve order, total count of valid points in the field n: bigint; // Curve order, total count of valid points in the field
nBitLength?: number; // bit length of curve order nBitLength?: number; // bit length of curve order
@ -165,20 +166,21 @@ export type AbstractCurve<T> = {
allowInfinityPoint?: boolean; // bls12-381 requires it. ZERO point is valid, but invalid pubkey allowInfinityPoint?: boolean; // bls12-381 requires it. ZERO point is valid, but invalid pubkey
}; };
export function validateAbsOpts<FP, T>(curve: AbstractCurve<FP> & T) { export function validateBasic<FP, T>(curve: BasicCurve<FP> & T) {
validateField(curve.Fp); validateField(curve.Fp);
for (const i of ['n', 'h'] as const) { validateObject(
const val = curve[i]; curve,
if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`); {
} n: 'bigint',
if (!curve.Fp.isValid(curve.Gx)) throw new Error('Invalid generator X coordinate Fp element'); h: 'bigint',
if (!curve.Fp.isValid(curve.Gy)) throw new Error('Invalid generator Y coordinate Fp element'); Gx: 'field',
Gy: 'field',
for (const i of ['nBitLength', 'nByteLength'] as const) { },
const val = curve[i]; {
if (val === undefined) continue; // Optional nBitLength: 'isSafeInteger',
if (!Number.isSafeInteger(val)) throw new Error(`Invalid param ${i}=${val} (${typeof val})`); nByteLength: 'isSafeInteger',
} }
);
// Set defaults // Set defaults
return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const); return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const);
} }

@ -1,23 +1,9 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y² // Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y²
import { mod } from './modular.js'; import { mod } from './modular.js';
import { import * as ut from './utils.js';
bytesToHex, import { ensureBytes, FHash, Hex } from './utils.js';
bytesToNumberLE, import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js';
concatBytes,
ensureBytes,
FHash,
Hex,
numberToBytesLE,
} from './utils.js';
import {
Group,
GroupConstructor,
wNAF,
AbstractCurve,
validateAbsOpts,
AffinePoint,
} from './curve.js';
// Be friendly to bad ECMAScript parsers by not using bigint literals like 123n // Be friendly to bad ECMAScript parsers by not using bigint literals like 123n
const _0n = BigInt(0); const _0n = BigInt(0);
@ -26,7 +12,7 @@ const _2n = BigInt(2);
const _8n = BigInt(8); const _8n = BigInt(8);
// Edwards curves must declare params a & d. // Edwards curves must declare params a & d.
export type CurveType = AbstractCurve<bigint> & { export type CurveType = BasicCurve<bigint> & {
a: bigint; // curve param a a: bigint; // curve param a
d: bigint; // curve param d d: bigint; // curve param d
hash: FHash; // Hashing hash: FHash; // Hashing
@ -39,19 +25,22 @@ export type CurveType = AbstractCurve<bigint> & {
}; };
function validateOpts(curve: CurveType) { function validateOpts(curve: CurveType) {
const opts = validateAbsOpts(curve); const opts = validateBasic(curve);
if (typeof opts.hash !== 'function') throw new Error('Invalid hash function'); ut.validateObject(
for (const i of ['a', 'd'] as const) { curve,
const val = opts[i]; {
if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`); hash: 'function',
} a: 'bigint',
for (const fn of ['randomBytes'] as const) { d: 'bigint',
if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`); randomBytes: 'function',
} },
for (const fn of ['adjustScalarBytes', 'domain', 'uvRatio', 'mapToCurve'] as const) { {
if (opts[fn] === undefined) continue; // Optional adjustScalarBytes: 'function',
if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`); domain: 'function',
uvRatio: 'function',
mapToCurve: 'function',
} }
);
// Set defaults // Set defaults
return Object.freeze({ ...opts } as const); return Object.freeze({ ...opts } as const);
} }
@ -75,7 +64,7 @@ export interface ExtPointConstructor extends GroupConstructor<ExtPointType> {
new (x: bigint, y: bigint, z: bigint, t: bigint): ExtPointType; new (x: bigint, y: bigint, z: bigint, t: bigint): ExtPointType;
fromAffine(p: AffinePoint<bigint>): ExtPointType; fromAffine(p: AffinePoint<bigint>): ExtPointType;
fromHex(hex: Hex): ExtPointType; fromHex(hex: Hex): ExtPointType;
fromPrivateKey(privateKey: Hex): ExtPointType; // TODO: remove fromPrivateKey(privateKey: Hex): ExtPointType;
} }
export type CurveFn = { export type CurveFn = {
@ -340,7 +329,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const normed = hex.slice(); // copy again, we'll manipulate it const normed = hex.slice(); // copy again, we'll manipulate it
const lastByte = hex[len - 1]; // select last byte const lastByte = hex[len - 1]; // select last byte
normed[len - 1] = lastByte & ~0x80; // clear last bit normed[len - 1] = lastByte & ~0x80; // clear last bit
const y = bytesToNumberLE(normed); const y = ut.bytesToNumberLE(normed);
if (y === _0n) { if (y === _0n) {
// y=0 is allowed // y=0 is allowed
} else { } else {
@ -366,12 +355,12 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
} }
toRawBytes(): Uint8Array { toRawBytes(): Uint8Array {
const { x, y } = this.toAffine(); const { x, y } = this.toAffine();
const bytes = numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y) const bytes = ut.numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y)
bytes[bytes.length - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y bytes[bytes.length - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y
return bytes; // and use the last byte to encode sign of x return bytes; // and use the last byte to encode sign of x
} }
toHex(): string { toHex(): string {
return bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string. return ut.bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string.
} }
} }
const { BASE: G, ZERO: I } = Point; const { BASE: G, ZERO: I } = Point;
@ -382,7 +371,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
} }
// Little-endian SHA512 with modulo n // Little-endian SHA512 with modulo n
function modN_LE(hash: Uint8Array): bigint { function modN_LE(hash: Uint8Array): bigint {
return modN(bytesToNumberLE(hash)); return modN(ut.bytesToNumberLE(hash));
} }
function isHex(item: Hex, err: string) { function isHex(item: Hex, err: string) {
if (typeof item !== 'string' && !(item instanceof Uint8Array)) if (typeof item !== 'string' && !(item instanceof Uint8Array))
@ -411,7 +400,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// int('LE', SHA512(dom2(F, C) || msgs)) mod N // int('LE', SHA512(dom2(F, C) || msgs)) mod N
function hashDomainToScalar(context: Hex = new Uint8Array(), ...msgs: Uint8Array[]) { function hashDomainToScalar(context: Hex = new Uint8Array(), ...msgs: Uint8Array[]) {
const msg = concatBytes(...msgs); const msg = ut.concatBytes(...msgs);
return modN_LE(cHash(domain(msg, ensureBytes(context), !!preHash))); return modN_LE(cHash(domain(msg, ensureBytes(context), !!preHash)));
} }
@ -426,7 +415,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const k = hashDomainToScalar(context, R, pointBytes, msg); // R || A || PH(M) const k = hashDomainToScalar(context, R, pointBytes, msg); // R || A || PH(M)
const s = modN(r + k * scalar); // S = (r + k * s) mod L const s = modN(r + k * scalar); // S = (r + k * s) mod L
assertGE0(s); // 0 <= s < l assertGE0(s); // 0 <= s < l
const res = concatBytes(R, numberToBytesLE(s, Fp.BYTES)); const res = ut.concatBytes(R, ut.numberToBytesLE(s, Fp.BYTES));
return ensureBytes(res, nByteLength * 2); // 64-byte signature return ensureBytes(res, nByteLength * 2); // 64-byte signature
} }
@ -439,7 +428,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
if (preHash) msg = preHash(msg); // for ed25519ph, etc if (preHash) msg = preHash(msg); // for ed25519ph, etc
const A = Point.fromHex(publicKey, false); // Check for s bounds, hex validity const A = Point.fromHex(publicKey, false); // Check for s bounds, hex validity
const R = Point.fromHex(sig.slice(0, len), false); // 0 <= R < 2^256: ZIP215 R can be >= P const R = Point.fromHex(sig.slice(0, len), false); // 0 <= R < 2^256: ZIP215 R can be >= P
const s = bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l const s = ut.bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l
const SB = G.multiplyUnsafe(s); const SB = G.multiplyUnsafe(s);
const k = hashDomainToScalar(context, R.toRawBytes(), A.toRawBytes(), msg); const k = hashDomainToScalar(context, R.toRawBytes(), A.toRawBytes(), msg);
const RkA = R.add(A.multiplyUnsafe(k)); const RkA = R.add(A.multiplyUnsafe(k));

@ -45,7 +45,7 @@ declare const TextDecoder: any;
export function stringToBytes(str: string): Uint8Array { export function stringToBytes(str: string): Uint8Array {
if (typeof str !== 'string') { if (typeof str !== 'string') {
throw new TypeError(`utf8ToBytes expected string, got ${typeof str}`); throw new Error(`utf8ToBytes expected string, got ${typeof str}`);
} }
return new TextEncoder().encode(str); return new TextEncoder().encode(str);
} }

@ -7,6 +7,7 @@ import {
bytesToNumberBE, bytesToNumberBE,
bytesToNumberLE, bytesToNumberLE,
ensureBytes, ensureBytes,
validateObject,
} from './utils.js'; } from './utils.js';
// prettier-ignore // prettier-ignore
const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3); const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3);
@ -40,7 +41,6 @@ export function pow(num: bigint, power: bigint, modulo: bigint): bigint {
} }
// Does x ^ (2 ^ power) mod p. pow2(30, 4) == 30 ^ (2 ^ 4) // Does x ^ (2 ^ power) mod p. pow2(30, 4) == 30 ^ (2 ^ 4)
// TODO: Fp version?
export function pow2(x: bigint, power: bigint, modulo: bigint): bigint { export function pow2(x: bigint, power: bigint, modulo: bigint): bigint {
let res = x; let res = x;
while (power-- > _0n) { while (power-- > _0n) {
@ -249,18 +249,17 @@ const FIELD_FIELDS = [
'addN', 'subN', 'mulN', 'sqrN' 'addN', 'subN', 'mulN', 'sqrN'
] as const; ] as const;
export function validateField<T>(field: Field<T>) { export function validateField<T>(field: Field<T>) {
for (const i of ['ORDER', 'MASK'] as const) { const initial = {
if (typeof field[i] !== 'bigint') ORDER: 'bigint',
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); MASK: 'bigint',
} BYTES: 'isSafeInteger',
for (const i of ['BYTES', 'BITS'] as const) { BITS: 'isSafeInteger',
if (typeof field[i] !== 'number') } as Record<string, string>;
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); const opts = FIELD_FIELDS.reduce((map, val: string) => {
} map[val] = 'function';
for (const i of FIELD_FIELDS) { return map;
if (typeof field[i] !== 'function') }, initial);
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`); return validateObject(field, opts);
}
} }
// Generic field functions // Generic field functions

@ -1,14 +1,13 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
import { mod, pow } from './modular.js'; import { mod, pow } from './modular.js';
import { ensureBytes, numberToBytesLE, bytesToNumberLE } from './utils.js'; import { bytesToNumberLE, ensureBytes, numberToBytesLE, validateObject } from './utils.js';
const _0n = BigInt(0); const _0n = BigInt(0);
const _1n = BigInt(1); const _1n = BigInt(1);
type Hex = string | Uint8Array; type Hex = string | Uint8Array;
export type CurveType = { export type CurveType = {
// Field over which we'll do calculations. Verify with: P: bigint; // finite field prime
P: bigint;
nByteLength: number; nByteLength: number;
adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array; adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array;
domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array; domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array;
@ -27,24 +26,20 @@ export type CurveFn = {
}; };
function validateOpts(curve: CurveType) { function validateOpts(curve: CurveType) {
for (const i of ['a24'] as const) { validateObject(
if (typeof curve[i] !== 'bigint') curve,
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); {
} a24: 'bigint',
for (const i of ['montgomeryBits', 'nByteLength'] as const) { },
if (curve[i] === undefined) continue; // Optional {
if (!Number.isSafeInteger(curve[i])) montgomeryBits: 'isSafeInteger',
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); nByteLength: 'isSafeInteger',
} adjustScalarBytes: 'function',
for (const fn of ['adjustScalarBytes', 'domain', 'powPminus2'] as const) { domain: 'function',
if (curve[fn] === undefined) continue; // Optional powPminus2: 'function',
if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`); Gu: 'string',
}
for (const i of ['Gu'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'string')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
} }
);
// Set defaults // Set defaults
return Object.freeze({ ...curve } as const); return Object.freeze({ ...curve } as const);
} }
@ -61,27 +56,7 @@ export function montgomery(curveDef: CurveType): CurveFn {
const adjustScalarBytes = CURVE.adjustScalarBytes || ((bytes: Uint8Array) => bytes); const adjustScalarBytes = CURVE.adjustScalarBytes || ((bytes: Uint8Array) => bytes);
const powPminus2 = CURVE.powPminus2 || ((x: bigint) => pow(x, P - BigInt(2), P)); const powPminus2 = CURVE.powPminus2 || ((x: bigint) => pow(x, P - BigInt(2), P));
/** // cswap from RFC7748. But it is not from RFC7748!
* Checks for num to be in range:
* For strict == true: `0 < num < max`.
* For strict == false: `0 <= num < max`.
* Converts non-float safe numbers to bigints.
*/
function normalizeScalar(num: bigint, max: bigint, strict = true): bigint {
if (!max) throw new TypeError('Specify max value');
if (typeof num === 'number' && Number.isSafeInteger(num)) num = BigInt(num);
if (typeof num === 'bigint' && num < max) {
if (strict) {
if (_0n < num) return num;
} else {
if (_0n <= num) return num;
}
}
throw new TypeError('Expected valid scalar: 0 < scalar < max');
}
// cswap from RFC7748
// NOTE: cswap is not from RFC7748!
/* /*
cswap(swap, x_2, x_3): cswap(swap, x_2, x_3):
dummy = mask(swap) AND (x_2 XOR x_3) dummy = mask(swap) AND (x_2 XOR x_3)
@ -98,6 +73,11 @@ export function montgomery(curveDef: CurveType): CurveFn {
return [x_2, x_3]; return [x_2, x_3];
} }
function assertFieldElement(n: bigint): bigint {
if (typeof n === 'bigint' && _0n <= n && n < P) return n;
throw new Error('Expected valid scalar 0 < scalar < CURVE.P');
}
// x25519 from 4 // x25519 from 4
/** /**
* *
@ -106,11 +86,10 @@ export function montgomery(curveDef: CurveType): CurveFn {
* @returns new Point on Montgomery curve * @returns new Point on Montgomery curve
*/ */
function montgomeryLadder(pointU: bigint, scalar: bigint): bigint { function montgomeryLadder(pointU: bigint, scalar: bigint): bigint {
const { P } = CURVE; const u = assertFieldElement(pointU);
const u = normalizeScalar(pointU, P);
// Section 5: Implementations MUST accept non-canonical values and process them as // Section 5: Implementations MUST accept non-canonical values and process them as
// if they had been reduced modulo the field prime. // if they had been reduced modulo the field prime.
const k = normalizeScalar(scalar, P); const k = assertFieldElement(scalar);
// The constant a24 is (486662 - 2) / 4 = 121665 for curve25519/X25519 // The constant a24 is (486662 - 2) / 4 = 121665 for curve25519/X25519
const a24 = CURVE.a24; const a24 = CURVE.a24;
const x_1 = u; const x_1 = u;
@ -166,28 +145,20 @@ export function montgomery(curveDef: CurveType): CurveFn {
} }
function decodeUCoordinate(uEnc: Hex): bigint { function decodeUCoordinate(uEnc: Hex): bigint {
const u = ensureBytes(uEnc, montgomeryBytes);
// Section 5: When receiving such an array, implementations of X25519 // Section 5: When receiving such an array, implementations of X25519
// MUST mask the most significant bit in the final byte. // MUST mask the most significant bit in the final byte.
// This is very ugly way, but it works because fieldLen-1 is outside of bounds for X448, so this becomes NOOP // This is very ugly way, but it works because fieldLen-1 is outside of bounds for X448, so this becomes NOOP
// fieldLen - scalaryBytes = 1 for X448 and = 0 for X25519 // fieldLen - scalaryBytes = 1 for X448 and = 0 for X25519
const u = ensureBytes(uEnc, montgomeryBytes);
u[fieldLen - 1] &= 127; // 0b0111_1111 u[fieldLen - 1] &= 127; // 0b0111_1111
return bytesToNumberLE(u); return bytesToNumberLE(u);
} }
function decodeScalar(n: Hex): bigint { function decodeScalar(n: Hex): bigint {
const bytes = ensureBytes(n); const bytes = ensureBytes(n);
if (bytes.length !== montgomeryBytes && bytes.length !== fieldLen) if (bytes.length !== montgomeryBytes && bytes.length !== fieldLen)
throw new Error(`Expected ${montgomeryBytes} or ${fieldLen} bytes, got ${bytes.length}`); throw new Error(`Expected ${montgomeryBytes} or ${fieldLen} bytes, got ${bytes.length}`);
return bytesToNumberLE(adjustScalarBytes(bytes)); return bytesToNumberLE(adjustScalarBytes(bytes));
} }
/**
* Computes shared secret between private key "scalar" and public key's "u" (x) coordinate.
* We can get 'y' coordinate from 'u',
* but Point.fromHex also wants 'x' coordinate oddity flag,
* and we cannot get 'x' without knowing 'v'.
* Need to add generic conversion between twisted edwards and complimentary curve for JubJub.
*/
function scalarMult(scalar: Hex, u: Hex): Uint8Array { function scalarMult(scalar: Hex, u: Hex): Uint8Array {
const pointU = decodeUCoordinate(u); const pointU = decodeUCoordinate(u);
const _scalar = decodeScalar(scalar); const _scalar = decodeScalar(scalar);
@ -197,12 +168,7 @@ export function montgomery(curveDef: CurveType): CurveFn {
if (pu === _0n) throw new Error('Invalid private or public key received'); if (pu === _0n) throw new Error('Invalid private or public key received');
return encodeUCoordinate(pu); return encodeUCoordinate(pu);
} }
/** // Computes public key from private. By doing scalar multiplication of base point.
* Computes public key from private.
* Executes scalar multiplication of curve's base point by scalar.
* @param scalar private key
* @returns new public key
*/
function scalarMultBase(scalar: Hex): Uint8Array { function scalarMultBase(scalar: Hex): Uint8Array {
return scalarMult(scalar, CURVE.Gu); return scalarMult(scalar, CURVE.Gu);
} }

@ -1,6 +1,6 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */ /*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Poseidon Hash: https://eprint.iacr.org/2019/458.pdf, https://www.poseidon-hash.info // Poseidon Hash: https://eprint.iacr.org/2019/458.pdf, https://www.poseidon-hash.info
import { Field, validateField, FpPow } from './modular.js'; import { Field, FpPow, validateField } from './modular.js';
// We don't provide any constants, since different implementations use different constants. // We don't provide any constants, since different implementations use different constants.
// For reference constants see './test/poseidon.test.js'. // For reference constants see './test/poseidon.test.js'.
export type PoseidonOpts = { export type PoseidonOpts = {

@ -18,7 +18,7 @@ export type FHash = (message: Uint8Array | string) => Uint8Array;
const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0')); const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0'));
export function bytesToHex(bytes: Uint8Array): string { export function bytesToHex(bytes: Uint8Array): string {
if (!u8a(bytes)) throw new Error('Expected Uint8Array'); if (!u8a(bytes)) throw new Error('Uint8Array expected');
// pre-caching improves the speed 6x // pre-caching improves the speed 6x
let hex = ''; let hex = '';
for (let i = 0; i < bytes.length; i++) { for (let i = 0; i < bytes.length; i++) {
@ -33,21 +33,21 @@ export function numberToHexUnpadded(num: number | bigint): string {
} }
export function hexToNumber(hex: string): bigint { export function hexToNumber(hex: string): bigint {
if (typeof hex !== 'string') throw new Error('hexToNumber: expected string, got ' + typeof hex); if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex);
// Big Endian // Big Endian
return BigInt(`0x${hex}`); return BigInt(`0x${hex}`);
} }
// Caching slows it down 2-3x // Caching slows it down 2-3x
export function hexToBytes(hex: string): Uint8Array { export function hexToBytes(hex: string): Uint8Array {
if (typeof hex !== 'string') throw new Error('hexToBytes: expected string, got ' + typeof hex); if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex);
if (hex.length % 2) throw new Error('hexToBytes: received invalid unpadded hex ' + hex.length); if (hex.length % 2) throw new Error('hex string is invalid: unpadded ' + hex.length);
const array = new Uint8Array(hex.length / 2); const array = new Uint8Array(hex.length / 2);
for (let i = 0; i < array.length; i++) { for (let i = 0; i < array.length; i++) {
const j = i * 2; const j = i * 2;
const hexByte = hex.slice(j, j + 2); const hexByte = hex.slice(j, j + 2);
const byte = Number.parseInt(hexByte, 16); const byte = Number.parseInt(hexByte, 16);
if (Number.isNaN(byte) || byte < 0) throw new Error('Invalid byte sequence'); if (Number.isNaN(byte) || byte < 0) throw new Error('invalid byte sequence');
array[i] = byte; array[i] = byte;
} }
return array; return array;
@ -58,7 +58,7 @@ export function bytesToNumberBE(bytes: Uint8Array): bigint {
return hexToNumber(bytesToHex(bytes)); return hexToNumber(bytesToHex(bytes));
} }
export function bytesToNumberLE(bytes: Uint8Array): bigint { export function bytesToNumberLE(bytes: Uint8Array): bigint {
if (!u8a(bytes)) throw new Error('Expected Uint8Array'); if (!u8a(bytes)) throw new Error('Uint8Array expected');
return hexToNumber(bytesToHex(Uint8Array.from(bytes).reverse())); return hexToNumber(bytesToHex(Uint8Array.from(bytes).reverse()));
} }
@ -66,11 +66,7 @@ export const numberToBytesBE = (n: bigint, len: number) =>
hexToBytes(n.toString(16).padStart(len * 2, '0')); hexToBytes(n.toString(16).padStart(len * 2, '0'));
export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse(); export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse();
// Returns variable number bytes (minimal bigint encoding?) // Returns variable number bytes (minimal bigint encoding?)
export const numberToVarBytesBE = (n: bigint) => { export const numberToVarBytesBE = (n: bigint) => hexToBytes(numberToHexUnpadded(n));
let hex = n.toString(16);
if (hex.length & 1) hex = '0' + hex;
return hexToBytes(hex);
};
export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array { export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
// Uint8Array.from() instead of hash.slice() because node.js Buffer // Uint8Array.from() instead of hash.slice() because node.js Buffer
@ -82,17 +78,15 @@ export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
} }
// Copies several Uint8Arrays into one. // Copies several Uint8Arrays into one.
export function concatBytes(...arrays: Uint8Array[]): Uint8Array { export function concatBytes(...arrs: Uint8Array[]): Uint8Array {
if (!arrays.every((b) => u8a(b))) throw new Error('Uint8Array list expected'); const r = new Uint8Array(arrs.reduce((sum, a) => sum + a.length, 0));
if (arrays.length === 1) return arrays[0]; let pad = 0; // walk through each item, ensure they have proper type
const length = arrays.reduce((a, arr) => a + arr.length, 0); arrs.forEach((a) => {
const result = new Uint8Array(length); if (!u8a(a)) throw new Error('Uint8Array expected');
for (let i = 0, pad = 0; i < arrays.length; i++) { r.set(a, pad);
const arr = arrays[i]; pad += a.length;
result.set(arr, pad); });
pad += arr.length; return r;
}
return result;
} }
export function equalBytes(b1: Uint8Array, b2: Uint8Array) { export function equalBytes(b1: Uint8Array, b2: Uint8Array) {
@ -119,3 +113,32 @@ export const bitSet = (n: bigint, pos: number, value: boolean) =>
// Return mask for N bits (Same as BigInt(`0b${Array(i).fill('1').join('')}`)) // Return mask for N bits (Same as BigInt(`0b${Array(i).fill('1').join('')}`))
// Not using ** operator with bigints for old engines. // Not using ** operator with bigints for old engines.
export const bitMask = (n: number) => (_2n << BigInt(n - 1)) - _1n; export const bitMask = (n: number) => (_2n << BigInt(n - 1)) - _1n;
type ValMap = Record<string, string>;
export function validateObject(object: object, validators: ValMap, optValidators: ValMap = {}) {
const validatorFns: Record<string, (val: any) => boolean> = {
bigint: (val) => typeof val === 'bigint',
function: (val) => typeof val === 'function',
boolean: (val) => typeof val === 'boolean',
string: (val) => typeof val === 'string',
isSafeInteger: (val) => Number.isSafeInteger(val),
array: (val) => Array.isArray(val),
field: (val) => (object as any).Fp.isValid(val),
hash: (val) => typeof val === 'function' && Number.isSafeInteger(val.outputLen),
};
// type Key = keyof typeof validators;
const checkField = (fieldName: string, type: string, isOptional: boolean) => {
const checkVal = validatorFns[type];
if (typeof checkVal !== 'function')
throw new Error(`Invalid validator "${type}", expected function`);
const val = object[fieldName as keyof typeof object];
if (isOptional && val === undefined) return;
if (!checkVal(val)) {
throw new Error(`Invalid param ${fieldName}=${val} (${typeof val}), expected ${type}`);
}
};
for (let [fieldName, type] of Object.entries(validators)) checkField(fieldName, type, false);
for (let [fieldName, type] of Object.entries(optValidators)) checkField(fieldName, type, true);
return object;
}

@ -2,15 +2,8 @@
// Short Weierstrass curve. The formula is: y² = x³ + ax + b // Short Weierstrass curve. The formula is: y² = x³ + ax + b
import * as mod from './modular.js'; import * as mod from './modular.js';
import * as ut from './utils.js'; import * as ut from './utils.js';
import { Hex, PrivKey, ensureBytes, CHash } from './utils.js'; import { CHash, Hex, PrivKey, ensureBytes } from './utils.js';
import { import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js';
Group,
GroupConstructor,
wNAF,
AbstractCurve,
validateAbsOpts,
AffinePoint,
} from './curve.js';
export type { AffinePoint }; export type { AffinePoint };
type HmacFnSync = (key: Uint8Array, ...messages: Uint8Array[]) => Uint8Array; type HmacFnSync = (key: Uint8Array, ...messages: Uint8Array[]) => Uint8Array;
@ -18,7 +11,7 @@ type EndomorphismOpts = {
beta: bigint; beta: bigint;
splitScalar: (k: bigint) => { k1neg: boolean; k1: bigint; k2neg: boolean; k2: bigint }; splitScalar: (k: bigint) => { k1neg: boolean; k1: bigint; k2neg: boolean; k2: bigint };
}; };
export type BasicCurve<T> = AbstractCurve<T> & { export type BasicWCurve<T> = BasicCurve<T> & {
// Params: a, b // Params: a, b
a: T; a: T;
b: T; b: T;
@ -86,34 +79,32 @@ export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
normalizeZ(points: ProjPointType<T>[]): ProjPointType<T>[]; normalizeZ(points: ProjPointType<T>[]): ProjPointType<T>[];
} }
export type CurvePointsType<T> = BasicCurve<T> & { export type CurvePointsType<T> = BasicWCurve<T> & {
// Bytes // Bytes
fromBytes: (bytes: Uint8Array) => AffinePoint<T>; fromBytes: (bytes: Uint8Array) => AffinePoint<T>;
toBytes: (c: ProjConstructor<T>, point: ProjPointType<T>, compressed: boolean) => Uint8Array; toBytes: (c: ProjConstructor<T>, point: ProjPointType<T>, compressed: boolean) => Uint8Array;
}; };
function validatePointOpts<T>(curve: CurvePointsType<T>) { function validatePointOpts<T>(curve: CurvePointsType<T>) {
const opts = validateAbsOpts(curve); const opts = validateBasic(curve);
const Fp = opts.Fp; ut.validateObject(
for (const i of ['a', 'b'] as const) { opts,
if (!Fp.isValid(curve[i])) {
throw new Error(`Invalid curve param ${i}=${opts[i]} (${typeof opts[i]})`); a: 'field',
b: 'field',
fromBytes: 'function',
toBytes: 'function',
},
{
allowedPrivateKeyLengths: 'array',
wrapPrivateKey: 'boolean',
isTorsionFree: 'function',
clearCofactor: 'function',
} }
for (const i of ['allowedPrivateKeyLengths'] as const) { );
if (curve[i] === undefined) continue; // Optional const { endo, Fp, a } = opts;
if (!Array.isArray(curve[i])) throw new Error(`Invalid ${i} array`);
}
for (const i of ['wrapPrivateKey'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'boolean') throw new Error(`Invalid ${i} boolean`);
}
for (const i of ['isTorsionFree', 'clearCofactor'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'function') throw new Error(`Invalid ${i} function`);
}
const endo = opts.endo;
if (endo) { if (endo) {
if (!Fp.eql(opts.a, Fp.ZERO)) { if (!Fp.eql(a, Fp.ZERO)) {
throw new Error('Endomorphism can only be defined for Koblitz curves that have a=0'); throw new Error('Endomorphism can only be defined for Koblitz curves that have a=0');
} }
if ( if (
@ -124,9 +115,6 @@ function validatePointOpts<T>(curve: CurvePointsType<T>) {
throw new Error('Expected endomorphism with beta: bigint and splitScalar: function'); throw new Error('Expected endomorphism with beta: bigint and splitScalar: function');
} }
} }
if (typeof opts.fromBytes !== 'function') throw new Error('Invalid fromBytes function');
if (typeof opts.toBytes !== 'function') throw new Error('Invalid fromBytes function');
// Set defaults
return Object.freeze({ ...opts } as const); return Object.freeze({ ...opts } as const);
} }
@ -609,25 +597,30 @@ type SignatureLike = { r: bigint; s: bigint };
export type PubKey = Hex | ProjPointType<bigint>; export type PubKey = Hex | ProjPointType<bigint>;
export type CurveType = BasicCurve<bigint> & { export type CurveType = BasicWCurve<bigint> & {
// Default options hash: CHash; // CHash not FHash because we need outputLen for DRBG
lowS?: boolean;
// Hashes
hash: CHash; // Because we need outputLen for DRBG
hmac: HmacFnSync; hmac: HmacFnSync;
randomBytes: (bytesLength?: number) => Uint8Array; randomBytes: (bytesLength?: number) => Uint8Array;
// truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => Uint8Array; lowS?: boolean;
bits2int?: (bytes: Uint8Array) => bigint; bits2int?: (bytes: Uint8Array) => bigint;
bits2int_modN?: (bytes: Uint8Array) => bigint; bits2int_modN?: (bytes: Uint8Array) => bigint;
}; };
function validateOpts(curve: CurveType) { function validateOpts(curve: CurveType) {
const opts = validateAbsOpts(curve); const opts = validateBasic(curve);
if (typeof opts.hash !== 'function' || !Number.isSafeInteger(opts.hash.outputLen)) ut.validateObject(
throw new Error('Invalid hash function'); opts,
if (typeof opts.hmac !== 'function') throw new Error('Invalid hmac function'); {
if (typeof opts.randomBytes !== 'function') throw new Error('Invalid randomBytes function'); hash: 'hash',
// Set defaults hmac: 'function',
randomBytes: 'function',
},
{
bits2int: 'function',
bits2int_modN: 'function',
lowS: 'boolean',
}
);
return Object.freeze({ lowS: true, ...opts } as const); return Object.freeze({ lowS: true, ...opts } as const);
} }
@ -756,7 +749,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
return { x, y }; return { x, y };
} else { } else {
throw new Error( throw new Error(
`Point.fromHex: received invalid point. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes, not ${len}` `Point of length ${len} was invalid. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes`
); );
} }
}, },
@ -951,9 +944,10 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// NOTE: pads output with zero as per spec // NOTE: pads output with zero as per spec
const ORDER_MASK = ut.bitMask(CURVE.nBitLength); const ORDER_MASK = ut.bitMask(CURVE.nBitLength);
function int2octets(num: bigint): Uint8Array { function int2octets(num: bigint): Uint8Array {
if (typeof num !== 'bigint') throw new Error('Expected bigint'); if (typeof num !== 'bigint') throw new Error('bigint expected');
if (!(_0n <= num && num < ORDER_MASK)) if (!(_0n <= num && num < ORDER_MASK))
throw new Error(`Expected number < 2^${CURVE.nBitLength}`); // n in [0..ORDER_MASK-1]
throw new Error(`bigint expected < 2^${CURVE.nBitLength}`);
// works with order, can have different size than numToField! // works with order, can have different size than numToField!
return ut.numberToBytesBE(num, CURVE.nByteLength); return ut.numberToBytesBE(num, CURVE.nByteLength);
} }
@ -1045,7 +1039,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
* ``` * ```
*/ */
function verify( function verify(
signature: Hex | { r: bigint; s: bigint }, signature: Hex | SignatureLike,
msgHash: Hex, msgHash: Hex,
publicKey: Hex, publicKey: Hex,
opts = defaultVerOpts opts = defaultVerOpts
@ -1090,7 +1084,6 @@ export function weierstrass(curveDef: CurveType): CurveFn {
getSharedSecret, getSharedSecret,
sign, sign,
verify, verify,
// Point,
ProjectivePoint: Point, ProjectivePoint: Point,
Signature, Signature,
utils, utils,

@ -10,7 +10,7 @@ export const P224 = createCurve(
// Params: a, b // Params: a, b
a: BigInt('0xfffffffffffffffffffffffffffffffefffffffffffffffffffffffe'), a: BigInt('0xfffffffffffffffffffffffffffffffefffffffffffffffffffffffe'),
b: BigInt('0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4'), b: BigInt('0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4'),
// Field over which we'll do calculations; 2n**224n - 2n**96n + 1n // Field over which we'll do calculations;
Fp: Fp(BigInt('0xffffffffffffffffffffffffffffffff000000000000000000000001')), Fp: Fp(BigInt('0xffffffffffffffffffffffffffffffff000000000000000000000001')),
// Curve order, total count of valid points in the field // Curve order, total count of valid points in the field
n: BigInt('0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d'), n: BigInt('0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d'),

@ -175,7 +175,6 @@ function schnorrSign(message: Hex, privateKey: Hex, auxRand: Hex = randomBytes(3
const { x: px, scalar: d } = schnorrGetScalar(bytesToNum(ensureBytes(privateKey, 32))); const { x: px, scalar: d } = schnorrGetScalar(bytesToNum(ensureBytes(privateKey, 32)));
const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array
// TODO: replace with proper xor?
const t = numTo32b(d ^ bytesToNum(taggedHash(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a) const t = numTo32b(d ^ bytesToNum(taggedHash(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a)
const rand = taggedHash(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m) const rand = taggedHash(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m)
const k_ = modN(bytesToNum(rand)); // Let k' = int(rand) mod n const k_ = modN(bytesToNum(rand)); // Let k' = int(rand) mod n

@ -86,7 +86,8 @@ describe('wycheproof ECDH', () => {
try { try {
const pub = CURVE.ProjectivePoint.fromHex(test.public); const pub = CURVE.ProjectivePoint.fromHex(test.public);
} catch (e) { } catch (e) {
if (e.message.includes('Point.fromHex: received invalid point.')) continue; // Our strict validation filter doesn't let weird-length DER vectors
if (e.message.startsWith('Point of length')) continue;
throw e; throw e;
} }
const shared = CURVE.getSharedSecret(test.private, test.public); const shared = CURVE.getSharedSecret(test.private, test.public);
@ -140,7 +141,8 @@ describe('wycheproof ECDH', () => {
try { try {
const pub = curve.ProjectivePoint.fromHex(test.public); const pub = curve.ProjectivePoint.fromHex(test.public);
} catch (e) { } catch (e) {
if (e.message.includes('Point.fromHex: received invalid point.')) continue; // Our strict validation filter doesn't let weird-length DER vectors
if (e.message.includes('Point of length')) continue;
throw e; throw e;
} }
const shared = curve.getSharedSecret(test.private, test.public); const shared = curve.getSharedSecret(test.private, test.public);
@ -194,7 +196,6 @@ const WYCHEPROOF_ECDSA = {
secp256k1: { secp256k1: {
curve: secp256k1, curve: secp256k1,
hashes: { hashes: {
// TODO: debug why fails, can be bug
sha256: { sha256: {
hash: sha256, hash: sha256,
tests: [secp256k1_sha256_test], tests: [secp256k1_sha256_test],