Use declarative curve field validation

This commit is contained in:
Paul Miller 2023-01-28 02:19:46 +00:00
parent f39fb80c52
commit c75129e629
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
11 changed files with 179 additions and 207 deletions

@ -1,6 +1,7 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Abelian group utilities
import { Field, validateField, nLength } from './modular.js';
import { validateObject } from './utils.js';
const _0n = BigInt(0);
const _1n = BigInt(1);
@ -153,7 +154,7 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
// Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok.
// Though generator can be different (Fp2 / Fp6 for BLS).
export type AbstractCurve<T> = {
export type BasicCurve<T> = {
Fp: Field<T>; // Field over which we'll do calculations (Fp)
n: bigint; // Curve order, total count of valid points in the field
nBitLength?: number; // bit length of curve order
@ -165,20 +166,21 @@ export type AbstractCurve<T> = {
allowInfinityPoint?: boolean; // bls12-381 requires it. ZERO point is valid, but invalid pubkey
};
export function validateAbsOpts<FP, T>(curve: AbstractCurve<FP> & T) {
export function validateBasic<FP, T>(curve: BasicCurve<FP> & T) {
validateField(curve.Fp);
for (const i of ['n', 'h'] as const) {
const val = curve[i];
if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`);
}
if (!curve.Fp.isValid(curve.Gx)) throw new Error('Invalid generator X coordinate Fp element');
if (!curve.Fp.isValid(curve.Gy)) throw new Error('Invalid generator Y coordinate Fp element');
for (const i of ['nBitLength', 'nByteLength'] as const) {
const val = curve[i];
if (val === undefined) continue; // Optional
if (!Number.isSafeInteger(val)) throw new Error(`Invalid param ${i}=${val} (${typeof val})`);
}
validateObject(
curve,
{
n: 'bigint',
h: 'bigint',
Gx: 'field',
Gy: 'field',
},
{
nBitLength: 'isSafeInteger',
nByteLength: 'isSafeInteger',
}
);
// Set defaults
return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const);
}

@ -1,23 +1,9 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y²
import { mod } from './modular.js';
import {
bytesToHex,
bytesToNumberLE,
concatBytes,
ensureBytes,
FHash,
Hex,
numberToBytesLE,
} from './utils.js';
import {
Group,
GroupConstructor,
wNAF,
AbstractCurve,
validateAbsOpts,
AffinePoint,
} from './curve.js';
import * as ut from './utils.js';
import { ensureBytes, FHash, Hex } from './utils.js';
import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js';
// Be friendly to bad ECMAScript parsers by not using bigint literals like 123n
const _0n = BigInt(0);
@ -26,7 +12,7 @@ const _2n = BigInt(2);
const _8n = BigInt(8);
// Edwards curves must declare params a & d.
export type CurveType = AbstractCurve<bigint> & {
export type CurveType = BasicCurve<bigint> & {
a: bigint; // curve param a
d: bigint; // curve param d
hash: FHash; // Hashing
@ -39,19 +25,22 @@ export type CurveType = AbstractCurve<bigint> & {
};
function validateOpts(curve: CurveType) {
const opts = validateAbsOpts(curve);
if (typeof opts.hash !== 'function') throw new Error('Invalid hash function');
for (const i of ['a', 'd'] as const) {
const val = opts[i];
if (typeof val !== 'bigint') throw new Error(`Invalid curve param ${i}=${val} (${typeof val})`);
}
for (const fn of ['randomBytes'] as const) {
if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`);
}
for (const fn of ['adjustScalarBytes', 'domain', 'uvRatio', 'mapToCurve'] as const) {
if (opts[fn] === undefined) continue; // Optional
if (typeof opts[fn] !== 'function') throw new Error(`Invalid ${fn} function`);
}
const opts = validateBasic(curve);
ut.validateObject(
curve,
{
hash: 'function',
a: 'bigint',
d: 'bigint',
randomBytes: 'function',
},
{
adjustScalarBytes: 'function',
domain: 'function',
uvRatio: 'function',
mapToCurve: 'function',
}
);
// Set defaults
return Object.freeze({ ...opts } as const);
}
@ -75,7 +64,7 @@ export interface ExtPointConstructor extends GroupConstructor<ExtPointType> {
new (x: bigint, y: bigint, z: bigint, t: bigint): ExtPointType;
fromAffine(p: AffinePoint<bigint>): ExtPointType;
fromHex(hex: Hex): ExtPointType;
fromPrivateKey(privateKey: Hex): ExtPointType; // TODO: remove
fromPrivateKey(privateKey: Hex): ExtPointType;
}
export type CurveFn = {
@ -340,7 +329,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const normed = hex.slice(); // copy again, we'll manipulate it
const lastByte = hex[len - 1]; // select last byte
normed[len - 1] = lastByte & ~0x80; // clear last bit
const y = bytesToNumberLE(normed);
const y = ut.bytesToNumberLE(normed);
if (y === _0n) {
// y=0 is allowed
} else {
@ -366,12 +355,12 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
}
toRawBytes(): Uint8Array {
const { x, y } = this.toAffine();
const bytes = numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y)
const bytes = ut.numberToBytesLE(y, Fp.BYTES); // each y has 2 x values (x, -y)
bytes[bytes.length - 1] |= x & _1n ? 0x80 : 0; // when compressing, it's enough to store y
return bytes; // and use the last byte to encode sign of x
}
toHex(): string {
return bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string.
return ut.bytesToHex(this.toRawBytes()); // Same as toRawBytes, but returns string.
}
}
const { BASE: G, ZERO: I } = Point;
@ -382,7 +371,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
}
// Little-endian SHA512 with modulo n
function modN_LE(hash: Uint8Array): bigint {
return modN(bytesToNumberLE(hash));
return modN(ut.bytesToNumberLE(hash));
}
function isHex(item: Hex, err: string) {
if (typeof item !== 'string' && !(item instanceof Uint8Array))
@ -411,7 +400,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// int('LE', SHA512(dom2(F, C) || msgs)) mod N
function hashDomainToScalar(context: Hex = new Uint8Array(), ...msgs: Uint8Array[]) {
const msg = concatBytes(...msgs);
const msg = ut.concatBytes(...msgs);
return modN_LE(cHash(domain(msg, ensureBytes(context), !!preHash)));
}
@ -426,7 +415,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const k = hashDomainToScalar(context, R, pointBytes, msg); // R || A || PH(M)
const s = modN(r + k * scalar); // S = (r + k * s) mod L
assertGE0(s); // 0 <= s < l
const res = concatBytes(R, numberToBytesLE(s, Fp.BYTES));
const res = ut.concatBytes(R, ut.numberToBytesLE(s, Fp.BYTES));
return ensureBytes(res, nByteLength * 2); // 64-byte signature
}
@ -439,7 +428,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
if (preHash) msg = preHash(msg); // for ed25519ph, etc
const A = Point.fromHex(publicKey, false); // Check for s bounds, hex validity
const R = Point.fromHex(sig.slice(0, len), false); // 0 <= R < 2^256: ZIP215 R can be >= P
const s = bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l
const s = ut.bytesToNumberLE(sig.slice(len, 2 * len)); // 0 <= s < l
const SB = G.multiplyUnsafe(s);
const k = hashDomainToScalar(context, R.toRawBytes(), A.toRawBytes(), msg);
const RkA = R.add(A.multiplyUnsafe(k));

