diff --git a/src/abstract/utils.ts b/src/abstract/utils.ts index acd4d89..2a7baf3 100644 --- a/src/abstract/utils.ts +++ b/src/abstract/utils.ts @@ -116,6 +116,12 @@ export function bytesToNumberLE(uint8a: Uint8Array): bigint { export const numberToBytesBE = (n: bigint, len: number) => hexToBytes(n.toString(16).padStart(len * 2, '0')); export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse(); +// Returns variable number bytes (minimal bigint encoding?) +export const numberToVarBytesBE = (n: bigint) => { + let hex = n.toString(16); + if (hex.length & 1) hex = '0' + hex; + return hexToBytes(hex); +}; export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array { // Uint8Array.from() instead of hash.slice() because node.js Buffer diff --git a/src/abstract/weierstrass.ts b/src/abstract/weierstrass.ts index 7263d8f..b87f1f0 100644 --- a/src/abstract/weierstrass.ts +++ b/src/abstract/weierstrass.ts @@ -708,7 +708,7 @@ export type CurveType = BasicCurve & { hash: ut.CHash; // Because we need outputLen for DRBG hmac: HmacFnSync; randomBytes: (bytesLength?: number) => Uint8Array; - truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => bigint; + truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => Uint8Array; }; function validateOpts(curve: CurveType) { @@ -881,18 +881,15 @@ export function weierstrass(curveDef: CurveType): CurveFn { return isBiggerThanHalfOrder(s) ? mod.mod(-s, CURVE_ORDER) : s; } - function bits2int_2(bytes: Uint8Array): bigint { - const delta = bytes.length * 8 - CURVE.nBitLength; - const num = ut.bytesToNumberBE(bytes); - return delta > 0 ? num >> BigInt(delta) : num; - } - // Ensures ECDSA message hashes are 32 bytes and < curve order - function _truncateHash(hash: Uint8Array, truncateOnly = false): bigint { - const h = bits2int_2(hash); - if (truncateOnly) return h; - const { n } = CURVE; - return h >= n ? h - n : h; + // RFC6979 suggest optional truncating via bits2octets + // FIPS 186-4 Section 4.6 suggest the leftmost min(N, outlen) bits, where N = nBitLength, which is exactly what bits2int does + // However, result of bits2int can be higher than order, but since there is same amount of bits, modulo operation + // can be done via 'h >= n ? h - n : h'. + // But we cannot use int2octets, since it pads small hash with zeros which should not happen on truncate as per RFC6979 vectors + function _truncateHash(hash: Uint8Array, truncateOnly = false): Uint8Array { + const num = bits2int(hash); + return ut.numberToVarBytesBE(truncateOnly ? num : mod.mod(num, CURVE_ORDER)); // same as bits2octets but without zero padding } const truncateHash = CURVE.truncateHash || _truncateHash; @@ -956,7 +953,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { const { r, s, recovery } = this; if (recovery == null) throw new Error('Cannot recover: recovery bit is not present'); if (![0, 1, 2, 3].includes(recovery)) throw new Error('Cannot recover: invalid recovery bit'); - const h = truncateHash(ut.ensureBytes(msgHash)); + const h = ut.bytesToNumberBE(truncateHash(ut.ensureBytes(msgHash))); const { n } = CURVE; const radj = recovery === 2 || recovery === 3 ? r + n : r; if (radj >= Fp.ORDER) throw new Error('Cannot recover: bit 2/3 is invalid with current r'); @@ -1099,38 +1096,42 @@ export function weierstrass(curveDef: CurveType): CurveFn { // RFC6979 methods function bits2int(bytes: Uint8Array): bigint { - const { nByteLength } = CURVE; - if (!(bytes instanceof Uint8Array)) throw new Error('Expected Uint8Array'); - const slice = bytes.length > nByteLength ? bytes.slice(0, nByteLength) : bytes; - // const slice = bytes; nByteLength; nBitLength; - let num = ut.bytesToNumberBE(slice); - // const { nBitLength } = CURVE; - // const delta = (bytes.length * 8) - nBitLength; - // if (delta > 0) { - // // console.log('bits=', bytes.length*8, 'CURVE n=', nBitLength, 'delta=', delta); - // // console.log(bytes.length, nBitLength, delta); - // // console.log(bytes, new Error().stack); - // num >>= BigInt(delta); - // } - return num; - } - function bits2octets(bytes: Uint8Array): Uint8Array { - const z1 = bits2int(bytes); - const z2 = mod.mod(z1, CURVE_ORDER); - return int2octets(z2 < _0n ? z1 : z2); + // Truncate to nBitLength leftmost bits (kinda) + // NOTE: for curves with nBitLength % 8 !== 0: bits2octets(bits2octets(hash)) !== bits2octets(hash) + // for some cases, because bytes.length * 8 is not actual bitLength. + const delta = bytes.length * 8 - CURVE.nBitLength; + const num = ut.bytesToNumberBE(bytes); + return delta > 0 ? num >> BigInt(delta) : num; } + // NOTE: pads output with zero as per spec + const ORDER_MASK = ut.bitMask(CURVE.nBitLength); function int2octets(num: bigint): Uint8Array { - return numToField(num); // prohibits >nByteLength bytes + if (typeof num !== 'bigint') throw new Error('Expected bigint'); + if (!(_0n <= num && num < ORDER_MASK)) + throw new Error(`Expected number < 2^${CURVE.nBitLength}`); + return ut.numberToBytesBE(num, CURVE.nByteLength); // works with order, can have different size than numToField! } // Steps A, D of RFC6979 3.2 // Creates RFC6979 seed; converts msg/privKey to numbers. + // Used only in sign, not in verify. + // NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521. + // Also it can be bigger for P224 + SHA256 function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) { if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`); // Step A is ignored, since we already provide hash instead of msg - const h1 = numToField(truncateHash(ut.ensureBytes(msgHash))); + + // NOTE: instead of bits2int, we calling here truncateHash, since we need + // custom truncation for stark. For other curves it is essentially same as calling bits2int + mod + // However, we cannot later call bits2octets (which is truncateHash + int2octets), since nested bits2int is broken + // for curves where nBitLength % 8 !== 0, so we unwrap it here as int2octets call. + // const bits2octets = (bits)=>int2octets(ut.bytesToNumberBE(truncateHash(bits))) + const h1 = truncateHash(ut.ensureBytes(msgHash)); + const h1int = ut.bytesToNumberBE(h1); + const h1octets = int2octets(h1int); + const d = normalizePrivateKey(privateKey); // K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k') - const seedArgs = [int2octets(d), bits2octets(h1)]; + const seedArgs = [int2octets(d), h1octets]; // RFC6979 3.6: additional k' could be provided if (extraEntropy != null) { if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES); @@ -1142,7 +1143,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { // Step D // V, 0x00 are done in HmacDRBG constructor. const seed = ut.concatBytes(...seedArgs); - const m = bits2int(h1); + const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash! return { seed, m, d }; } @@ -1156,7 +1157,8 @@ export function weierstrass(curveDef: CurveType): CurveFn { */ function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined { const { n } = CURVE; - const k = truncateHash(kBytes, true); + // RFC 6979 Section 3.2, step 3: k = bits2int(T) + const k = ut.bytesToNumberBE(truncateHash(kBytes, true)); // Cannot use fields methods, since it is group element if (!isWithinCurveOrder(k)) return; // Important: all mod() calls in the function must be done over `n` const kinv = mod.invert(k, n); @@ -1189,6 +1191,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { * ``` * @param opts `lowS, extraEntropy` */ + // TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg? function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature { // Steps A, D of RFC6979 3.2. const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy); @@ -1256,7 +1259,7 @@ export function weierstrass(curveDef: CurveType): CurveFn { } const { n } = CURVE; const { r, s } = signature; - const h = truncateHash(msgHash); + const h = ut.bytesToNumberBE(truncateHash(msgHash)); // Cannot use fields methods, since it is group element const sinv = mod.invert(s, n); // s^-1 // R = u1⋅G - u2⋅P const u1 = mod.mod(h * sinv, n); diff --git a/src/stark.ts b/src/stark.ts index b71895a..6281e6a 100644 --- a/src/stark.ts +++ b/src/stark.ts @@ -3,7 +3,7 @@ import { keccak_256 } from '@noble/hashes/sha3'; import { sha256 } from '@noble/hashes/sha256'; import { weierstrass, ProjectivePointType } from './abstract/weierstrass.js'; import * as cutils from './abstract/utils.js'; -import { Fp } from './abstract/modular.js'; +import { Fp, mod } from './abstract/modular.js'; import { getHash } from './_shortw_utils.js'; type ProjectivePoint = ProjectivePointType; @@ -31,8 +31,7 @@ export const starkCurve = weierstrass({ // Default options lowS: false, ...getHash(sha256), - truncateHash: (hash: Uint8Array, truncateOnly = false): bigint => { - // TODO: cleanup, ugly code + truncateHash: (hash: Uint8Array, truncateOnly = false): Uint8Array => { // Fix truncation if (!truncateOnly) { let hashS = bytesToNumber0x(hash).toString(16); @@ -43,12 +42,13 @@ export const starkCurve = weierstrass({ } // Truncate zero bytes on left (compat with elliptic) while (hash[0] === 0) hash = hash.subarray(1); + // bits2int + part of bits2octets (mod if !truncateOnly) const byteLength = hash.length; const delta = byteLength * 8 - nBitLength; // size of curve.n (252 bits) let h = hash.length ? bytesToNumber0x(hash) : 0n; - if (delta > 0) h = h >> BigInt(delta); - if (!truncateOnly && h >= CURVE_N) h -= CURVE_N; - return h; + if (delta > 0) h = h >> BigInt(delta); // truncate to nBitLength leftmost bits + if (!truncateOnly) h = mod(h, CURVE_N); + return cutils.numberToVarBytesBE(h); }, });