From 7fda6de61939fc4914fc061772f20c77a9794ddf Mon Sep 17 00:00:00 2001 From: Paul Miller Date: Tue, 24 Jan 2023 03:02:38 +0000 Subject: [PATCH] weierstrass: make points compressed by def. Rewrite drbg, k generation. --- src/abstract/group.ts | 2 +- src/abstract/utils.ts | 2 +- src/abstract/weierstrass.ts | 209 +++++++++++++++++------------------- src/ed25519.ts | 4 +- 4 files changed, 104 insertions(+), 113 deletions(-) diff --git a/src/abstract/group.ts b/src/abstract/group.ts index 09fdc8b..fc36403 100644 --- a/src/abstract/group.ts +++ b/src/abstract/group.ts @@ -9,7 +9,7 @@ export interface Group> { add(other: T): T; subtract(other: T): T; equals(other: T): boolean; - multiply(scalar: number | bigint): T; + multiply(scalar: bigint): T; } export type GroupConstructor = { diff --git a/src/abstract/utils.ts b/src/abstract/utils.ts index 700a35d..5d0a6aa 100644 --- a/src/abstract/utils.ts +++ b/src/abstract/utils.ts @@ -5,7 +5,7 @@ const _1n = BigInt(1); const _2n = BigInt(2); const str = (a: any): a is string => typeof a === 'string'; -const big = (a: any): a is bigint => typeof a === 'bigint'; +export const big = (a: any): a is bigint => typeof a === 'bigint'; const u8a = (a: any): a is Uint8Array => a instanceof Uint8Array; // We accept hex strings besides Uint8Array for simplicity diff --git a/src/abstract/weierstrass.ts b/src/abstract/weierstrass.ts index 410d63e..d6492a7 100644 --- a/src/abstract/weierstrass.ts +++ b/src/abstract/weierstrass.ts @@ -114,7 +114,7 @@ export interface ProjectivePointType extends Group> { readonly x: T; readonly y: T; readonly z: T; - multiply(scalar: number | bigint, affinePoint?: PointType): ProjectivePointType; + multiply(scalar: bigint, affinePoint?: PointType): ProjectivePointType; multiplyUnsafe(scalar: bigint): ProjectivePointType; toAffine(invZ?: T): PointType; clearCofactor(): ProjectivePointType; @@ -249,7 +249,7 @@ export function weierstrassPoints(opts: CurvePointsType) { * Validates if a scalar ("private number") is valid. * Scalars are valid only if they are less than curve order. */ - function normalizeScalar(num: number | bigint): bigint { + function normalizeScalar(num: bigint): bigint { if (ut.isPositiveInt(num)) return BigInt(num); if (typeof num === 'bigint' && isWithinCurveOrder(num)) return num; throw new TypeError('Expected valid private scalar: 0 < scalar < curve.n'); @@ -470,7 +470,7 @@ export function weierstrassPoints(opts: CurvePointsType) { * @param affinePoint optional point ot save cached precompute windows on it * @returns New point */ - multiply(scalar: number | bigint, affinePoint?: Point): ProjectivePoint { + multiply(scalar: bigint, affinePoint?: Point): ProjectivePoint { let n = normalizeScalar(scalar); // Real point. @@ -582,12 +582,12 @@ export function weierstrassPoints(opts: CurvePointsType) { return Point.BASE.multiply(normalizePrivateKey(privateKey)); } - toRawBytes(isCompressed = false): Uint8Array { + toRawBytes(isCompressed = true): Uint8Array { this.assertValidity(); return CURVE.toBytes(Point, this, isCompressed); } - toHex(isCompressed = false): string { + toHex(isCompressed = true): string { return bytesToHex(this.toRawBytes(isCompressed)); } // A point on curve is valid if it conforms to equation. @@ -635,7 +635,7 @@ export function weierstrassPoints(opts: CurvePointsType) { return this.add(other.negate()); } - multiply(scalar: number | bigint) { + multiply(scalar: bigint) { return this.toProj().multiply(scalar, this).toAffine(); } @@ -754,54 +754,59 @@ export type CurveFn = { }; }; -/** - * Minimal HMAC-DRBG (NIST 800-90) for signatures. - * Used only for RFC6979, does not fully implement DRBG spec. - */ -class HmacDrbg { - k: Uint8Array; - v: Uint8Array; - counter: number; - constructor(public hashLen: number, public qByteLen: number, public hmacFn: HmacFnSync) { - if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number'); - if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number'); - if (typeof hmacFn !== 'function') throw new Error('hmacFn must be a function'); - // Step B, Step C: set hashLen to 8*ceil(hlen/8) - this.v = new Uint8Array(hashLen).fill(1); - this.k = new Uint8Array(hashLen).fill(0); - this.counter = 0; - } - private hmacSync(...values: Uint8Array[]) { - return this.hmacFn(this.k, ...values); - } - incr() { - if (this.counter >= 1000) throw new Error('Tried 1,000 k values for sign(), all were invalid'); - this.counter += 1; - } - reseedSync(seed = new Uint8Array()) { - this.k = this.hmacSync(this.v, Uint8Array.from([0x00]), seed); - this.v = this.hmacSync(this.v); +const u8n = (data?: any) => new Uint8Array(data); // creates Uint8Array +const u8fr = (arr: any) => Uint8Array.from(arr); // another shortcut +// Minimal HMAC-DRBG from NIST 800-90 for RFC6979 sigs. +type Pred = (v: Uint8Array) => T | undefined; +function hmacDrbg( + hashLen: number, + qByteLen: number, + hmacFn: HmacFnSync +): (seed: Uint8Array, predicate: Pred) => T { + if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number'); + if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number'); + if (typeof hmacFn !== 'function') throw new Error('hmacFn must be a function'); + // Step B, Step C: set hashLen to 8*ceil(hlen/8) + let v = u8n(hashLen); // Minimal non-full-spec HMAC-DRBG from NIST 800-90 for RFC6979 sigs. + let k = u8n(hashLen); // Steps B and C of RFC6979 3.2: set hashLen, in our case always same + let i = 0; // Iterations counter, will throw when over 1000 + const reset = () => { + v.fill(1); + k.fill(0); + i = 0; + }; + const h = (...b: Uint8Array[]) => hmacFn(k, v, ...b); // hmac(k)(v, ...values) + const reseed = (seed = u8n()) => { + // HMAC-DRBG reseed() function. Steps D-G + k = h(u8fr([0x00]), seed); // k = hmac(k || v || 0x00 || seed) + v = h(); // v = hmac(k || v) if (seed.length === 0) return; - this.k = this.hmacSync(this.v, Uint8Array.from([0x01]), seed); - this.v = this.hmacSync(this.v); - } - // TODO: review - generateSync(): Uint8Array { - this.incr(); - + k = h(u8fr([0x01]), seed); // k = hmac(k || v || 0x01 || seed) + v = h(); // v = hmac(k || v) + }; + const gen = () => { + // HMAC-DRBG generate() function + if (i++ >= 1000) throw new Error('drbg: tried 1000 values'); let len = 0; const out: Uint8Array[] = []; - while (len < this.qByteLen) { - this.v = this.hmacSync(this.v); - const sl = this.v.slice(); + while (len < qByteLen) { + v = h(); + const sl = v.slice(); out.push(sl); - len += this.v.length; + len += v.length; } return ut.concatBytes(...out); - } - // There are no guarantees with JS GC whether bigints are removed even if you clean Uint8Arrays. + }; + const genUntil = (seed: Uint8Array, pred: Pred): T => { + reset(); + reseed(seed); // Steps D-G + let res: T | undefined = undefined; // Step H: grind until k is in [1..n-1] + while (!(res = pred(gen()))) reseed(); + reset(); + return res; + }; + return genUntil; } - export function weierstrass(curveDef: CurveType): CurveFn { const CURVE = validateOpts(curveDef) as ReturnType; const CURVE_ORDER = CURVE.n; @@ -1051,7 +1056,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { * @param isCompressed whether to return compact (default), or full key * @returns Public key, full when isCompressed=false; short when isCompressed=true */ - function getPublicKey(privateKey: PrivKey, isCompressed = false): Uint8Array { + function getPublicKey(privateKey: PrivKey, isCompressed = true): Uint8Array { return Point.fromPrivateKey(privateKey).toRawBytes(isCompressed); } @@ -1077,7 +1082,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { * @param isCompressed whether to return compact (default), or full key * @returns shared public key */ - function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = false): Uint8Array { + function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = true): Uint8Array { if (isProbPub(privateA)) throw new TypeError('getSharedSecret: first arg must be private key'); if (!isProbPub(publicB)) throw new TypeError('getSharedSecret: second arg must be public key'); const b = normalizePublicKey(publicB); @@ -1085,21 +1090,17 @@ export function weierstrass(curveDef: CurveType): CurveFn { return b.multiply(normalizePrivateKey(privateA)).toRawBytes(isCompressed); } - // RFC6979 methods - // Ensures ECDSA message hashes are 32 bytes and < curve order - // RFC6979 suggest optional truncating via bits2octets - // FIPS 186-4 Section 4.6 suggest the leftmost min(N, outlen) bits, where N = nBitLength, which is exactly what bits2int does - // However, result of bits2int can be higher than order, but since there is same amount of bits, modulo operation - // can be done via 'h >= n ? h - n : h'. - // But we cannot use int2octets, since it pads small hash with zeros which should not happen on truncate as per RFC6979 vectors + // RFC6979: ensure ECDSA msg is X bytes and < N. RFC suggests optional truncating via bits2octets. + // FIPS 186-4 4.6 suggests the leftmost min(nBitLen, outLen) bits, which matches bits2int. + // bits2int can produce res>N, we can do mod(res, N) since the bitLen is the same. + // int2octets can't be used; pads small msgs with 0: unacceptatble for trunc as per RFC vectors const bits2int = CURVE.bits2int || function (bytes: Uint8Array): bigint { - // Truncate to nBitLength leftmost bits (kinda) - // NOTE: for curves with nBitLength % 8 !== 0: bits2octets(bits2octets(hash)) !== bits2octets(hash) - // for some cases, because bytes.length * 8 is not actual bitLength. - const delta = bytes.length * 8 - CURVE.nBitLength; - const num = ut.bytesToNumberBE(bytes); + // For curves with nBitLength % 8 !== 0: bits2octets(bits2octets(m)) !== bits2octets(m) + // for some cases, since bytes.length * 8 is not actual bitLength. + const delta = bytes.length * 8 - CURVE.nBitLength; // truncate to nBitLength leftmost bits + const num = ut.bytesToNumberBE(bytes); // check for == u8 done here return delta > 0 ? num >> BigInt(delta) : num; }; const bits2int_modN = @@ -1113,15 +1114,20 @@ export function weierstrass(curveDef: CurveType): CurveFn { if (typeof num !== 'bigint') throw new Error('Expected bigint'); if (!(_0n <= num && num < ORDER_MASK)) throw new Error(`Expected number < 2^${CURVE.nBitLength}`); - return ut.numberToBytesBE(num, CURVE.nByteLength); // works with order, can have different size than numToField! + // works with order, can have different size than numToField! + return ut.numberToBytesBE(num, CURVE.nByteLength); } // Steps A, D of RFC6979 3.2 // Creates RFC6979 seed; converts msg/privKey to numbers. // Used only in sign, not in verify. // NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521. // Also it can be bigger for P224 + SHA256 - function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) { + function prepSig(msgHash: Hex, privateKey: PrivKey, opts = defaultSigOpts) { if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`); + if (['recovered', 'canonical'].some(k => k in opts)) // Ban legacy options + throw new Error('sign() legacy options not supported'); + let { lowS } = opts; // generates low-s sigs by default + if (lowS == null) lowS = true; // RFC6979 3.2: we skip step A, because // Step A is ignored, since we already provide hash instead of msg // NOTE: instead of bits2int, we calling here truncateHash, since we need @@ -1135,10 +1141,10 @@ export function weierstrass(curveDef: CurveType): CurveFn { const d = normalizePrivateKey(privateKey); // K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k') const seedArgs = [int2octets(d), h1octets]; - // RFC6979 3.6: additional k' could be provided - if (extraEntropy != null) { - if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES); - const e = ut.ensureBytes(extraEntropy); + let ent = opts.extraEntropy; // RFC6979 3.6: additional k' (optional) + if (ent != null) { + if (ent === true) ent = CURVE.randomBytes(Fp.BYTES); + const e = ut.ensureBytes(ent); if (e.length !== Fp.BYTES) throw new Error(`sign: Expected ${Fp.BYTES} bytes of extra data`); seedArgs.push(e); } @@ -1147,41 +1153,32 @@ export function weierstrass(curveDef: CurveType): CurveFn { // V, 0x00 are done in HmacDRBG constructor. const seed = ut.concatBytes(...seedArgs); const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash! - return { seed, m, d }; - } - - /** - * Converts signature params into point & r/s, checks them for validity. - * k must be in range [1, n-1] - * @param k signature's k param: deterministic in our case, random in non-rfc6979 sigs - * @param m message that would be signed - * @param d private key - * @returns Signature with its point on curve Q OR undefined if params were invalid - */ - function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined { - const { n } = CURVE; - // RFC 6979 Section 3.2, step 3: k = bits2int(T) - const k = bits2int(kBytes); // Cannot use fields methods, since it is group element - if (!isWithinCurveOrder(k)) return; - // Important: all mod() calls in the function must be done over `n` - const kinv = mod.invert(k, n); - const q = Point.BASE.multiply(k); - // r = x mod n - const r = mod.mod(q.x, n); - if (r === _0n) return; - // s = (m + dr)/k mod n where x/k == x*inv(k) - const s = mod.mod(kinv * mod.mod(m + mod.mod(d * r, n), n), n); - if (s === _0n) return; - // recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n - let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n); - let normS = s; - if (lowS && isBiggerThanHalfOrder(s)) { - normS = normalizeS(s); - recovery ^= 1; + // Converts signature params into point w r/s, checks result for validity. + function k2sig(kBytes: Uint8Array): Signature | undefined { + const { n } = CURVE; + // RFC 6979 Section 3.2, step 3: k = bits2int(T) + const k = bits2int(kBytes); // Cannot use fields methods, since it is group element + if (!isWithinCurveOrder(k)) return; + // Important: all mod() calls in the function must be done over `n` + const ik = mod.invert(k, n); + const q = Point.BASE.multiply(k); + // r = x mod n + const r = mod.mod(q.x, n); + if (r === _0n) return; + // s = (m + dr)/k mod n where x/k == x*inv(k) + const s = mod.mod(ik * mod.mod(m + mod.mod(d * r, n), n), n); + if (s === _0n) return; + // recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n + let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n); + let normS = s; + if (lowS && isBiggerThanHalfOrder(s)) { + normS = normalizeS(s); + recovery ^= 1; + } + return new Signature(r, normS, recovery); } - return new Signature(r, normS, recovery); + return { seed, k2sig }; } - const defaultSigOpts: SignOpts = { lowS: CURVE.lowS }; /** @@ -1196,15 +1193,9 @@ export function weierstrass(curveDef: CurveType): CurveFn { */ // TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg? function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature { - // Steps A, D of RFC6979 3.2. - const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy); - // Steps B, C, D, E, F, G - const drbg = new HmacDrbg(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac); - drbg.reseedSync(seed); - // Step H3, repeat until k is in range [1, n-1] - let sig: Signature | undefined; - while (!(sig = kmdToSig(drbg.generateSync(), m, d, opts.lowS))) drbg.reseedSync(); - return sig; + const { seed, k2sig } = prepSig(msgHash, privKey, opts); // Steps A, D of RFC6979 3.2. + const genUntil = hmacDrbg(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac); + return genUntil(seed, k2sig); // Steps B, C, D, E, F, G } /** diff --git a/src/ed25519.ts b/src/ed25519.ts index ebf8cd7..a5f873c 100644 --- a/src/ed25519.ts +++ b/src/ed25519.ts @@ -419,11 +419,11 @@ export class RistrettoPoint { return new RistrettoPoint(this.ep.subtract(other.ep)); } - multiply(scalar: number | bigint): RistrettoPoint { + multiply(scalar: bigint): RistrettoPoint { return new RistrettoPoint(this.ep.multiply(scalar)); } - multiplyUnsafe(scalar: number | bigint): RistrettoPoint { + multiplyUnsafe(scalar: bigint): RistrettoPoint { return new RistrettoPoint(this.ep.multiplyUnsafe(scalar)); } }