@ -45,7 +45,7 @@ declare const TextDecoder: any;
export function stringToBytes(str: string): Uint8Array {
if (typeof str !== 'string') {
throw new TypeError(`utf8ToBytes expected string, got ${typeof str}`);
throw new Error(`utf8ToBytes expected string, got ${typeof str}`);
}
return new TextEncoder().encode(str);
}

@ -7,6 +7,7 @@ import {
bytesToNumberBE,
bytesToNumberLE,
ensureBytes,
validateObject,
} from './utils.js';
// prettier-ignore
const _0n = BigInt(0), _1n = BigInt(1), _2n = BigInt(2), _3n = BigInt(3);
@ -40,7 +41,6 @@ export function pow(num: bigint, power: bigint, modulo: bigint): bigint {
}
// Does x ^ (2 ^ power) mod p. pow2(30, 4) == 30 ^ (2 ^ 4)
// TODO: Fp version?
export function pow2(x: bigint, power: bigint, modulo: bigint): bigint {
let res = x;
while (power-- > _0n) {
@ -249,18 +249,17 @@ const FIELD_FIELDS = [
'addN', 'subN', 'mulN', 'sqrN'
] as const;
export function validateField<T>(field: Field<T>) {
for (const i of ['ORDER', 'MASK'] as const) {
if (typeof field[i] !== 'bigint')
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`);
}
for (const i of ['BYTES', 'BITS'] as const) {
if (typeof field[i] !== 'number')
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`);
}
for (const i of FIELD_FIELDS) {
if (typeof field[i] !== 'function')
throw new Error(`Invalid field param ${i}=${field[i]} (${typeof field[i]})`);
}
const initial = {
ORDER: 'bigint',
MASK: 'bigint',
BYTES: 'isSafeInteger',
BITS: 'isSafeInteger',
} as Record<string, string>;
const opts = FIELD_FIELDS.reduce((map, val: string) => {
map[val] = 'function';
return map;
}, initial);
return validateObject(field, opts);
}
// Generic field functions

@ -1,14 +1,13 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
import { mod, pow } from './modular.js';
import { ensureBytes, numberToBytesLE, bytesToNumberLE } from './utils.js';
import { bytesToNumberLE, ensureBytes, numberToBytesLE, validateObject } from './utils.js';
const _0n = BigInt(0);
const _1n = BigInt(1);
type Hex = string | Uint8Array;
export type CurveType = {
// Field over which we'll do calculations. Verify with:
P: bigint;
P: bigint; // finite field prime
nByteLength: number;
adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array;
domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array;
@ -27,24 +26,20 @@ export type CurveFn = {
};
function validateOpts(curve: CurveType) {
for (const i of ['a24'] as const) {
if (typeof curve[i] !== 'bigint')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
}
for (const i of ['montgomeryBits', 'nByteLength'] as const) {
if (curve[i] === undefined) continue; // Optional
if (!Number.isSafeInteger(curve[i]))
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
}
for (const fn of ['adjustScalarBytes', 'domain', 'powPminus2'] as const) {
if (curve[fn] === undefined) continue; // Optional
if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`);
}
for (const i of ['Gu'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'string')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
}
validateObject(
curve,
{
a24: 'bigint',
},
{
montgomeryBits: 'isSafeInteger',
nByteLength: 'isSafeInteger',
adjustScalarBytes: 'function',
domain: 'function',
powPminus2: 'function',
Gu: 'string',
}
);
// Set defaults
return Object.freeze({ ...curve } as const);
}
@ -61,27 +56,7 @@ export function montgomery(curveDef: CurveType): CurveFn {
const adjustScalarBytes = CURVE.adjustScalarBytes || ((bytes: Uint8Array) => bytes);
const powPminus2 = CURVE.powPminus2 || ((x: bigint) => pow(x, P - BigInt(2), P));
/**
* Checks for num to be in range:
* For strict == true: `0 < num < max`.
* For strict == false: `0 <= num < max`.
* Converts non-float safe numbers to bigints.
*/
function normalizeScalar(num: bigint, max: bigint, strict = true): bigint {
if (!max) throw new TypeError('Specify max value');
if (typeof num === 'number' && Number.isSafeInteger(num)) num = BigInt(num);
if (typeof num === 'bigint' && num < max) {
if (strict) {
if (_0n < num) return num;
} else {
if (_0n <= num) return num;
}
}
throw new TypeError('Expected valid scalar: 0 < scalar < max');
}
// cswap from RFC7748
// NOTE: cswap is not from RFC7748!
// cswap from RFC7748. But it is not from RFC7748!
/*
cswap(swap, x_2, x_3):
dummy = mask(swap) AND (x_2 XOR x_3)
@ -98,6 +73,11 @@ export function montgomery(curveDef: CurveType): CurveFn {
return [x_2, x_3];
}
function assertFieldElement(n: bigint): bigint {
if (typeof n === 'bigint' && _0n <= n && n < P) return n;
throw new Error('Expected valid scalar 0 < scalar < CURVE.P');
}
// x25519 from 4
/**
*
@ -106,11 +86,10 @@ export function montgomery(curveDef: CurveType): CurveFn {
* @returns new Point on Montgomery curve
*/
function montgomeryLadder(pointU: bigint, scalar: bigint): bigint {
const { P } = CURVE;
const u = normalizeScalar(pointU, P);
const u = assertFieldElement(pointU);
// Section 5: Implementations MUST accept non-canonical values and process them as
// if they had been reduced modulo the field prime.
const k = normalizeScalar(scalar, P);
const k = assertFieldElement(scalar);
// The constant a24 is (486662 - 2) / 4 = 121665 for curve25519/X25519
const a24 = CURVE.a24;
const x_1 = u;
@ -166,28 +145,20 @@ export function montgomery(curveDef: CurveType): CurveFn {
}
function decodeUCoordinate(uEnc: Hex): bigint {
const u = ensureBytes(uEnc, montgomeryBytes);
// Section 5: When receiving such an array, implementations of X25519
// MUST mask the most significant bit in the final byte.
// This is very ugly way, but it works because fieldLen-1 is outside of bounds for X448, so this becomes NOOP
// fieldLen - scalaryBytes = 1 for X448 and = 0 for X25519
const u = ensureBytes(uEnc, montgomeryBytes);
u[fieldLen - 1] &= 127; // 0b0111_1111
return bytesToNumberLE(u);
}
function decodeScalar(n: Hex): bigint {
const bytes = ensureBytes(n);
if (bytes.length !== montgomeryBytes && bytes.length !== fieldLen)
throw new Error(`Expected ${montgomeryBytes} or ${fieldLen} bytes, got ${bytes.length}`);
return bytesToNumberLE(adjustScalarBytes(bytes));
}
/**
* Computes shared secret between private key "scalar" and public key's "u" (x) coordinate.
* We can get 'y' coordinate from 'u',
* but Point.fromHex also wants 'x' coordinate oddity flag,
* and we cannot get 'x' without knowing 'v'.
* Need to add generic conversion between twisted edwards and complimentary curve for JubJub.
*/
function scalarMult(scalar: Hex, u: Hex): Uint8Array {
const pointU = decodeUCoordinate(u);
const _scalar = decodeScalar(scalar);
@ -197,12 +168,7 @@ export function montgomery(curveDef: CurveType): CurveFn {
if (pu === _0n) throw new Error('Invalid private or public key received');
return encodeUCoordinate(pu);
}
/**
* Computes public key from private.
* Executes scalar multiplication of curve's base point by scalar.
* @param scalar private key
* @returns new public key
*/
// Computes public key from private. By doing scalar multiplication of base point.
function scalarMultBase(scalar: Hex): Uint8Array {
return scalarMult(scalar, CURVE.Gu);
}

@ -1,6 +1,6 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Poseidon Hash: https://eprint.iacr.org/2019/458.pdf, https://www.poseidon-hash.info
import { Field, validateField, FpPow } from './modular.js';
import { Field, FpPow, validateField } from './modular.js';
// We don't provide any constants, since different implementations use different constants.
// For reference constants see './test/poseidon.test.js'.
export type PoseidonOpts = {

@ -18,7 +18,7 @@ export type FHash = (message: Uint8Array | string) => Uint8Array;
const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0'));
export function bytesToHex(bytes: Uint8Array): string {
if (!u8a(bytes)) throw new Error('Expected Uint8Array');
if (!u8a(bytes)) throw new Error('Uint8Array expected');
// pre-caching improves the speed 6x
let hex = '';
for (let i = 0; i < bytes.length; i++) {
@ -33,21 +33,21 @@ export function numberToHexUnpadded(num: number | bigint): string {
}
export function hexToNumber(hex: string): bigint {
if (typeof hex !== 'string') throw new Error('hexToNumber: expected string, got ' + typeof hex);
if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex);
// Big Endian
return BigInt(`0x${hex}`);
}
// Caching slows it down 2-3x
export function hexToBytes(hex: string): Uint8Array {
if (typeof hex !== 'string') throw new Error('hexToBytes: expected string, got ' + typeof hex);
if (hex.length % 2) throw new Error('hexToBytes: received invalid unpadded hex ' + hex.length);
if (typeof hex !== 'string') throw new Error('string expected, got ' + typeof hex);
if (hex.length % 2) throw new Error('hex string is invalid: unpadded ' + hex.length);
const array = new Uint8Array(hex.length / 2);
for (let i = 0; i < array.length; i++) {
const j = i * 2;
const hexByte = hex.slice(j, j + 2);
const byte = Number.parseInt(hexByte, 16);
if (Number.isNaN(byte) || byte < 0) throw new Error('Invalid byte sequence');
if (Number.isNaN(byte) || byte < 0) throw new Error('invalid byte sequence');
array[i] = byte;
}
return array;
@ -58,7 +58,7 @@ export function bytesToNumberBE(bytes: Uint8Array): bigint {
return hexToNumber(bytesToHex(bytes));
}
export function bytesToNumberLE(bytes: Uint8Array): bigint {
if (!u8a(bytes)) throw new Error('Expected Uint8Array');
if (!u8a(bytes)) throw new Error('Uint8Array expected');
return hexToNumber(bytesToHex(Uint8Array.from(bytes).reverse()));
}
@ -66,11 +66,7 @@ export const numberToBytesBE = (n: bigint, len: number) =>
hexToBytes(n.toString(16).padStart(len * 2, '0'));
export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse();
// Returns variable number bytes (minimal bigint encoding?)
export const numberToVarBytesBE = (n: bigint) => {
let hex = n.toString(16);
if (hex.length & 1) hex = '0' + hex;
return hexToBytes(hex);
};
export const numberToVarBytesBE = (n: bigint) => hexToBytes(numberToHexUnpadded(n));
export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
// Uint8Array.from() instead of hash.slice() because node.js Buffer
@ -82,17 +78,15 @@ export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
}
// Copies several Uint8Arrays into one.
export function concatBytes(...arrays: Uint8Array[]): Uint8Array {
if (!arrays.every((b) => u8a(b))) throw new Error('Uint8Array list expected');
if (arrays.length === 1) return arrays[0];
const length = arrays.reduce((a, arr) => a + arr.length, 0);
const result = new Uint8Array(length);
for (let i = 0, pad = 0; i < arrays.length; i++) {
const arr = arrays[i];
result.set(arr, pad);
pad += arr.length;
}
return result;
export function concatBytes(...arrs: Uint8Array[]): Uint8Array {
const r = new Uint8Array(arrs.reduce((sum, a) => sum + a.length, 0));
let pad = 0; // walk through each item, ensure they have proper type
arrs.forEach((a) => {
if (!u8a(a)) throw new Error('Uint8Array expected');
r.set(a, pad);
pad += a.length;
});
return r;
}
export function equalBytes(b1: Uint8Array, b2: Uint8Array) {
@ -119,3 +113,32 @@ export const bitSet = (n: bigint, pos: number, value: boolean) =>
// Return mask for N bits (Same as BigInt(`0b${Array(i).fill('1').join('')}`))
// Not using ** operator with bigints for old engines.
export const bitMask = (n: number) => (_2n << BigInt(n - 1)) - _1n;
type ValMap = Record<string, string>;
export function validateObject(object: object, validators: ValMap, optValidators: ValMap = {}) {
const validatorFns: Record<string, (val: any) => boolean> = {
bigint: (val) => typeof val === 'bigint',
function: (val) => typeof val === 'function',
boolean: (val) => typeof val === 'boolean',
string: (val) => typeof val === 'string',
isSafeInteger: (val) => Number.isSafeInteger(val),
array: (val) => Array.isArray(val),
field: (val) => (object as any).Fp.isValid(val),
hash: (val) => typeof val === 'function' && Number.isSafeInteger(val.outputLen),
};
// type Key = keyof typeof validators;
const checkField = (fieldName: string, type: string, isOptional: boolean) => {
const checkVal = validatorFns[type];
if (typeof checkVal !== 'function')
throw new Error(`Invalid validator "${type}", expected function`);
const val = object[fieldName as keyof typeof object];
if (isOptional && val === undefined) return;
if (!checkVal(val)) {
throw new Error(`Invalid param ${fieldName}=${val} (${typeof val}), expected ${type}`);
}
};
for (let [fieldName, type] of Object.entries(validators)) checkField(fieldName, type, false);
for (let [fieldName, type] of Object.entries(optValidators)) checkField(fieldName, type, true);
return object;
}

@ -2,15 +2,8 @@
// Short Weierstrass curve. The formula is: y² = x³ + ax + b
import * as mod from './modular.js';
import * as ut from './utils.js';
import { Hex, PrivKey, ensureBytes, CHash } from './utils.js';
import {
Group,
GroupConstructor,
wNAF,
AbstractCurve,
validateAbsOpts,
AffinePoint,
} from './curve.js';
import { CHash, Hex, PrivKey, ensureBytes } from './utils.js';
import { Group, GroupConstructor, wNAF, BasicCurve, validateBasic, AffinePoint } from './curve.js';
export type { AffinePoint };
type HmacFnSync = (key: Uint8Array, ...messages: Uint8Array[]) => Uint8Array;
@ -18,7 +11,7 @@ type EndomorphismOpts = {
beta: bigint;
splitScalar: (k: bigint) => { k1neg: boolean; k1: bigint; k2neg: boolean; k2: bigint };
};
export type BasicCurve<T> = AbstractCurve<T> & {
export type BasicWCurve<T> = BasicCurve<T> & {
// Params: a, b
a: T;
b: T;
@ -86,34 +79,32 @@ export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
normalizeZ(points: ProjPointType<T>[]): ProjPointType<T>[];
}
export type CurvePointsType<T> = BasicCurve<T> & {
export type CurvePointsType<T> = BasicWCurve<T> & {
// Bytes
fromBytes: (bytes: Uint8Array) => AffinePoint<T>;
toBytes: (c: ProjConstructor<T>, point: ProjPointType<T>, compressed: boolean) => Uint8Array;
};
function validatePointOpts<T>(curve: CurvePointsType<T>) {
const opts = validateAbsOpts(curve);
const Fp = opts.Fp;
for (const i of ['a', 'b'] as const) {
if (!Fp.isValid(curve[i]))
throw new Error(`Invalid curve param ${i}=${opts[i]} (${typeof opts[i]})`);
}
for (const i of ['allowedPrivateKeyLengths'] as const) {
if (curve[i] === undefined) continue; // Optional
if (!Array.isArray(curve[i])) throw new Error(`Invalid ${i} array`);
}
for (const i of ['wrapPrivateKey'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'boolean') throw new Error(`Invalid ${i} boolean`);
}
for (const i of ['isTorsionFree', 'clearCofactor'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'function') throw new Error(`Invalid ${i} function`);
}
const endo = opts.endo;
const opts = validateBasic(curve);
ut.validateObject(
opts,
{
a: 'field',
b: 'field',
fromBytes: 'function',
toBytes: 'function',
},
{
allowedPrivateKeyLengths: 'array',
wrapPrivateKey: 'boolean',
isTorsionFree: 'function',
clearCofactor: 'function',
}
);
const { endo, Fp, a } = opts;
if (endo) {
if (!Fp.eql(opts.a, Fp.ZERO)) {
if (!Fp.eql(a, Fp.ZERO)) {
throw new Error('Endomorphism can only be defined for Koblitz curves that have a=0');
}
if (
@ -124,9 +115,6 @@ function validatePointOpts<T>(curve: CurvePointsType<T>) {
throw new Error('Expected endomorphism with beta: bigint and splitScalar: function');
}
}
if (typeof opts.fromBytes !== 'function') throw new Error('Invalid fromBytes function');
if (typeof opts.toBytes !== 'function') throw new Error('Invalid fromBytes function');
// Set defaults
return Object.freeze({ ...opts } as const);
}
@ -609,25 +597,30 @@ type SignatureLike = { r: bigint; s: bigint };
export type PubKey = Hex | ProjPointType<bigint>;
export type CurveType = BasicCurve<bigint> & {
// Default options
lowS?: boolean;
// Hashes
hash: CHash; // Because we need outputLen for DRBG
export type CurveType = BasicWCurve<bigint> & {
hash: CHash; // CHash not FHash because we need outputLen for DRBG
hmac: HmacFnSync;
randomBytes: (bytesLength?: number) => Uint8Array;
// truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => Uint8Array;
lowS?: boolean;
bits2int?: (bytes: Uint8Array) => bigint;
bits2int_modN?: (bytes: Uint8Array) => bigint;
};
function validateOpts(curve: CurveType) {
const opts = validateAbsOpts(curve);
if (typeof opts.hash !== 'function' || !Number.isSafeInteger(opts.hash.outputLen))
throw new Error('Invalid hash function');
if (typeof opts.hmac !== 'function') throw new Error('Invalid hmac function');
if (typeof opts.randomBytes !== 'function') throw new Error('Invalid randomBytes function');
// Set defaults
const opts = validateBasic(curve);
ut.validateObject(
opts,
{
hash: 'hash',
hmac: 'function',
randomBytes: 'function',
},
{
bits2int: 'function',
bits2int_modN: 'function',
lowS: 'boolean',
}
);
return Object.freeze({ lowS: true, ...opts } as const);
}
@ -756,7 +749,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
return { x, y };
} else {
throw new Error(
`Point.fromHex: received invalid point. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes, not ${len}`
`Point of length ${len} was invalid. Expected ${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes`
);
}
},
@ -821,7 +814,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
const ir = invN(radj); // r^-1
const u1 = modN(-h * ir); // -hr^-1
const u2 = modN(s * ir); // sr^-1
const Q = Point.BASE.multiplyAndAddUnsafe(R, u1, u2); // (sr^-1)R-(hr^-1)G = -(hr^-1)G + (sr^-1)
const Q = Point.BASE.multiplyAndAddUnsafe(R, u1, u2); // (sr^-1)R-(hr^-1)G = -(hr^-1)G + (sr^-1)
if (!Q) throw new Error('point at infinify'); // unsafe is fine: no priv data leaked
Q.assertValidity();
return Q;
@ -951,9 +944,10 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// NOTE: pads output with zero as per spec
const ORDER_MASK = ut.bitMask(CURVE.nBitLength);
function int2octets(num: bigint): Uint8Array {
if (typeof num !== 'bigint') throw new Error('Expected bigint');
if (typeof num !== 'bigint') throw new Error('bigint expected');
if (!(_0n <= num && num < ORDER_MASK))
throw new Error(`Expected number < 2^${CURVE.nBitLength}`);
// n in [0..ORDER_MASK-1]
throw new Error(`bigint expected < 2^${CURVE.nBitLength}`);
// works with order, can have different size than numToField!
return ut.numberToBytesBE(num, CURVE.nByteLength);
}
@ -1045,7 +1039,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
* ```
*/
function verify(
signature: Hex | { r: bigint; s: bigint },
signature: Hex | SignatureLike,
msgHash: Hex,
publicKey: Hex,
opts = defaultVerOpts
@ -1090,7 +1084,6 @@ export function weierstrass(curveDef: CurveType): CurveFn {
getSharedSecret,
sign,
verify,
// Point,
ProjectivePoint: Point,
Signature,
utils,

@ -10,7 +10,7 @@ export const P224 = createCurve(
// Params: a, b
a: BigInt('0xfffffffffffffffffffffffffffffffefffffffffffffffffffffffe'),
b: BigInt('0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4'),
// Field over which we'll do calculations; 2n**224n - 2n**96n + 1n
// Field over which we'll do calculations;
Fp: Fp(BigInt('0xffffffffffffffffffffffffffffffff000000000000000000000001')),
// Curve order, total count of valid points in the field
n: BigInt('0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d'),

@ -175,7 +175,6 @@ function schnorrSign(message: Hex, privateKey: Hex, auxRand: Hex = randomBytes(3
const { x: px, scalar: d } = schnorrGetScalar(bytesToNum(ensureBytes(privateKey, 32)));
const a = ensureBytes(auxRand, 32); // Auxiliary random data a: a 32-byte array
// TODO: replace with proper xor?
const t = numTo32b(d ^ bytesToNum(taggedHash(TAGS.aux, a))); // Let t be the byte-wise xor of bytes(d) and hash/aux(a)
const rand = taggedHash(TAGS.nonce, t, px, m); // Let rand = hash/nonce(t || bytes(P) || m)
const k_ = modN(bytesToNum(rand)); // Let k' = int(rand) mod n

@ -86,7 +86,8 @@ describe('wycheproof ECDH', () => {
try {
const pub = CURVE.ProjectivePoint.fromHex(test.public);
} catch (e) {
if (e.message.includes('Point.fromHex: received invalid point.')) continue;
// Our strict validation filter doesn't let weird-length DER vectors
if (e.message.startsWith('Point of length')) continue;
throw e;
}
const shared = CURVE.getSharedSecret(test.private, test.public);
@ -140,7 +141,8 @@ describe('wycheproof ECDH', () => {
try {
const pub = curve.ProjectivePoint.fromHex(test.public);
} catch (e) {
if (e.message.includes('Point.fromHex: received invalid point.')) continue;
// Our strict validation filter doesn't let weird-length DER vectors
if (e.message.includes('Point of length')) continue;
throw e;
}
const shared = curve.getSharedSecret(test.private, test.public);
@ -194,7 +196,6 @@ const WYCHEPROOF_ECDSA = {
secp256k1: {
curve: secp256k1,
hashes: {
// TODO: debug why fails, can be bug
sha256: {
hash: sha256,
tests: [secp256k1_sha256_test],