forked from tornado-packages/noble-curves
weierstrass: make points compressed by def. Rewrite drbg, k generation.
This commit is contained in:
parent
2b908ad602
commit
7fda6de619
@ -9,7 +9,7 @@ export interface Group<T extends Group<T>> {
|
||||
add(other: T): T;
|
||||
subtract(other: T): T;
|
||||
equals(other: T): boolean;
|
||||
multiply(scalar: number | bigint): T;
|
||||
multiply(scalar: bigint): T;
|
||||
}
|
||||
|
||||
export type GroupConstructor<T> = {
|
||||
|
@ -5,7 +5,7 @@ const _1n = BigInt(1);
|
||||
const _2n = BigInt(2);
|
||||
|
||||
const str = (a: any): a is string => typeof a === 'string';
|
||||
const big = (a: any): a is bigint => typeof a === 'bigint';
|
||||
export const big = (a: any): a is bigint => typeof a === 'bigint';
|
||||
const u8a = (a: any): a is Uint8Array => a instanceof Uint8Array;
|
||||
|
||||
// We accept hex strings besides Uint8Array for simplicity
|
||||
|
@ -114,7 +114,7 @@ export interface ProjectivePointType<T> extends Group<ProjectivePointType<T>> {
|
||||
readonly x: T;
|
||||
readonly y: T;
|
||||
readonly z: T;
|
||||
multiply(scalar: number | bigint, affinePoint?: PointType<T>): ProjectivePointType<T>;
|
||||
multiply(scalar: bigint, affinePoint?: PointType<T>): ProjectivePointType<T>;
|
||||
multiplyUnsafe(scalar: bigint): ProjectivePointType<T>;
|
||||
toAffine(invZ?: T): PointType<T>;
|
||||
clearCofactor(): ProjectivePointType<T>;
|
||||
@ -249,7 +249,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
|
||||
* Validates if a scalar ("private number") is valid.
|
||||
* Scalars are valid only if they are less than curve order.
|
||||
*/
|
||||
function normalizeScalar(num: number | bigint): bigint {
|
||||
function normalizeScalar(num: bigint): bigint {
|
||||
if (ut.isPositiveInt(num)) return BigInt(num);
|
||||
if (typeof num === 'bigint' && isWithinCurveOrder(num)) return num;
|
||||
throw new TypeError('Expected valid private scalar: 0 < scalar < curve.n');
|
||||
@ -470,7 +470,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
|
||||
* @param affinePoint optional point ot save cached precompute windows on it
|
||||
* @returns New point
|
||||
*/
|
||||
multiply(scalar: number | bigint, affinePoint?: Point): ProjectivePoint {
|
||||
multiply(scalar: bigint, affinePoint?: Point): ProjectivePoint {
|
||||
let n = normalizeScalar(scalar);
|
||||
|
||||
// Real point.
|
||||
@ -582,12 +582,12 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
|
||||
return Point.BASE.multiply(normalizePrivateKey(privateKey));
|
||||
}
|
||||
|
||||
toRawBytes(isCompressed = false): Uint8Array {
|
||||
toRawBytes(isCompressed = true): Uint8Array {
|
||||
this.assertValidity();
|
||||
return CURVE.toBytes(Point, this, isCompressed);
|
||||
}
|
||||
|
||||
toHex(isCompressed = false): string {
|
||||
toHex(isCompressed = true): string {
|
||||
return bytesToHex(this.toRawBytes(isCompressed));
|
||||
}
|
||||
// A point on curve is valid if it conforms to equation.
|
||||
@ -635,7 +635,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
|
||||
return this.add(other.negate());
|
||||
}
|
||||
|
||||
multiply(scalar: number | bigint) {
|
||||
multiply(scalar: bigint) {
|
||||
return this.toProj().multiply(scalar, this).toAffine();
|
||||
}
|
||||
|
||||
@ -754,54 +754,59 @@ export type CurveFn = {
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Minimal HMAC-DRBG (NIST 800-90) for signatures.
|
||||
* Used only for RFC6979, does not fully implement DRBG spec.
|
||||
*/
|
||||
class HmacDrbg {
|
||||
k: Uint8Array;
|
||||
v: Uint8Array;
|
||||
counter: number;
|
||||
constructor(public hashLen: number, public qByteLen: number, public hmacFn: HmacFnSync) {
|
||||
if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number');
|
||||
if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number');
|
||||
if (typeof hmacFn !== 'function') throw new Error('hmacFn must be a function');
|
||||
// Step B, Step C: set hashLen to 8*ceil(hlen/8)
|
||||
this.v = new Uint8Array(hashLen).fill(1);
|
||||
this.k = new Uint8Array(hashLen).fill(0);
|
||||
this.counter = 0;
|
||||
}
|
||||
private hmacSync(...values: Uint8Array[]) {
|
||||
return this.hmacFn(this.k, ...values);
|
||||
}
|
||||
incr() {
|
||||
if (this.counter >= 1000) throw new Error('Tried 1,000 k values for sign(), all were invalid');
|
||||
this.counter += 1;
|
||||
}
|
||||
reseedSync(seed = new Uint8Array()) {
|
||||
this.k = this.hmacSync(this.v, Uint8Array.from([0x00]), seed);
|
||||
this.v = this.hmacSync(this.v);
|
||||
const u8n = (data?: any) => new Uint8Array(data); // creates Uint8Array
|
||||
const u8fr = (arr: any) => Uint8Array.from(arr); // another shortcut
|
||||
// Minimal HMAC-DRBG from NIST 800-90 for RFC6979 sigs.
|
||||
type Pred<T> = (v: Uint8Array) => T | undefined;
|
||||
function hmacDrbg<T>(
|
||||
hashLen: number,
|
||||
qByteLen: number,
|
||||
hmacFn: HmacFnSync
|
||||
): (seed: Uint8Array, predicate: Pred<T>) => T {
|
||||
if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number');
|
||||
if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number');
|
||||
if (typeof hmacFn !== 'function') throw new Error('hmacFn must be a function');
|
||||
// Step B, Step C: set hashLen to 8*ceil(hlen/8)
|
||||
let v = u8n(hashLen); // Minimal non-full-spec HMAC-DRBG from NIST 800-90 for RFC6979 sigs.
|
||||
let k = u8n(hashLen); // Steps B and C of RFC6979 3.2: set hashLen, in our case always same
|
||||
let i = 0; // Iterations counter, will throw when over 1000
|
||||
const reset = () => {
|
||||
v.fill(1);
|
||||
k.fill(0);
|
||||
i = 0;
|
||||
};
|
||||
const h = (...b: Uint8Array[]) => hmacFn(k, v, ...b); // hmac(k)(v, ...values)
|
||||
const reseed = (seed = u8n()) => {
|
||||
// HMAC-DRBG reseed() function. Steps D-G
|
||||
k = h(u8fr([0x00]), seed); // k = hmac(k || v || 0x00 || seed)
|
||||
v = h(); // v = hmac(k || v)
|
||||
if (seed.length === 0) return;
|
||||
this.k = this.hmacSync(this.v, Uint8Array.from([0x01]), seed);
|
||||
this.v = this.hmacSync(this.v);
|
||||
}
|
||||
// TODO: review
|
||||
generateSync(): Uint8Array {
|
||||
this.incr();
|
||||
|
||||
k = h(u8fr([0x01]), seed); // k = hmac(k || v || 0x01 || seed)
|
||||
v = h(); // v = hmac(k || v)
|
||||
};
|
||||
const gen = () => {
|
||||
// HMAC-DRBG generate() function
|
||||
if (i++ >= 1000) throw new Error('drbg: tried 1000 values');
|
||||
let len = 0;
|
||||
const out: Uint8Array[] = [];
|
||||
while (len < this.qByteLen) {
|
||||
this.v = this.hmacSync(this.v);
|
||||
const sl = this.v.slice();
|
||||
while (len < qByteLen) {
|
||||
v = h();
|
||||
const sl = v.slice();
|
||||
out.push(sl);
|
||||
len += this.v.length;
|
||||
len += v.length;
|
||||
}
|
||||
return ut.concatBytes(...out);
|
||||
}
|
||||
// There are no guarantees with JS GC whether bigints are removed even if you clean Uint8Arrays.
|
||||
};
|
||||
const genUntil = (seed: Uint8Array, pred: Pred<T>): T => {
|
||||
reset();
|
||||
reseed(seed); // Steps D-G
|
||||
let res: T | undefined = undefined; // Step H: grind until k is in [1..n-1]
|
||||
while (!(res = pred(gen()))) reseed();
|
||||
reset();
|
||||
return res;
|
||||
};
|
||||
return genUntil;
|
||||
}
|
||||
|
||||
export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
const CURVE = validateOpts(curveDef) as ReturnType<typeof validateOpts>;
|
||||
const CURVE_ORDER = CURVE.n;
|
||||
@ -1051,7 +1056,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
* @param isCompressed whether to return compact (default), or full key
|
||||
* @returns Public key, full when isCompressed=false; short when isCompressed=true
|
||||
*/
|
||||
function getPublicKey(privateKey: PrivKey, isCompressed = false): Uint8Array {
|
||||
function getPublicKey(privateKey: PrivKey, isCompressed = true): Uint8Array {
|
||||
return Point.fromPrivateKey(privateKey).toRawBytes(isCompressed);
|
||||
}
|
||||
|
||||
@ -1077,7 +1082,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
* @param isCompressed whether to return compact (default), or full key
|
||||
* @returns shared public key
|
||||
*/
|
||||
function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = false): Uint8Array {
|
||||
function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = true): Uint8Array {
|
||||
if (isProbPub(privateA)) throw new TypeError('getSharedSecret: first arg must be private key');
|
||||
if (!isProbPub(publicB)) throw new TypeError('getSharedSecret: second arg must be public key');
|
||||
const b = normalizePublicKey(publicB);
|
||||
@ -1085,21 +1090,17 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
return b.multiply(normalizePrivateKey(privateA)).toRawBytes(isCompressed);
|
||||
}
|
||||
|
||||
// RFC6979 methods
|
||||
// Ensures ECDSA message hashes are 32 bytes and < curve order
|
||||
// RFC6979 suggest optional truncating via bits2octets
|
||||
// FIPS 186-4 Section 4.6 suggest the leftmost min(N, outlen) bits, where N = nBitLength, which is exactly what bits2int does
|
||||
// However, result of bits2int can be higher than order, but since there is same amount of bits, modulo operation
|
||||
// can be done via 'h >= n ? h - n : h'.
|
||||
// But we cannot use int2octets, since it pads small hash with zeros which should not happen on truncate as per RFC6979 vectors
|
||||
// RFC6979: ensure ECDSA msg is X bytes and < N. RFC suggests optional truncating via bits2octets.
|
||||
// FIPS 186-4 4.6 suggests the leftmost min(nBitLen, outLen) bits, which matches bits2int.
|
||||
// bits2int can produce res>N, we can do mod(res, N) since the bitLen is the same.
|
||||
// int2octets can't be used; pads small msgs with 0: unacceptatble for trunc as per RFC vectors
|
||||
const bits2int =
|
||||
CURVE.bits2int ||
|
||||
function (bytes: Uint8Array): bigint {
|
||||
// Truncate to nBitLength leftmost bits (kinda)
|
||||
// NOTE: for curves with nBitLength % 8 !== 0: bits2octets(bits2octets(hash)) !== bits2octets(hash)
|
||||
// for some cases, because bytes.length * 8 is not actual bitLength.
|
||||
const delta = bytes.length * 8 - CURVE.nBitLength;
|
||||
const num = ut.bytesToNumberBE(bytes);
|
||||
// For curves with nBitLength % 8 !== 0: bits2octets(bits2octets(m)) !== bits2octets(m)
|
||||
// for some cases, since bytes.length * 8 is not actual bitLength.
|
||||
const delta = bytes.length * 8 - CURVE.nBitLength; // truncate to nBitLength leftmost bits
|
||||
const num = ut.bytesToNumberBE(bytes); // check for == u8 done here
|
||||
return delta > 0 ? num >> BigInt(delta) : num;
|
||||
};
|
||||
const bits2int_modN =
|
||||
@ -1113,15 +1114,20 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
if (typeof num !== 'bigint') throw new Error('Expected bigint');
|
||||
if (!(_0n <= num && num < ORDER_MASK))
|
||||
throw new Error(`Expected number < 2^${CURVE.nBitLength}`);
|
||||
return ut.numberToBytesBE(num, CURVE.nByteLength); // works with order, can have different size than numToField!
|
||||
// works with order, can have different size than numToField!
|
||||
return ut.numberToBytesBE(num, CURVE.nByteLength);
|
||||
}
|
||||
// Steps A, D of RFC6979 3.2
|
||||
// Creates RFC6979 seed; converts msg/privKey to numbers.
|
||||
// Used only in sign, not in verify.
|
||||
// NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521.
|
||||
// Also it can be bigger for P224 + SHA256
|
||||
function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) {
|
||||
function prepSig(msgHash: Hex, privateKey: PrivKey, opts = defaultSigOpts) {
|
||||
if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`);
|
||||
if (['recovered', 'canonical'].some(k => k in opts)) // Ban legacy options
|
||||
throw new Error('sign() legacy options not supported');
|
||||
let { lowS } = opts; // generates low-s sigs by default
|
||||
if (lowS == null) lowS = true; // RFC6979 3.2: we skip step A, because
|
||||
// Step A is ignored, since we already provide hash instead of msg
|
||||
|
||||
// NOTE: instead of bits2int, we calling here truncateHash, since we need
|
||||
@ -1135,10 +1141,10 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
const d = normalizePrivateKey(privateKey);
|
||||
// K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k')
|
||||
const seedArgs = [int2octets(d), h1octets];
|
||||
// RFC6979 3.6: additional k' could be provided
|
||||
if (extraEntropy != null) {
|
||||
if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES);
|
||||
const e = ut.ensureBytes(extraEntropy);
|
||||
let ent = opts.extraEntropy; // RFC6979 3.6: additional k' (optional)
|
||||
if (ent != null) {
|
||||
if (ent === true) ent = CURVE.randomBytes(Fp.BYTES);
|
||||
const e = ut.ensureBytes(ent);
|
||||
if (e.length !== Fp.BYTES) throw new Error(`sign: Expected ${Fp.BYTES} bytes of extra data`);
|
||||
seedArgs.push(e);
|
||||
}
|
||||
@ -1147,41 +1153,32 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
// V, 0x00 are done in HmacDRBG constructor.
|
||||
const seed = ut.concatBytes(...seedArgs);
|
||||
const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash!
|
||||
return { seed, m, d };
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts signature params into point & r/s, checks them for validity.
|
||||
* k must be in range [1, n-1]
|
||||
* @param k signature's k param: deterministic in our case, random in non-rfc6979 sigs
|
||||
* @param m message that would be signed
|
||||
* @param d private key
|
||||
* @returns Signature with its point on curve Q OR undefined if params were invalid
|
||||
*/
|
||||
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined {
|
||||
const { n } = CURVE;
|
||||
// RFC 6979 Section 3.2, step 3: k = bits2int(T)
|
||||
const k = bits2int(kBytes); // Cannot use fields methods, since it is group element
|
||||
if (!isWithinCurveOrder(k)) return;
|
||||
// Important: all mod() calls in the function must be done over `n`
|
||||
const kinv = mod.invert(k, n);
|
||||
const q = Point.BASE.multiply(k);
|
||||
// r = x mod n
|
||||
const r = mod.mod(q.x, n);
|
||||
if (r === _0n) return;
|
||||
// s = (m + dr)/k mod n where x/k == x*inv(k)
|
||||
const s = mod.mod(kinv * mod.mod(m + mod.mod(d * r, n), n), n);
|
||||
if (s === _0n) return;
|
||||
// recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n
|
||||
let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n);
|
||||
let normS = s;
|
||||
if (lowS && isBiggerThanHalfOrder(s)) {
|
||||
normS = normalizeS(s);
|
||||
recovery ^= 1;
|
||||
// Converts signature params into point w r/s, checks result for validity.
|
||||
function k2sig(kBytes: Uint8Array): Signature | undefined {
|
||||
const { n } = CURVE;
|
||||
// RFC 6979 Section 3.2, step 3: k = bits2int(T)
|
||||
const k = bits2int(kBytes); // Cannot use fields methods, since it is group element
|
||||
if (!isWithinCurveOrder(k)) return;
|
||||
// Important: all mod() calls in the function must be done over `n`
|
||||
const ik = mod.invert(k, n);
|
||||
const q = Point.BASE.multiply(k);
|
||||
// r = x mod n
|
||||
const r = mod.mod(q.x, n);
|
||||
if (r === _0n) return;
|
||||
// s = (m + dr)/k mod n where x/k == x*inv(k)
|
||||
const s = mod.mod(ik * mod.mod(m + mod.mod(d * r, n), n), n);
|
||||
if (s === _0n) return;
|
||||
// recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n
|
||||
let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n);
|
||||
let normS = s;
|
||||
if (lowS && isBiggerThanHalfOrder(s)) {
|
||||
normS = normalizeS(s);
|
||||
recovery ^= 1;
|
||||
}
|
||||
return new Signature(r, normS, recovery);
|
||||
}
|
||||
return new Signature(r, normS, recovery);
|
||||
return { seed, k2sig };
|
||||
}
|
||||
|
||||
const defaultSigOpts: SignOpts = { lowS: CURVE.lowS };
|
||||
|
||||
/**
|
||||
@ -1196,15 +1193,9 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
||||
*/
|
||||
// TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg?
|
||||
function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature {
|
||||
// Steps A, D of RFC6979 3.2.
|
||||
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
|
||||
// Steps B, C, D, E, F, G
|
||||
const drbg = new HmacDrbg(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac);
|
||||
drbg.reseedSync(seed);
|
||||
// Step H3, repeat until k is in range [1, n-1]
|
||||
let sig: Signature | undefined;
|
||||
while (!(sig = kmdToSig(drbg.generateSync(), m, d, opts.lowS))) drbg.reseedSync();
|
||||
return sig;
|
||||
const { seed, k2sig } = prepSig(msgHash, privKey, opts); // Steps A, D of RFC6979 3.2.
|
||||
const genUntil = hmacDrbg<Signature>(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac);
|
||||
return genUntil(seed, k2sig); // Steps B, C, D, E, F, G
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -419,11 +419,11 @@ export class RistrettoPoint {
|
||||
return new RistrettoPoint(this.ep.subtract(other.ep));
|
||||
}
|
||||
|
||||
multiply(scalar: number | bigint): RistrettoPoint {
|
||||
multiply(scalar: bigint): RistrettoPoint {
|
||||
return new RistrettoPoint(this.ep.multiply(scalar));
|
||||
}
|
||||
|
||||
multiplyUnsafe(scalar: number | bigint): RistrettoPoint {
|
||||
multiplyUnsafe(scalar: bigint): RistrettoPoint {
|
||||
return new RistrettoPoint(this.ep.multiplyUnsafe(scalar));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user