Add eddsa prehashed mode, diffie-hellman

This commit is contained in:
Paul Miller 2022-12-11 14:54:30 +00:00
parent 4c6ca2326a
commit c8fc24fd8f
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
4 changed files with 305 additions and 86 deletions

@ -11,7 +11,14 @@
// 5. Domain function was no-op for ed25519, but adds some data even with empty context for ed448 // 5. Domain function was no-op for ed25519, but adds some data even with empty context for ed448
import * as mod from './modular.js'; import * as mod from './modular.js';
import { bytesToHex, concatBytes, ensureBytes, numberToBytesLE, nLength } from './utils.js'; import {
bytesToHex,
concatBytes,
ensureBytes,
numberToBytesLE,
nLength,
hashToPrivateScalar,
} from './utils.js';
import { wNAF } from './group.js'; import { wNAF } from './group.js';
// Be friendly to bad ECMAScript parsers by not using bigint literals like 123n // Be friendly to bad ECMAScript parsers by not using bigint literals like 123n
@ -42,16 +49,21 @@ export type CurveType = {
// Base point (x, y) aka generator point // Base point (x, y) aka generator point
Gx: bigint; Gx: bigint;
Gy: bigint; Gy: bigint;
// Other constants
a24: bigint;
// ECDH bits (can be different from N bits)
scalarBits: number;
// Hashes // Hashes
hash: CHash; // Because we need outputLen for DRBG hash: CHash; // Because we need outputLen for DRBG
randomBytes: (bytesLength?: number) => Uint8Array; randomBytes: (bytesLength?: number) => Uint8Array;
adjustScalarBytes: (bytes: Uint8Array) => Uint8Array; adjustScalarBytes?: (bytes: Uint8Array) => Uint8Array;
domain: (data: Uint8Array, ctx: Uint8Array, hflag: boolean) => Uint8Array; domain?: (data: Uint8Array, ctx: Uint8Array, phflag: boolean) => Uint8Array;
uvRatio: (u: bigint, v: bigint) => { isValid: boolean; value: bigint }; uvRatio?: (u: bigint, v: bigint) => { isValid: boolean; value: bigint };
preHash?: CHash;
// ECDH related
// Other constants
a24: bigint; // Related to d, but cannot be derived from it
// ECDH bits (can be different from N bits)
montgomeryBits?: number;
basePointU?: string; // TODO: why not bigint?
powPminus2?: (x: bigint) => bigint;
UfromPoint?: (p: PointType) => Uint8Array;
}; };
// We accept hex strings besides Uint8Array for simplicity // We accept hex strings besides Uint8Array for simplicity
@ -67,18 +79,29 @@ function validateOpts(curve: CurveType) {
if (typeof curve[i] !== 'bigint') if (typeof curve[i] !== 'bigint')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
} }
for (const i of ['scalarBits'] as const) { for (const i of ['nBitLength', 'nByteLength', 'montgomeryBits'] as const) {
if (typeof curve[i] !== 'number')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
}
for (const i of ['nBitLength', 'nByteLength'] as const) {
if (curve[i] === undefined) continue; // Optional if (curve[i] === undefined) continue; // Optional
if (!Number.isSafeInteger(curve[i])) if (!Number.isSafeInteger(curve[i]))
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`); throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
} }
for (const fn of ['randomBytes', 'adjustScalarBytes', 'domain', 'uvRatio'] as const) { for (const fn of ['randomBytes'] as const) {
if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`); if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`);
} }
for (const fn of [
'adjustScalarBytes',
'domain',
'uvRatio',
'powPminus2',
'UfromPoint',
] as const) {
if (curve[fn] === undefined) continue; // Optional
if (typeof curve[fn] !== 'function') throw new Error(`Invalid ${fn} function`);
}
for (const i of ['basePointU'] as const) {
if (curve[i] === undefined) continue; // Optional
if (typeof curve[i] !== 'string')
throw new Error(`Invalid curve param ${i}=${curve[i]} (${typeof curve[i]})`);
}
// Set defaults // Set defaults
return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const); return Object.freeze({ ...nLength(curve.n, curve.nBitLength), ...curve } as const);
} }
@ -131,7 +154,6 @@ export interface PointType {
_setWindowSize(windowSize: number): void; _setWindowSize(windowSize: number): void;
toRawBytes(isCompressed?: boolean): Uint8Array; toRawBytes(isCompressed?: boolean): Uint8Array;
toHex(isCompressed?: boolean): string; toHex(isCompressed?: boolean): string;
// toX25519(): Uint8Array;
isTorsionFree(): boolean; isTorsionFree(): boolean;
equals(other: PointType): boolean; equals(other: PointType): boolean;
negate(): PointType; negate(): PointType;
@ -154,11 +176,20 @@ export type SigType = Hex | SignatureType;
export type CurveFn = { export type CurveFn = {
CURVE: ReturnType<typeof validateOpts>; CURVE: ReturnType<typeof validateOpts>;
getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array; getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array;
getSharedSecret: (privateKey: PrivKey, publicKey: Hex) => Uint8Array;
sign: (message: Hex, privateKey: Hex) => Uint8Array; sign: (message: Hex, privateKey: Hex) => Uint8Array;
verify: (sig: SigType, message: Hex, publicKey: PubKey) => boolean; verify: (sig: SigType, message: Hex, publicKey: PubKey) => boolean;
Point: PointConstructor; Point: PointConstructor;
ExtendedPoint: ExtendedPointConstructor; ExtendedPoint: ExtendedPointConstructor;
Signature: SignatureConstructor; Signature: SignatureConstructor;
montgomeryCurve: {
BASE_POINT_U: string;
UfromPoint: (p: PointType) => Uint8Array;
scalarMult: (u: Hex, scalar: Hex) => Uint8Array;
scalarMultBase: (scalar: Hex) => Uint8Array;
getPublicKey: (privateKey: Hex) => Uint8Array;
getSharedSecret: (privateKey: Hex, publicKey: Hex) => Uint8Array;
};
utils: { utils: {
mod: (a: bigint, b?: bigint) => bigint; mod: (a: bigint, b?: bigint) => bigint;
invert: (number: bigint, modulo?: bigint) => bigint; invert: (number: bigint, modulo?: bigint) => bigint;
@ -184,11 +215,29 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const maxGroupElement = _2n ** BigInt(groupLen * 8); // previous POW_2_256 const maxGroupElement = _2n ** BigInt(groupLen * 8); // previous POW_2_256
// Function overrides // Function overrides
const { adjustScalarBytes, randomBytes, uvRatio } = CURVE; const { P, randomBytes } = CURVE;
const modP = (a: bigint) => mod.mod(a, P);
function modP(a: bigint) { // sqrt(u/v)
return mod.mod(a, CURVE.P); function _uvRatio(u: bigint, v: bigint) {
try {
const value = mod.sqrt(u * mod.invert(v, P), P);
return { isValid: true, value };
} catch (e) {
return { isValid: false, value: _0n };
}
} }
const uvRatio = CURVE.uvRatio || _uvRatio;
const _powPminus2 = (x: bigint) => mod.pow(x, P - _2n, P);
const powPminus2 = CURVE.powPminus2 || _powPminus2;
const _adjustScalarBytes = (bytes: Uint8Array) => bytes; // NOOP
const adjustScalarBytes = CURVE.adjustScalarBytes || _adjustScalarBytes;
function _domain(data: Uint8Array, ctx: Uint8Array, phflag: boolean) {
if (ctx.length || phflag) throw new Error('Contexts/pre-hash are not supported');
return data;
}
const domain = CURVE.domain || _domain; // NOOP
/** /**
* Extended Point works in extended coordinates: (x, y, z, t) (x=x/z, y=y/z, t=xy). * Extended Point works in extended coordinates: (x, y, z, t) (x=x/z, y=y/z, t=xy).
@ -213,7 +262,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
static toAffineBatch(points: ExtendedPoint[]): Point[] { static toAffineBatch(points: ExtendedPoint[]): Point[] {
const toInv = mod.invertBatch( const toInv = mod.invertBatch(
points.map((p) => p.z), points.map((p) => p.z),
CURVE.P P
); );
return points.map((p, i) => p.toAffine(toInv[i])); return points.map((p, i) => p.toAffine(toInv[i]));
} }
@ -341,7 +390,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
toAffine(invZ?: bigint): Point { toAffine(invZ?: bigint): Point {
const { x, y, z } = this; const { x, y, z } = this;
const is0 = this.equals(ExtendedPoint.ZERO); const is0 = this.equals(ExtendedPoint.ZERO);
if (invZ == null) invZ = is0 ? _8n : mod.invert(z, CURVE.P); // 8 was chosen arbitrarily if (invZ == null) invZ = is0 ? _8n : mod.invert(z, P); // 8 was chosen arbitrarily
const ax = modP(x * invZ); const ax = modP(x * invZ);
const ay = modP(y * invZ); const ay = modP(y * invZ);
const zz = modP(z * invZ); const zz = modP(z * invZ);
@ -419,9 +468,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// 2, set x <-- p - x. Return the decoded point (x,y). // 2, set x <-- p - x. Return the decoded point (x,y).
const isXOdd = (x & _1n) === _1n; const isXOdd = (x & _1n) === _1n;
const isLastByteOdd = (lastByte & 0x80) !== 0; const isLastByteOdd = (lastByte & 0x80) !== 0;
if (isLastByteOdd !== isXOdd) { if (isLastByteOdd !== isXOdd) x = modP(-x);
x = modP(-x);
}
return new Point(x, y); return new Point(x, y);
} }
@ -575,26 +622,36 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
return getExtendedPublicKey(privateKey).pointBytes; return getExtendedPublicKey(privateKey).pointBytes;
} }
const EMPTY = new Uint8Array();
function hashDomainToScalar(message: Uint8Array, context: Hex = EMPTY) {
context = ensureBytes(context);
return modlLE(CURVE.hash(domain(message, context, !!CURVE.preHash)));
}
/** Signs message with privateKey. RFC8032 5.1.6 */ /** Signs message with privateKey. RFC8032 5.1.6 */
function sign(message: Hex, privateKey: Hex): Uint8Array { function sign(message: Hex, privateKey: Hex, context?: Hex): Uint8Array {
message = ensureBytes(message); message = ensureBytes(message);
if (CURVE.preHash) message = CURVE.preHash(message);
const { prefix, scalar, pointBytes } = getExtendedPublicKey(privateKey); const { prefix, scalar, pointBytes } = getExtendedPublicKey(privateKey);
const rDomain = CURVE.domain(concatBytes(prefix, message), new Uint8Array(), false); const r = hashDomainToScalar(concatBytes(prefix, message), context);
const r = modlLE(CURVE.hash(rDomain)); // r = hash(prefix + msg)
const R = Point.BASE.multiply(r); // R = rG const R = Point.BASE.multiply(r); // R = rG
const kDomain = CURVE.domain( const k = hashDomainToScalar(concatBytes(R.toRawBytes(), pointBytes, message), context); // k = hash(R+P+msg)
concatBytes(R.toRawBytes(), pointBytes, message),
new Uint8Array(),
false
);
const k = modlLE(CURVE.hash(kDomain)); // k = hash(R+P+msg)
const s = mod.mod(r + k * scalar, CURVE_ORDER); // s = r + kp const s = mod.mod(r + k * scalar, CURVE_ORDER); // s = r + kp
return new Signature(R, s).toRawBytes(); return new Signature(R, s).toRawBytes();
} }
// Helper functions because we have async and sync methods. /**
function prepareVerification(sig: SigType, message: Hex, publicKey: PubKey) { * Verifies EdDSA signature against message and public key.
* An extended group equation is checked.
* RFC8032 5.1.7
* Compliant with ZIP215:
* 0 <= sig.R/publicKey < 2**256 (can be >= curve.P)
* 0 <= sig.s < l
* Not compliant with RFC8032: it's not possible to comply to both ZIP & RFC at the same time.
*/
function verify(sig: SigType, message: Hex, publicKey: PubKey, context?: Hex): boolean {
message = ensureBytes(message); message = ensureBytes(message);
if (CURVE.preHash) message = CURVE.preHash(message);
// When hex is passed, we check public key fully. // When hex is passed, we check public key fully.
// When Point instance is passed, we assume it has already been checked, for performance. // When Point instance is passed, we assume it has already been checked, for performance.
// If user passes Point/Sig instance, we assume it has been already verified. // If user passes Point/Sig instance, we assume it has been already verified.
@ -614,40 +671,189 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const { r, s } = sig; const { r, s } = sig;
const SB = ExtendedPoint.BASE.multiplyUnsafe(s); const SB = ExtendedPoint.BASE.multiplyUnsafe(s);
return { r, s, SB, pub: publicKey, msg: message }; const k = hashDomainToScalar(
} concatBytes(r.toRawBytes(), publicKey.toRawBytes(), message),
context
function finishVerification(publicKey: Point, r: Point, SB: ExtendedPoint, hashed: Uint8Array) { );
const k = modlLE(hashed);
const kA = ExtendedPoint.fromAffine(publicKey).multiplyUnsafe(k); const kA = ExtendedPoint.fromAffine(publicKey).multiplyUnsafe(k);
const RkA = ExtendedPoint.fromAffine(r).add(kA); const RkA = ExtendedPoint.fromAffine(r).add(kA);
// [8][S]B = [8]R + [8][k]A' // [8][S]B = [8]R + [8][k]A'
return RkA.subtract(SB).multiplyUnsafe(CURVE.h).equals(ExtendedPoint.ZERO); return RkA.subtract(SB).multiplyUnsafe(CURVE.h).equals(ExtendedPoint.ZERO);
} }
/**
* Verifies EdDSA signature against message and public key.
* An extended group equation is checked.
* RFC8032 5.1.7
* Compliant with ZIP215:
* 0 <= sig.R/publicKey < 2**256 (can be >= curve.P)
* 0 <= sig.s < l
* Not compliant with RFC8032: it's not possible to comply to both ZIP & RFC at the same time.
*/
function verify(sig: SigType, message: Hex, publicKey: PubKey): boolean {
const { r, SB, msg, pub } = prepareVerification(sig, message, publicKey);
const domain = CURVE.domain(
concatBytes(r.toRawBytes(), pub.toRawBytes(), msg),
new Uint8Array([]),
false
);
const hashed = CURVE.hash(domain);
return finishVerification(pub, r, SB, hashed);
}
// Enable precomputes. Slows down first publicKey computation by 20ms. // Enable precomputes. Slows down first publicKey computation by 20ms.
Point.BASE._setWindowSize(8); Point.BASE._setWindowSize(8);
// ECDH (X22519/X448)
// https://datatracker.ietf.org/doc/html/rfc7748
// Every twisted Edwards curve is birationally equivalent to an elliptic curve in Montgomery form and vice versa.
const montgomeryBits = CURVE.montgomeryBits || CURVE.nBitLength;
const montgomeryBytes = Math.ceil(montgomeryBits / 8);
// cswap from RFC7748
function cswap(swap: bigint, x_2: bigint, x_3: bigint): [bigint, bigint] {
const dummy = modP(swap * (x_2 - x_3));
x_2 = modP(x_2 - dummy);
x_3 = modP(x_3 + dummy);
return [x_2, x_3];
}
// x25519 from 4
/**
*
* @param pointU u coordinate (x) on Montgomery Curve 25519
* @param scalar by which the point would be multiplied
* @returns new Point on Montgomery curve
*/
function montgomeryLadder(pointU: bigint, scalar: bigint): bigint {
const { P } = CURVE;
const u = normalizeScalar(pointU, P);
// Section 5: Implementations MUST accept non-canonical values and process them as
// if they had been reduced modulo the field prime.
const k = normalizeScalar(scalar, P);
// The constant a24 is (486662 - 2) / 4 = 121665 for curve25519/X25519
const a24 = CURVE.a24;
const x_1 = u;
let x_2 = _1n;
let z_2 = _0n;
let x_3 = u;
let z_3 = _1n;
let swap = _0n;
let sw: [bigint, bigint];
for (let t = BigInt(montgomeryBits - 1); t >= _0n; t--) {
const k_t = (k >> t) & _1n;
swap ^= k_t;
sw = cswap(swap, x_2, x_3);
x_2 = sw[0];
x_3 = sw[1];
sw = cswap(swap, z_2, z_3);
z_2 = sw[0];
z_3 = sw[1];
swap = k_t;
const A = x_2 + z_2;
const AA = modP(A * A);
const B = x_2 - z_2;
const BB = modP(B * B);
const E = AA - BB;
const C = x_3 + z_3;
const D = x_3 - z_3;
const DA = modP(D * A);
const CB = modP(C * B);
const dacb = DA + CB;
const da_cb = DA - CB;
x_3 = modP(dacb * dacb);
z_3 = modP(x_1 * modP(da_cb * da_cb));
x_2 = modP(AA * BB);
z_2 = modP(E * (AA + modP(a24 * E)));
}
// (x_2, x_3) = cswap(swap, x_2, x_3)
sw = cswap(swap, x_2, x_3);
x_2 = sw[0];
x_3 = sw[1];
// (z_2, z_3) = cswap(swap, z_2, z_3)
sw = cswap(swap, z_2, z_3);
z_2 = sw[0];
z_3 = sw[1];
// z_2^(p - 2)
const z2 = powPminus2(z_2);
// Return x_2 * (z_2^(p - 2))
return modP(x_2 * z2);
}
function encodeUCoordinate(u: bigint): Uint8Array {
return numberToBytesLE(modP(u), montgomeryBytes);
}
function decodeUCoordinate(uEnc: Hex): bigint {
const u = ensureBytes(uEnc, montgomeryBytes);
// Section 5: When receiving such an array, implementations of X25519
// MUST mask the most significant bit in the final byte.
// This is very ugly way, but it works because fieldLen-1 is outside of bounds for X448, so this becomes NOOP
// fieldLen - scalaryBytes = 1 for X448 and = 0 for X25519
u[fieldLen - 1] &= 127; // 0b0111_1111
return bytesToNumberLE(u);
}
function decodeScalar(n: Hex): bigint {
const bytes = ensureBytes(n);
if (bytes.length !== montgomeryBytes && bytes.length !== fieldLen)
throw new Error(`Expected ${montgomeryBytes} or ${fieldLen} bytes, got ${bytes.length}`);
return bytesToNumberLE(adjustScalarBytes(bytes));
}
/*
Converts Point to Montgomery Curve
- u, v: curve25519 coordinates
- x, y: ed25519 coordinates
RFC 7748 (https://www.rfc-editor.org/rfc/rfc7748) says
- The birational maps are (25519):
(u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x)
(x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1))
- The birational maps are (448):
(u, v) = ((y-1)/(y+1), sqrt(156324)*u/x)
(x, y) = (sqrt(156324)*u/v, (1+u)/(1-u))
But original Twisted Edwards paper (https://eprint.iacr.org/2008/013.pdf) and hyperelliptics (http://hyperelliptic.org/EFD/g1p/data/twisted/coordinates)
says that mapping is always:
- u = (1+y)/(1-y)
- v = 2 (1+y)/(x(1-y))
- x = 2 u/v
- y = (u-1)/(u+1)
Which maps correctly, but to completely different curve. There is different mapping for ed448 (which done with replaceble function).
Returns 'u' coordinate of curve25519 point.
NOTE: jubjub will need full mapping, for now only Point -> U is enough
*/
function _UfromPoint(p: Point): Uint8Array {
if (!(p instanceof Point)) throw new Error('Wrong point');
const { y } = p;
const u = modP((y + _1n) * mod.invert(_1n - y, P));
return numberToBytesLE(u, montgomeryBytes);
}
const UfromPoint = CURVE.UfromPoint || _UfromPoint;
const BASE_POINT_U = CURVE.basePointU || bytesToHex(UfromPoint(Point.BASE));
// Multiply point u by scalar
function scalarMult(u: Hex, scalar: Hex): Uint8Array {
const pointU = decodeUCoordinate(u);
const _scalar = decodeScalar(scalar);
const pu = montgomeryLadder(pointU, _scalar);
// The result was not contributory
// https://cr.yp.to/ecdh.html#validate
if (pu === _0n) throw new Error('Invalid private or public key received');
return encodeUCoordinate(pu);
}
// Multiply base point by scalar
const scalarMultBase = (scalar: Hex): Uint8Array =>
montgomeryCurve.scalarMult(montgomeryCurve.BASE_POINT_U, scalar);
const montgomeryCurve = {
BASE_POINT_U,
UfromPoint,
// NOTE: we can get 'y' coordinate from 'u', but Point.fromHex also wants 'x' coordinate oddity flag, and we cannot get 'x' without knowing 'v'
// Need to add generic conversion between twisted edwards and complimentary curve for JubJub
scalarMult,
scalarMultBase,
// NOTE: these function work on complimentary montgomery curve
getSharedSecret: (privateKey: Hex, publicKey: Hex) => scalarMult(publicKey, privateKey),
getPublicKey: (privateKey: Hex): Uint8Array => scalarMultBase(privateKey),
};
/**
* Calculates X25519 DH shared secret from ed25519 private & public keys.
* Curve25519 used in X25519 consumes private keys as-is, while ed25519 hashes them with sha512.
* Which means we will need to normalize ed25519 seeds to "hashed repr".
* @param privateKey ed25519 private key
* @param publicKey ed25519 public key
* @returns X25519 shared key
*/
function getSharedSecret(privateKey: PrivKey, publicKey: Hex): Uint8Array {
const { head } = getExtendedPublicKey(privateKey);
const u = montgomeryCurve.UfromPoint(Point.fromHex(publicKey));
return montgomeryCurve.getSharedSecret(head, u);
}
const utils = { const utils = {
getExtendedPublicKey, getExtendedPublicKey,
mod: modP, mod: modP,
@ -661,12 +867,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
* @param hash hash output from sha512, or a similar function * @param hash hash output from sha512, or a similar function
* @returns valid private scalar * @returns valid private scalar
*/ */
hashToPrivateScalar: (hash: Hex): bigint => { hashToPrivateScalar: (hash: Hex): bigint => hashToPrivateScalar(hash, CURVE_ORDER, true),
hash = ensureBytes(hash);
if (hash.length < 40 || hash.length > 1024)
throw new Error('Expected 40-1024 bytes of private key as per FIPS 186');
return mod.mod(bytesToNumberLE(hash), CURVE_ORDER - _1n) + _1n;
},
/** /**
* ed25519 private keys are uniform 32-bit strings. We do not need to check for * ed25519 private keys are uniform 32-bit strings. We do not need to check for
@ -690,6 +891,8 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
return { return {
CURVE, CURVE,
montgomeryCurve,
getSharedSecret,
ExtendedPoint, ExtendedPoint,
Point, Point,
Signature, Signature,

@ -99,7 +99,6 @@ export function legendre(num: bigint, fieldPrime: bigint): bigint {
/** /**
* Calculates square root of a number in a finite field. * Calculates square root of a number in a finite field.
* Used to calculate y - the square root of y².
*/ */
export function sqrt(number: bigint, modulo: bigint): bigint { export function sqrt(number: bigint, modulo: bigint): bigint {
const n = number; const n = number;
@ -109,6 +108,7 @@ export function sqrt(number: bigint, modulo: bigint): bigint {
// P = 3 (mod 4) // P = 3 (mod 4)
// sqrt n = n^((P+1)/4) // sqrt n = n^((P+1)/4)
if (P % _4n === _3n) return pow(n, p1div4, P); if (P % _4n === _3n) return pow(n, p1div4, P);
// P = 5 (mod 8) // P = 5 (mod 8)
if (P % _8n === _5n) { if (P % _8n === _5n) {
const n2 = mod(n * _2n, P); const n2 = mod(n * _2n, P);

@ -2,6 +2,9 @@
// Convert between types // Convert between types
// --------------------- // ---------------------
type Hex = string | Uint8Array;
import * as mod from './modular.js';
const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0')); const hexes = Array.from({ length: 256 }, (v, i) => i.toString(16).padStart(2, '0'));
export function bytesToHex(uint8a: Uint8Array): string { export function bytesToHex(uint8a: Uint8Array): string {
if (!(uint8a instanceof Uint8Array)) throw new Error('Expected Uint8Array'); if (!(uint8a instanceof Uint8Array)) throw new Error('Expected Uint8Array');
@ -44,15 +47,19 @@ export function hexToBytes(hex: string): Uint8Array {
} }
// Big Endian // Big Endian
export function bytesToNumber(bytes: Uint8Array): bigint { export function bytesToNumberBE(bytes: Uint8Array): bigint {
return hexToNumber(bytesToHex(bytes)); return hexToNumber(bytesToHex(bytes));
} }
export function bytesToNumberLE(uint8a: Uint8Array): bigint {
if (!(uint8a instanceof Uint8Array)) throw new Error('Expected Uint8Array');
return BigInt('0x' + bytesToHex(Uint8Array.from(uint8a).reverse()));
}
export const numberToBytesBE = (n: bigint, len: number) => export const numberToBytesBE = (n: bigint, len: number) =>
hexToBytes(n.toString(16).padStart(len * 2, '0')); hexToBytes(n.toString(16).padStart(len * 2, '0'));
export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse(); export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse();
export function ensureBytes(hex: string | Uint8Array, expectedLength?: number): Uint8Array { export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
// Uint8Array.from() instead of hash.slice() because node.js Buffer // Uint8Array.from() instead of hash.slice() because node.js Buffer
// is instance of Uint8Array, and its slice() creates **mutable** copy // is instance of Uint8Array, and its slice() creates **mutable** copy
const bytes = hex instanceof Uint8Array ? Uint8Array.from(hex) : hexToBytes(hex); const bytes = hex instanceof Uint8Array ? Uint8Array.from(hex) : hexToBytes(hex);
@ -82,3 +89,19 @@ export function nLength(n: bigint, nBitLength?: number) {
const nByteLength = Math.ceil(_nBitLength / 8); const nByteLength = Math.ceil(_nBitLength / 8);
return { nBitLength: _nBitLength, nByteLength }; return { nBitLength: _nBitLength, nByteLength };
} }
/**
* Can take (n+8) or more bytes of uniform input e.g. from CSPRNG or KDF
* and convert them into private scalar, with the modulo bias being neglible.
* As per FIPS 186 B.4.1.
* @param hash hash output from sha512, or a similar function
* @returns valid private scalar
*/
const _1n = BigInt(1);
export function hashToPrivateScalar(hash: Hex, CURVE_ORDER: bigint, isLE = false): bigint {
hash = ensureBytes(hash);
if (hash.length < 40 || hash.length > 1024)
throw new Error('Expected 40-1024 bytes of private key as per FIPS 186');
const num = isLE ? bytesToNumberLE(hash) : bytesToNumberBE(hash);
return mod.mod(num, CURVE_ORDER - _1n) + _1n;
}

@ -12,13 +12,14 @@
import * as mod from './modular.js'; import * as mod from './modular.js';
import { import {
bytesToHex, bytesToHex,
bytesToNumber, bytesToNumberBE,
concatBytes, concatBytes,
ensureBytes, ensureBytes,
hexToBytes, hexToBytes,
hexToNumber, hexToNumber,
numberToHexUnpadded, numberToHexUnpadded,
nLength, nLength,
hashToPrivateScalar,
} from './utils.js'; } from './utils.js';
import { wNAF } from './group.js'; import { wNAF } from './group.js';
@ -131,7 +132,7 @@ function parseDERInt(data: Uint8Array) {
if (res[0] === 0x00 && res[1] <= 0x7f) { if (res[0] === 0x00 && res[1] <= 0x7f) {
throw new DERError('Invalid signature integer: trailing length'); throw new DERError('Invalid signature integer: trailing length');
} }
return { data: bytesToNumber(res), left: data.subarray(len + 2) }; return { data: bytesToNumberBE(res), left: data.subarray(len + 2) };
} }
function parseDERSignature(data: Uint8Array) { function parseDERSignature(data: Uint8Array) {
@ -214,8 +215,8 @@ export interface JacobianPointType {
double(): JacobianPointType; double(): JacobianPointType;
add(other: JacobianPointType): JacobianPointType; add(other: JacobianPointType): JacobianPointType;
subtract(other: JacobianPointType): JacobianPointType; subtract(other: JacobianPointType): JacobianPointType;
multiplyUnsafe(scalar: bigint): JacobianPointType;
multiply(scalar: number | bigint, affinePoint?: PointType): JacobianPointType; multiply(scalar: number | bigint, affinePoint?: PointType): JacobianPointType;
multiplyUnsafe(scalar: bigint): JacobianPointType;
toAffine(invZ?: bigint): PointType; toAffine(invZ?: bigint): PointType;
} }
// Static methods // Static methods
@ -392,7 +393,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
num = hexToNumber(key); num = hexToNumber(key);
} else if (key instanceof Uint8Array) { } else if (key instanceof Uint8Array) {
if (key.length !== groupLen) throw new Error(`Expected ${groupLen} bytes of private key`); if (key.length !== groupLen) throw new Error(`Expected ${groupLen} bytes of private key`);
num = bytesToNumber(key); num = bytesToNumberBE(key);
} else { } else {
throw new TypeError('Expected valid private key'); throw new TypeError('Expected valid private key');
} }
@ -435,7 +436,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
const { n, nBitLength } = CURVE; const { n, nBitLength } = CURVE;
const byteLength = hash.length; const byteLength = hash.length;
const delta = byteLength * 8 - nBitLength; // size of curve.n (252 bits) const delta = byteLength * 8 - nBitLength; // size of curve.n (252 bits)
let h = bytesToNumber(hash); let h = bytesToNumberBE(hash);
if (delta > 0) h = h >> BigInt(delta); if (delta > 0) h = h >> BigInt(delta);
if (!truncateOnly && h >= n) h -= n; if (!truncateOnly && h >= n) h -= n;
return h; return h;
@ -720,7 +721,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
*/ */
private static fromCompressedHex(bytes: Uint8Array) { private static fromCompressedHex(bytes: Uint8Array) {
const P = CURVE.P; const P = CURVE.P;
const x = bytesToNumber(bytes.subarray(1)); const x = bytesToNumberBE(bytes.subarray(1));
if (!isValidFieldElement(x)) throw new Error('Point is not on curve'); if (!isValidFieldElement(x)) throw new Error('Point is not on curve');
const y2 = weierstrassEquation(x); // y² = x³ + ax + b const y2 = weierstrassEquation(x); // y² = x³ + ax + b
let y = sqrtModCurve(y2, P); // y = y² ^ (p+1)/4 let y = sqrtModCurve(y2, P); // y = y² ^ (p+1)/4
@ -734,8 +735,8 @@ export function weierstrass(curveDef: CurveType): CurveFn {
} }
private static fromUncompressedHex(bytes: Uint8Array) { private static fromUncompressedHex(bytes: Uint8Array) {
const x = bytesToNumber(bytes.subarray(1, fieldLen + 1)); const x = bytesToNumberBE(bytes.subarray(1, fieldLen + 1));
const y = bytesToNumber(bytes.subarray(fieldLen + 1, 2 * fieldLen + 1)); const y = bytesToNumberBE(bytes.subarray(fieldLen + 1, 2 * fieldLen + 1));
const point = new Point(x, y); const point = new Point(x, y);
point.assertValidity(); point.assertValidity();
return point; return point;
@ -962,15 +963,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
* @param hash hash output from sha512, or a similar function * @param hash hash output from sha512, or a similar function
* @returns valid private key * @returns valid private key
*/ */
hashToPrivateKey: (hash: Hex): Uint8Array => { hashToPrivateKey: (hash: Hex): Uint8Array => numToField(hashToPrivateScalar(hash, CURVE_ORDER)),
hash = ensureBytes(hash);
const minLen = fieldLen + 8;
if (hash.length < minLen || hash.length > 1024) {
throw new Error(`Expected ${minLen}-1024 bytes of private key as per FIPS 186`);
}
const num = mod.mod(bytesToNumber(hash), CURVE_ORDER - _1n) + _1n;
return numToField(num);
},
// Takes curve order + 64 bits from CSPRNG // Takes curve order + 64 bits from CSPRNG
// so that modulo bias is neglible, matches FIPS 186 B.4.1. // so that modulo bias is neglible, matches FIPS 186 B.4.1.
@ -1035,7 +1028,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// RFC6979 methods // RFC6979 methods
function bits2int(bytes: Uint8Array) { function bits2int(bytes: Uint8Array) {
const slice = bytes.length > fieldLen ? bytes.slice(0, fieldLen) : bytes; const slice = bytes.length > fieldLen ? bytes.slice(0, fieldLen) : bytes;
return bytesToNumber(slice); return bytesToNumberBE(slice);
} }
function bits2octets(bytes: Uint8Array): Uint8Array { function bits2octets(bytes: Uint8Array): Uint8Array {
const z1 = bits2int(bytes); const z1 = bits2int(bytes);