weierstrass: make points compressed by def. Rewrite drbg, k generation.

This commit is contained in:
Paul Miller 2023-01-24 03:02:38 +00:00
parent 2b908ad602
commit 7fda6de619
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
4 changed files with 104 additions and 113 deletions

@ -9,7 +9,7 @@ export interface Group<T extends Group<T>> {
add(other: T): T;
subtract(other: T): T;
equals(other: T): boolean;
multiply(scalar: number | bigint): T;
multiply(scalar: bigint): T;
}
export type GroupConstructor<T> = {

@ -5,7 +5,7 @@ const _1n = BigInt(1);
const _2n = BigInt(2);
const str = (a: any): a is string => typeof a === 'string';
const big = (a: any): a is bigint => typeof a === 'bigint';
export const big = (a: any): a is bigint => typeof a === 'bigint';
const u8a = (a: any): a is Uint8Array => a instanceof Uint8Array;
// We accept hex strings besides Uint8Array for simplicity

@ -114,7 +114,7 @@ export interface ProjectivePointType<T> extends Group<ProjectivePointType<T>> {
readonly x: T;
readonly y: T;
readonly z: T;
multiply(scalar: number | bigint, affinePoint?: PointType<T>): ProjectivePointType<T>;
multiply(scalar: bigint, affinePoint?: PointType<T>): ProjectivePointType<T>;
multiplyUnsafe(scalar: bigint): ProjectivePointType<T>;
toAffine(invZ?: T): PointType<T>;
clearCofactor(): ProjectivePointType<T>;
@ -249,7 +249,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
* Validates if a scalar ("private number") is valid.
* Scalars are valid only if they are less than curve order.
*/
function normalizeScalar(num: number | bigint): bigint {
function normalizeScalar(num: bigint): bigint {
if (ut.isPositiveInt(num)) return BigInt(num);
if (typeof num === 'bigint' && isWithinCurveOrder(num)) return num;
throw new TypeError('Expected valid private scalar: 0 < scalar < curve.n');
@ -470,7 +470,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
* @param affinePoint optional point ot save cached precompute windows on it
* @returns New point
*/
multiply(scalar: number | bigint, affinePoint?: Point): ProjectivePoint {
multiply(scalar: bigint, affinePoint?: Point): ProjectivePoint {
let n = normalizeScalar(scalar);
// Real point.
@ -582,12 +582,12 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
return Point.BASE.multiply(normalizePrivateKey(privateKey));
}
toRawBytes(isCompressed = false): Uint8Array {
toRawBytes(isCompressed = true): Uint8Array {
this.assertValidity();
return CURVE.toBytes(Point, this, isCompressed);
}
toHex(isCompressed = false): string {
toHex(isCompressed = true): string {
return bytesToHex(this.toRawBytes(isCompressed));
}
// A point on curve is valid if it conforms to equation.
@ -635,7 +635,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>) {
return this.add(other.negate());
}
multiply(scalar: number | bigint) {
multiply(scalar: bigint) {
return this.toProj().multiply(scalar, this).toAffine();
}
@ -754,54 +754,59 @@ export type CurveFn = {
};
};
/**
* Minimal HMAC-DRBG (NIST 800-90) for signatures.
* Used only for RFC6979, does not fully implement DRBG spec.
*/
class HmacDrbg {
k: Uint8Array;
v: Uint8Array;
counter: number;
constructor(public hashLen: number, public qByteLen: number, public hmacFn: HmacFnSync) {
const u8n = (data?: any) => new Uint8Array(data); // creates Uint8Array
const u8fr = (arr: any) => Uint8Array.from(arr); // another shortcut
// Minimal HMAC-DRBG from NIST 800-90 for RFC6979 sigs.
type Pred<T> = (v: Uint8Array) => T | undefined;
function hmacDrbg<T>(
hashLen: number,
qByteLen: number,
hmacFn: HmacFnSync
): (seed: Uint8Array, predicate: Pred<T>) => T {
if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number');
if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number');
if (typeof hmacFn !== 'function') throw new Error('hmacFn must be a function');
// Step B, Step C: set hashLen to 8*ceil(hlen/8)
this.v = new Uint8Array(hashLen).fill(1);
this.k = new Uint8Array(hashLen).fill(0);
this.counter = 0;
}
private hmacSync(...values: Uint8Array[]) {
return this.hmacFn(this.k, ...values);
}
incr() {
if (this.counter >= 1000) throw new Error('Tried 1,000 k values for sign(), all were invalid');
this.counter += 1;
}
reseedSync(seed = new Uint8Array()) {
this.k = this.hmacSync(this.v, Uint8Array.from([0x00]), seed);
this.v = this.hmacSync(this.v);
let v = u8n(hashLen); // Minimal non-full-spec HMAC-DRBG from NIST 800-90 for RFC6979 sigs.
let k = u8n(hashLen); // Steps B and C of RFC6979 3.2: set hashLen, in our case always same
let i = 0; // Iterations counter, will throw when over 1000
const reset = () => {
v.fill(1);
k.fill(0);
i = 0;
};
const h = (...b: Uint8Array[]) => hmacFn(k, v, ...b); // hmac(k)(v, ...values)
const reseed = (seed = u8n()) => {
// HMAC-DRBG reseed() function. Steps D-G
k = h(u8fr([0x00]), seed); // k = hmac(k || v || 0x00 || seed)
v = h(); // v = hmac(k || v)
if (seed.length === 0) return;
this.k = this.hmacSync(this.v, Uint8Array.from([0x01]), seed);
this.v = this.hmacSync(this.v);
}
// TODO: review
generateSync(): Uint8Array {
this.incr();
k = h(u8fr([0x01]), seed); // k = hmac(k || v || 0x01 || seed)
v = h(); // v = hmac(k || v)
};
const gen = () => {
// HMAC-DRBG generate() function
if (i++ >= 1000) throw new Error('drbg: tried 1000 values');
let len = 0;
const out: Uint8Array[] = [];
while (len < this.qByteLen) {
this.v = this.hmacSync(this.v);
const sl = this.v.slice();
while (len < qByteLen) {
v = h();
const sl = v.slice();
out.push(sl);
len += this.v.length;
len += v.length;
}
return ut.concatBytes(...out);
};
const genUntil = (seed: Uint8Array, pred: Pred<T>): T => {
reset();
reseed(seed); // Steps D-G
let res: T | undefined = undefined; // Step H: grind until k is in [1..n-1]
while (!(res = pred(gen()))) reseed();
reset();
return res;
};
return genUntil;
}
// There are no guarantees with JS GC whether bigints are removed even if you clean Uint8Arrays.
}
export function weierstrass(curveDef: CurveType): CurveFn {
const CURVE = validateOpts(curveDef) as ReturnType<typeof validateOpts>;
const CURVE_ORDER = CURVE.n;
@ -1051,7 +1056,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
* @param isCompressed whether to return compact (default), or full key
* @returns Public key, full when isCompressed=false; short when isCompressed=true
*/
function getPublicKey(privateKey: PrivKey, isCompressed = false): Uint8Array {
function getPublicKey(privateKey: PrivKey, isCompressed = true): Uint8Array {
return Point.fromPrivateKey(privateKey).toRawBytes(isCompressed);
}
@ -1077,7 +1082,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
* @param isCompressed whether to return compact (default), or full key
* @returns shared public key
*/
function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = false): Uint8Array {
function getSharedSecret(privateA: PrivKey, publicB: PubKey, isCompressed = true): Uint8Array {
if (isProbPub(privateA)) throw new TypeError('getSharedSecret: first arg must be private key');
if (!isProbPub(publicB)) throw new TypeError('getSharedSecret: second arg must be public key');
const b = normalizePublicKey(publicB);
@ -1085,21 +1090,17 @@ export function weierstrass(curveDef: CurveType): CurveFn {
return b.multiply(normalizePrivateKey(privateA)).toRawBytes(isCompressed);
}
// RFC6979 methods
// Ensures ECDSA message hashes are 32 bytes and < curve order
// RFC6979 suggest optional truncating via bits2octets
// FIPS 186-4 Section 4.6 suggest the leftmost min(N, outlen) bits, where N = nBitLength, which is exactly what bits2int does
// However, result of bits2int can be higher than order, but since there is same amount of bits, modulo operation
// can be done via 'h >= n ? h - n : h'.
// But we cannot use int2octets, since it pads small hash with zeros which should not happen on truncate as per RFC6979 vectors
// RFC6979: ensure ECDSA msg is X bytes and < N. RFC suggests optional truncating via bits2octets.
// FIPS 186-4 4.6 suggests the leftmost min(nBitLen, outLen) bits, which matches bits2int.
// bits2int can produce res>N, we can do mod(res, N) since the bitLen is the same.
// int2octets can't be used; pads small msgs with 0: unacceptatble for trunc as per RFC vectors
const bits2int =
CURVE.bits2int ||
function (bytes: Uint8Array): bigint {
// Truncate to nBitLength leftmost bits (kinda)
// NOTE: for curves with nBitLength % 8 !== 0: bits2octets(bits2octets(hash)) !== bits2octets(hash)
// for some cases, because bytes.length * 8 is not actual bitLength.
const delta = bytes.length * 8 - CURVE.nBitLength;
const num = ut.bytesToNumberBE(bytes);
// For curves with nBitLength % 8 !== 0: bits2octets(bits2octets(m)) !== bits2octets(m)
// for some cases, since bytes.length * 8 is not actual bitLength.
const delta = bytes.length * 8 - CURVE.nBitLength; // truncate to nBitLength leftmost bits
const num = ut.bytesToNumberBE(bytes); // check for == u8 done here
return delta > 0 ? num >> BigInt(delta) : num;
};
const bits2int_modN =
@ -1113,15 +1114,20 @@ export function weierstrass(curveDef: CurveType): CurveFn {
if (typeof num !== 'bigint') throw new Error('Expected bigint');
if (!(_0n <= num && num < ORDER_MASK))
throw new Error(`Expected number < 2^${CURVE.nBitLength}`);
return ut.numberToBytesBE(num, CURVE.nByteLength); // works with order, can have different size than numToField!
// works with order, can have different size than numToField!
return ut.numberToBytesBE(num, CURVE.nByteLength);
}
// Steps A, D of RFC6979 3.2
// Creates RFC6979 seed; converts msg/privKey to numbers.
// Used only in sign, not in verify.
// NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521.
// Also it can be bigger for P224 + SHA256
function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) {
function prepSig(msgHash: Hex, privateKey: PrivKey, opts = defaultSigOpts) {
if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`);
if (['recovered', 'canonical'].some(k => k in opts)) // Ban legacy options
throw new Error('sign() legacy options not supported');
let { lowS } = opts; // generates low-s sigs by default
if (lowS == null) lowS = true; // RFC6979 3.2: we skip step A, because
// Step A is ignored, since we already provide hash instead of msg
// NOTE: instead of bits2int, we calling here truncateHash, since we need
@ -1135,10 +1141,10 @@ export function weierstrass(curveDef: CurveType): CurveFn {
const d = normalizePrivateKey(privateKey);
// K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k')
const seedArgs = [int2octets(d), h1octets];
// RFC6979 3.6: additional k' could be provided
if (extraEntropy != null) {
if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES);
const e = ut.ensureBytes(extraEntropy);
let ent = opts.extraEntropy; // RFC6979 3.6: additional k' (optional)
if (ent != null) {
if (ent === true) ent = CURVE.randomBytes(Fp.BYTES);
const e = ut.ensureBytes(ent);
if (e.length !== Fp.BYTES) throw new Error(`sign: Expected ${Fp.BYTES} bytes of extra data`);
seedArgs.push(e);
}
@ -1147,30 +1153,20 @@ export function weierstrass(curveDef: CurveType): CurveFn {
// V, 0x00 are done in HmacDRBG constructor.
const seed = ut.concatBytes(...seedArgs);
const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash!
return { seed, m, d };
}
/**
* Converts signature params into point & r/s, checks them for validity.
* k must be in range [1, n-1]
* @param k signature's k param: deterministic in our case, random in non-rfc6979 sigs
* @param m message that would be signed
* @param d private key
* @returns Signature with its point on curve Q OR undefined if params were invalid
*/
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined {
// Converts signature params into point w r/s, checks result for validity.
function k2sig(kBytes: Uint8Array): Signature | undefined {
const { n } = CURVE;
// RFC 6979 Section 3.2, step 3: k = bits2int(T)
const k = bits2int(kBytes); // Cannot use fields methods, since it is group element
if (!isWithinCurveOrder(k)) return;
// Important: all mod() calls in the function must be done over `n`
const kinv = mod.invert(k, n);
const ik = mod.invert(k, n);
const q = Point.BASE.multiply(k);
// r = x mod n
const r = mod.mod(q.x, n);
if (r === _0n) return;
// s = (m + dr)/k mod n where x/k == x*inv(k)
const s = mod.mod(kinv * mod.mod(m + mod.mod(d * r, n), n), n);
const s = mod.mod(ik * mod.mod(m + mod.mod(d * r, n), n), n);
if (s === _0n) return;
// recovery bit is usually 0 or 1; rarely it's 2 or 3, when q.x > n
let recovery = (q.x === r ? 0 : 2) | Number(q.y & _1n);
@ -1181,7 +1177,8 @@ export function weierstrass(curveDef: CurveType): CurveFn {
}
return new Signature(r, normS, recovery);
}
return { seed, k2sig };
}
const defaultSigOpts: SignOpts = { lowS: CURVE.lowS };
/**
@ -1196,15 +1193,9 @@ export function weierstrass(curveDef: CurveType): CurveFn {
*/
// TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg?
function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature {
// Steps A, D of RFC6979 3.2.
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
// Steps B, C, D, E, F, G
const drbg = new HmacDrbg(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac);
drbg.reseedSync(seed);
// Step H3, repeat until k is in range [1, n-1]
let sig: Signature | undefined;
while (!(sig = kmdToSig(drbg.generateSync(), m, d, opts.lowS))) drbg.reseedSync();
return sig;
const { seed, k2sig } = prepSig(msgHash, privKey, opts); // Steps A, D of RFC6979 3.2.
const genUntil = hmacDrbg<Signature>(CURVE.hash.outputLen, CURVE.nByteLength, CURVE.hmac);
return genUntil(seed, k2sig); // Steps B, C, D, E, F, G
}
/**

@ -419,11 +419,11 @@ export class RistrettoPoint {
return new RistrettoPoint(this.ep.subtract(other.ep));
}
multiply(scalar: number | bigint): RistrettoPoint {
multiply(scalar: bigint): RistrettoPoint {
return new RistrettoPoint(this.ep.multiply(scalar));
}
multiplyUnsafe(scalar: number | bigint): RistrettoPoint {
multiplyUnsafe(scalar: bigint): RistrettoPoint {
return new RistrettoPoint(this.ep.multiplyUnsafe(scalar));
}
}