Remove utils.mod(), utils.invert()

This commit is contained in:
Paul Miller 2023-01-13 00:26:00 +00:00
parent 36998fede8
commit 2d37edf7d1
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
11 changed files with 55 additions and 64 deletions

@ -201,8 +201,6 @@ export type CurveFn = {
ExtendedPoint: ExtendedPointConstructor; ExtendedPoint: ExtendedPointConstructor;
Signature: SignatureConstructor; Signature: SignatureConstructor;
utils: { utils: {
mod: (a: bigint, b?: bigint) => bigint;
invert: (number: bigint, modulo?: bigint) => bigint;
randomPrivateKey: () => Uint8Array; randomPrivateKey: () => Uint8Array;
getExtendedPublicKey: (key: PrivKey) => { getExtendedPublicKey: (key: PrivKey) => {
head: Uint8Array; head: Uint8Array;
@ -306,6 +304,7 @@ export type CurveFn = {
getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array; getPublicKey: (privateKey: PrivKey, isCompressed?: boolean) => Uint8Array;
getSharedSecret: (privateA: PrivKey, publicB: PubKey, isCompressed?: boolean) => Uint8Array; getSharedSecret: (privateA: PrivKey, publicB: PubKey, isCompressed?: boolean) => Uint8Array;
sign: (msgHash: Hex, privKey: PrivKey, opts?: SignOpts) => SignatureType; sign: (msgHash: Hex, privKey: PrivKey, opts?: SignOpts) => SignatureType;
signUnhashed: (msg: Uint8Array, privKey: PrivKey, opts?: SignOpts) => SignatureType;
verify: ( verify: (
signature: Hex | SignatureType, signature: Hex | SignatureType,
msgHash: Hex, msgHash: Hex,
@ -316,8 +315,6 @@ export type CurveFn = {
ProjectivePoint: ProjectivePointConstructor; ProjectivePoint: ProjectivePointConstructor;
Signature: SignatureConstructor; Signature: SignatureConstructor;
utils: { utils: {
mod: (a: bigint) => bigint;
invert: (number: bigint) => bigint;
isValidPrivateKey(privateKey: PrivKey): boolean; isValidPrivateKey(privateKey: PrivKey): boolean;
hashToPrivateKey: (hash: Hex) => Uint8Array; hashToPrivateKey: (hash: Hex) => Uint8Array;
randomPrivateKey: () => Uint8Array; randomPrivateKey: () => Uint8Array;

@ -93,12 +93,9 @@ export type CurveFn<Fp, Fp2, Fp6, Fp12> = {
publicKeys: (Hex | PointType<Fp>)[] publicKeys: (Hex | PointType<Fp>)[]
) => boolean; ) => boolean;
utils: { utils: {
bytesToHex: typeof ut.bytesToHex;
hexToBytes: typeof ut.hexToBytes;
stringToBytes: typeof stringToBytes; stringToBytes: typeof stringToBytes;
hashToField: typeof hashToField; hashToField: typeof hashToField;
expandMessageXMD: typeof expandMessageXMD; expandMessageXMD: typeof expandMessageXMD;
mod: typeof mod.mod;
getDSTLabel: () => string; getDSTLabel: () => string;
setDSTLabel(newLabel: string): void; setDSTLabel(newLabel: string): void;
}; };
@ -177,7 +174,6 @@ export function bls<Fp2, Fp6, Fp12>(
const utils = { const utils = {
hexToBytes: ut.hexToBytes, hexToBytes: ut.hexToBytes,
bytesToHex: ut.bytesToHex, bytesToHex: ut.bytesToHex,
mod: mod.mod,
stringToBytes: stringToBytes, stringToBytes: stringToBytes,
// TODO: do we need to export it here? // TODO: do we need to export it here?
hashToField: ( hashToField: (

@ -130,8 +130,6 @@ export type CurveFn = {
ExtendedPoint: ExtendedPointConstructor; ExtendedPoint: ExtendedPointConstructor;
Signature: SignatureConstructor; Signature: SignatureConstructor;
utils: { utils: {
mod: (a: bigint) => bigint;
invert: (number: bigint) => bigint;
randomPrivateKey: () => Uint8Array; randomPrivateKey: () => Uint8Array;
getExtendedPublicKey: (key: PrivKey) => { getExtendedPublicKey: (key: PrivKey) => {
head: Uint8Array; head: Uint8Array;
@ -146,7 +144,7 @@ export type CurveFn = {
// NOTE: it is not generic twisted curve for now, but ed25519/ed448 generic implementation // NOTE: it is not generic twisted curve for now, but ed25519/ed448 generic implementation
export function twistedEdwards(curveDef: CurveType): CurveFn { export function twistedEdwards(curveDef: CurveType): CurveFn {
const CURVE = validateOpts(curveDef) as ReturnType<typeof validateOpts>; const CURVE = validateOpts(curveDef) as ReturnType<typeof validateOpts>;
const Fp = CURVE.Fp as mod.Field<bigint>; const Fp = CURVE.Fp;
const CURVE_ORDER = CURVE.n; const CURVE_ORDER = CURVE.n;
const maxGroupElement = _2n ** BigInt(CURVE.nByteLength * 8); const maxGroupElement = _2n ** BigInt(CURVE.nByteLength * 8);
@ -662,9 +660,6 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const utils = { const utils = {
getExtendedPublicKey, getExtendedPublicKey,
mod: modP,
invert: Fp.invert,
/** /**
* Not needed for ed25519 private keys. Needed if you use scalars directly (rare). * Not needed for ed25519 private keys. Needed if you use scalars directly (rare).
*/ */

@ -608,13 +608,9 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
const { x, y } = this; const { x, y } = this;
// Check if x, y are valid field elements // Check if x, y are valid field elements
if (!Fp.isValid(x) || !Fp.isValid(y)) throw new Error(msg); if (!Fp.isValid(x) || !Fp.isValid(y)) throw new Error(msg);
const left = Fp.square(y); const left = Fp.square(y); // y²
const right = weierstrassEquation(x); const right = weierstrassEquation(x); // x³ + ax + b
// We subtract instead of comparing: it's safer if (!Fp.equals(left, right)) throw new Error(msg);
// (y²) - (x³ + ax + b) == 0
if (!Fp.isZero(Fp.sub(left, right))) throw new Error(msg);
// if (!Fp.equals(left, right))
// TODO: flag to disable this?
if (!this.isTorsionFree()) throw new Error('Point must be of prime-order subgroup'); if (!this.isTorsionFree()) throw new Error('Point must be of prime-order subgroup');
} }
@ -771,8 +767,6 @@ export type CurveFn = {
ProjectivePoint: ProjectiveConstructor<bigint>; ProjectivePoint: ProjectiveConstructor<bigint>;
Signature: SignatureConstructor; Signature: SignatureConstructor;
utils: { utils: {
mod: (a: bigint, b?: bigint) => bigint;
invert: (number: bigint, modulo?: bigint) => bigint;
_bigintToBytes: (num: bigint) => Uint8Array; _bigintToBytes: (num: bigint) => Uint8Array;
_bigintToString: (num: bigint) => string; _bigintToString: (num: bigint) => string;
_normalizePrivateKey: (key: PrivKey) => bigint; _normalizePrivateKey: (key: PrivKey) => bigint;
@ -831,9 +825,7 @@ class HmacDrbg {
} }
return ut.concatBytes(...out); return ut.concatBytes(...out);
} }
// There is no need in clean() method // There are no guarantees with JS GC whether bigints are removed even if you clean Uint8Arrays.
// It's useless, there are no guarantees with JS GC
// whether bigints are removed even if you clean Uint8Arrays.
} }
export function weierstrass(curveDef: CurveType): CurveFn { export function weierstrass(curveDef: CurveType): CurveFn {
@ -1049,8 +1041,6 @@ export function weierstrass(curveDef: CurveType): CurveFn {
} }
const utils = { const utils = {
mod: (n: bigint, modulo = Fp.ORDER) => mod.mod(n, modulo),
invert: Fp.invert,
isValidPrivateKey(privateKey: PrivKey) { isValidPrivateKey(privateKey: PrivKey) {
try { try {
normalizePrivateKey(privateKey); normalizePrivateKey(privateKey);

@ -135,6 +135,7 @@ const Fp2: mod.Field<Fp2> & Fp2Utils = {
return { c0: Fp.mul(factor, Fp.create(a)), c1: Fp.mul(factor, Fp.create(-b)) }; return { c0: Fp.mul(factor, Fp.create(a)), c1: Fp.mul(factor, Fp.create(-b)) };
}, },
sqrt: (num) => { sqrt: (num) => {
if (Fp2.equals(num, Fp2.ZERO)) return Fp2.ZERO; // Algo doesn't handles this case
// TODO: Optimize this line. It's extremely slow. // TODO: Optimize this line. It's extremely slow.
// Speeding this up would boost aggregateSignatures. // Speeding this up would boost aggregateSignatures.
// https://eprint.iacr.org/2012/685.pdf applicable? // https://eprint.iacr.org/2012/685.pdf applicable?

@ -260,7 +260,7 @@ const invertSqrt = (number: bigint) => uvRatio(_1n, number);
const MAX_255B = BigInt('0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'); const MAX_255B = BigInt('0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff');
const bytes255ToNumberLE = (bytes: Uint8Array) => const bytes255ToNumberLE = (bytes: Uint8Array) =>
ed25519.utils.mod(bytesToNumberLE(bytes) & MAX_255B); ed25519.CURVE.Fp.create(bytesToNumberLE(bytes) & MAX_255B);
type ExtendedPoint = ExtendedPointType; type ExtendedPoint = ExtendedPointType;
@ -269,7 +269,7 @@ type ExtendedPoint = ExtendedPointType;
function calcElligatorRistrettoMap(r0: bigint): ExtendedPoint { function calcElligatorRistrettoMap(r0: bigint): ExtendedPoint {
const { d } = ed25519.CURVE; const { d } = ed25519.CURVE;
const P = ed25519.CURVE.Fp.ORDER; const P = ed25519.CURVE.Fp.ORDER;
const { mod } = ed25519.utils; const mod = ed25519.CURVE.Fp.create;
const r = mod(SQRT_M1 * r0 * r0); // 1 const r = mod(SQRT_M1 * r0 * r0); // 1
const Ns = mod((r + _1n) * ONE_MINUS_D_SQ); // 2 const Ns = mod((r + _1n) * ONE_MINUS_D_SQ); // 2
let c = BigInt(-1); // 3 let c = BigInt(-1); // 3
@ -327,7 +327,7 @@ export class RistrettoPoint {
hex = ensureBytes(hex, 32); hex = ensureBytes(hex, 32);
const { a, d } = ed25519.CURVE; const { a, d } = ed25519.CURVE;
const P = ed25519.CURVE.Fp.ORDER; const P = ed25519.CURVE.Fp.ORDER;
const { mod } = ed25519.utils; const mod = ed25519.CURVE.Fp.create;
const emsg = 'RistrettoPoint.fromHex: the hex is not valid encoding of RistrettoPoint'; const emsg = 'RistrettoPoint.fromHex: the hex is not valid encoding of RistrettoPoint';
const s = bytes255ToNumberLE(hex); const s = bytes255ToNumberLE(hex);
// 1. Check that s_bytes is the canonical encoding of a field element, or else abort. // 1. Check that s_bytes is the canonical encoding of a field element, or else abort.
@ -357,7 +357,7 @@ export class RistrettoPoint {
toRawBytes(): Uint8Array { toRawBytes(): Uint8Array {
let { x, y, z, t } = this.ep; let { x, y, z, t } = this.ep;
const P = ed25519.CURVE.Fp.ORDER; const P = ed25519.CURVE.Fp.ORDER;
const { mod } = ed25519.utils; const mod = ed25519.CURVE.Fp.create;
const u1 = mod(mod(z + y) * mod(z - y)); // 1 const u1 = mod(mod(z + y) * mod(z - y)); // 1
const u2 = mod(x * y); // 2 const u2 = mod(x * y); // 2
// Square root always exists // Square root always exists
@ -395,7 +395,7 @@ export class RistrettoPoint {
assertRstPoint(other); assertRstPoint(other);
const a = this.ep; const a = this.ep;
const b = other.ep; const b = other.ep;
const { mod } = ed25519.utils; const mod = ed25519.CURVE.Fp.create;
// (x1 * y2 == y1 * x2) | (y1 * y2 == x1 * x2) // (x1 * y2 == y1 * x2) | (y1 * y2 == x1 * x2)
const one = mod(a.x * b.y) === mod(a.y * b.x); const one = mod(a.x * b.y) === mod(a.y * b.x);
const two = mod(a.y * b.y) === mod(a.x * b.x); const two = mod(a.y * b.y) === mod(a.x * b.x);

@ -127,7 +127,7 @@ export const secp256k1 = createCurve(
const b1 = -_1n * BigInt('0xe4437ed6010e88286f547fa90abfe4c3'); const b1 = -_1n * BigInt('0xe4437ed6010e88286f547fa90abfe4c3');
const a2 = BigInt('0x114ca50f7a8e2f3f657c1108d9d44cfd8'); const a2 = BigInt('0x114ca50f7a8e2f3f657c1108d9d44cfd8');
const b2 = a1; const b2 = a1;
const POW_2_128 = BigInt('0x100000000000000000000000000000000'); const POW_2_128 = BigInt('0x100000000000000000000000000000000'); // (2n**128n).toString(16)
const c1 = divNearest(b2 * k, n); const c1 = divNearest(b2 * k, n);
const c2 = divNearest(-b1 * k, n); const c2 = divNearest(-b1 * k, n);
@ -173,20 +173,17 @@ function normalizePublicKey(publicKey: Hex | PointType<bigint>): PointType<bigin
} else { } else {
const bytes = ensureBytes(publicKey); const bytes = ensureBytes(publicKey);
// Schnorr is 32 bytes // Schnorr is 32 bytes
if (bytes.length === 32) { if (bytes.length !== 32) throw new Error('Schnorr pubkeys must be 32 bytes');
const x = bytesToNumberBE(bytes); const x = bytesToNumberBE(bytes);
if (!isValidFieldElement(x)) throw new Error('Point is not on curve'); if (!isValidFieldElement(x)) throw new Error('Point is not on curve');
const y2 = secp256k1.utils._weierstrassEquation(x); // y² = x³ + ax + b const y2 = secp256k1.utils._weierstrassEquation(x); // y² = x³ + ax + b
let y = sqrtMod(y2); // y = y² ^ (p+1)/4 let y = sqrtMod(y2); // y = y² ^ (p+1)/4
const isYOdd = (y & _1n) === _1n; const isYOdd = (y & _1n) === _1n;
// Schnorr // Schnorr
if (isYOdd) y = secp256k1.CURVE.Fp.negate(y); if (isYOdd) y = secp256k1.CURVE.Fp.negate(y);
const point = new secp256k1.Point(x, y); const point = new secp256k1.Point(x, y);
point.assertValidity(); point.assertValidity();
return point; return point;
}
// Do we need that in schnorr at all?
return secp256k1.Point.fromHex(publicKey);
} }
} }
@ -225,10 +222,13 @@ class SchnorrSignature {
} }
static fromHex(hex: Hex) { static fromHex(hex: Hex) {
const bytes = ensureBytes(hex); const bytes = ensureBytes(hex);
if (bytes.length !== 64) const len = 32; // group length
throw new TypeError(`SchnorrSignature.fromHex: expected 64 bytes, not ${bytes.length}`); if (bytes.length !== 2 * len)
const r = bytesToNumberBE(bytes.subarray(0, 32)); throw new TypeError(
const s = bytesToNumberBE(bytes.subarray(32, 64)); `SchnorrSignature.fromHex: expected ${2 * len} bytes, not ${bytes.length}`
);
const r = bytesToNumberBE(bytes.subarray(0, len));
const s = bytesToNumberBE(bytes.subarray(len, 2 * len));
return new SchnorrSignature(r, s); return new SchnorrSignature(r, s);
} }
assertValidity() { assertValidity() {

@ -138,11 +138,12 @@ function hashKeyWithIndex(key: Uint8Array, index: number) {
export function grindKey(seed: Hex) { export function grindKey(seed: Hex) {
const _seed = ensureBytes0x(seed); const _seed = ensureBytes0x(seed);
const sha256mask = 2n ** 256n; const sha256mask = 2n ** 256n;
const limit = sha256mask - starkCurve.utils.mod(sha256mask, starkCurve.CURVE.n); const Fn = Fp(CURVE.n);
const limit = sha256mask - Fn.create(sha256mask);
for (let i = 0; ; i++) { for (let i = 0; ; i++) {
const key = hashKeyWithIndex(_seed, i); const key = hashKeyWithIndex(_seed, i);
// key should be in [0, limit) // key should be in [0, limit)
if (key < limit) return starkCurve.utils.mod(key, starkCurve.CURVE.n).toString(16); if (key < limit) return Fn.create(key).toString(16);
} }
} }

@ -646,7 +646,8 @@ for (let i = 0; i < VECTORS_RFC8032_PH.length; i++) {
should('X25519 base point', () => { should('X25519 base point', () => {
const { y } = ed25519.Point.BASE; const { y } = ed25519.Point.BASE;
const u = ed25519.utils.mod((y + 1n) * ed25519.utils.invert(1n - y, ed25519.CURVE.P)); const { Fp } = ed25519.CURVE;
const u = Fp.create((y + 1n) * Fp.invert(1n - y));
deepStrictEqual(hex(numberToBytesLE(u, 32)), x25519.Gu); deepStrictEqual(hex(numberToBytesLE(u, 32)), x25519.Gu);
}); });

@ -651,9 +651,10 @@ for (let i = 0; i < VECTORS_RFC8032_PH.length; i++) {
should('X448 base point', () => { should('X448 base point', () => {
const { x, y } = ed448.Point.BASE; const { x, y } = ed448.Point.BASE;
const { P } = ed448.CURVE; const { Fp } = ed448.CURVE;
const invX = ed448.utils.invert(x * x, P); // x² // const invX = Fp.invert(x * x); // x²
const u = ed448.utils.mod(y * y * invX, P); // (y²/x²) const u = Fp.div(Fp.create(y * y), Fp.create(x * x)); // (y²/x²)
// const u = Fp.create(y * y * invX);
deepStrictEqual(hex(numberToBytesLE(u, 56)), x448.Gu); deepStrictEqual(hex(numberToBytesLE(u, 56)), x448.Gu);
}); });

@ -1,5 +1,6 @@
import * as fc from 'fast-check'; import * as fc from 'fast-check';
import { secp256k1, schnorr } from '../lib/esm/secp256k1.js'; import { secp256k1, schnorr } from '../lib/esm/secp256k1.js';
import { Fp } from '../lib/esm/abstract/modular.js';
import { readFileSync } from 'fs'; import { readFileSync } from 'fs';
import { default as ecdsa } from './vectors/ecdsa.json' assert { type: 'json' }; import { default as ecdsa } from './vectors/ecdsa.json' assert { type: 'json' };
import { default as ecdh } from './vectors/ecdh.json' assert { type: 'json' }; import { default as ecdh } from './vectors/ecdh.json' assert { type: 'json' };
@ -16,7 +17,6 @@ const privatesTxt = readFileSync('./test/vectors/privates-2.txt', 'utf-8');
const schCsv = readFileSync('./test/vectors/schnorr.csv', 'utf-8'); const schCsv = readFileSync('./test/vectors/schnorr.csv', 'utf-8');
const FC_BIGINT = fc.bigInt(1n + 1n, secp.CURVE.n - 1n); const FC_BIGINT = fc.bigInt(1n + 1n, secp.CURVE.n - 1n);
const P = secp.CURVE.Fp.ORDER;
// prettier-ignore // prettier-ignore
const INVALID_ITEMS = ['deadbeef', Math.pow(2, 53), [1], 'xyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxy', secp.CURVE.n + 2n]; const INVALID_ITEMS = ['deadbeef', Math.pow(2, 53), [1], 'xyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxyxyzxyzxy', secp.CURVE.n + 2n];
@ -50,9 +50,9 @@ should('secp256k1.getPublicKey()', () => {
} }
}); });
should('secp256k1.getPublicKey() rejects invalid keys', () => { should('secp256k1.getPublicKey() rejects invalid keys', () => {
// for (const item of INVALID_ITEMS) { for (const item of INVALID_ITEMS) {
// throws(() => secp.getPublicKey(item)); throws(() => secp.getPublicKey(item));
// } }
}); });
should('secp256k1.precompute', () => { should('secp256k1.precompute', () => {
secp.utils.precompute(4); secp.utils.precompute(4);
@ -434,17 +434,26 @@ should('secp256k1.utils.isValidPrivateKey()', () => {
deepStrictEqual(secp.utils.isValidPrivateKey(d), expected); deepStrictEqual(secp.utils.isValidPrivateKey(d), expected);
} }
}); });
should('have proper curve equation in assertValidity()', () => {
throws(() => {
const { Fp } = secp.CURVE;
let point = new secp.Point(Fp.create(-2n), Fp.create(-1n));
point.assertValidity();
});
});
const Fn = Fp(secp.CURVE.n);
const normal = secp.utils._normalizePrivateKey; const normal = secp.utils._normalizePrivateKey;
const tweakUtils = { const tweakUtils = {
privateAdd: (privateKey, tweak) => { privateAdd: (privateKey, tweak) => {
const p = normal(privateKey); const p = normal(privateKey);
const t = normal(tweak); const t = normal(tweak);
return secp.utils._bigintToBytes(secp.utils.mod(p + t, secp.CURVE.n)); return secp.utils._bigintToBytes(Fn.create(p + t));
}, },
privateNegate: (privateKey) => { privateNegate: (privateKey) => {
const p = normal(privateKey); const p = normal(privateKey);
return secp.utils._bigintToBytes(secp.CURVE.n - p); return secp.utils._bigintToBytes(Fn.negate(p));
}, },
pointAddScalar: (p, tweak, isCompressed) => { pointAddScalar: (p, tweak, isCompressed) => {