forked from tornado-packages/noble-curves
weierstrass: bits2int, int2octets, truncateHash now comply with standard
This commit is contained in:
parent
1e47bf2372
commit
6f99f6042e
@ -116,6 +116,12 @@ export function bytesToNumberLE(uint8a: Uint8Array): bigint {
|
|||||||
export const numberToBytesBE = (n: bigint, len: number) =>
|
export const numberToBytesBE = (n: bigint, len: number) =>
|
||||||
hexToBytes(n.toString(16).padStart(len * 2, '0'));
|
hexToBytes(n.toString(16).padStart(len * 2, '0'));
|
||||||
export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse();
|
export const numberToBytesLE = (n: bigint, len: number) => numberToBytesBE(n, len).reverse();
|
||||||
|
// Returns variable number bytes (minimal bigint encoding?)
|
||||||
|
export const numberToVarBytesBE = (n: bigint) => {
|
||||||
|
let hex = n.toString(16);
|
||||||
|
if (hex.length & 1) hex = '0' + hex;
|
||||||
|
return hexToBytes(hex);
|
||||||
|
};
|
||||||
|
|
||||||
export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
|
export function ensureBytes(hex: Hex, expectedLength?: number): Uint8Array {
|
||||||
// Uint8Array.from() instead of hash.slice() because node.js Buffer
|
// Uint8Array.from() instead of hash.slice() because node.js Buffer
|
||||||
|
@ -708,7 +708,7 @@ export type CurveType = BasicCurve<bigint> & {
|
|||||||
hash: ut.CHash; // Because we need outputLen for DRBG
|
hash: ut.CHash; // Because we need outputLen for DRBG
|
||||||
hmac: HmacFnSync;
|
hmac: HmacFnSync;
|
||||||
randomBytes: (bytesLength?: number) => Uint8Array;
|
randomBytes: (bytesLength?: number) => Uint8Array;
|
||||||
truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => bigint;
|
truncateHash?: (hash: Uint8Array, truncateOnly?: boolean) => Uint8Array;
|
||||||
};
|
};
|
||||||
|
|
||||||
function validateOpts(curve: CurveType) {
|
function validateOpts(curve: CurveType) {
|
||||||
@ -881,18 +881,15 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
return isBiggerThanHalfOrder(s) ? mod.mod(-s, CURVE_ORDER) : s;
|
return isBiggerThanHalfOrder(s) ? mod.mod(-s, CURVE_ORDER) : s;
|
||||||
}
|
}
|
||||||
|
|
||||||
function bits2int_2(bytes: Uint8Array): bigint {
|
|
||||||
const delta = bytes.length * 8 - CURVE.nBitLength;
|
|
||||||
const num = ut.bytesToNumberBE(bytes);
|
|
||||||
return delta > 0 ? num >> BigInt(delta) : num;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensures ECDSA message hashes are 32 bytes and < curve order
|
// Ensures ECDSA message hashes are 32 bytes and < curve order
|
||||||
function _truncateHash(hash: Uint8Array, truncateOnly = false): bigint {
|
// RFC6979 suggest optional truncating via bits2octets
|
||||||
const h = bits2int_2(hash);
|
// FIPS 186-4 Section 4.6 suggest the leftmost min(N, outlen) bits, where N = nBitLength, which is exactly what bits2int does
|
||||||
if (truncateOnly) return h;
|
// However, result of bits2int can be higher than order, but since there is same amount of bits, modulo operation
|
||||||
const { n } = CURVE;
|
// can be done via 'h >= n ? h - n : h'.
|
||||||
return h >= n ? h - n : h;
|
// But we cannot use int2octets, since it pads small hash with zeros which should not happen on truncate as per RFC6979 vectors
|
||||||
|
function _truncateHash(hash: Uint8Array, truncateOnly = false): Uint8Array {
|
||||||
|
const num = bits2int(hash);
|
||||||
|
return ut.numberToVarBytesBE(truncateOnly ? num : mod.mod(num, CURVE_ORDER)); // same as bits2octets but without zero padding
|
||||||
}
|
}
|
||||||
const truncateHash = CURVE.truncateHash || _truncateHash;
|
const truncateHash = CURVE.truncateHash || _truncateHash;
|
||||||
|
|
||||||
@ -956,7 +953,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
const { r, s, recovery } = this;
|
const { r, s, recovery } = this;
|
||||||
if (recovery == null) throw new Error('Cannot recover: recovery bit is not present');
|
if (recovery == null) throw new Error('Cannot recover: recovery bit is not present');
|
||||||
if (![0, 1, 2, 3].includes(recovery)) throw new Error('Cannot recover: invalid recovery bit');
|
if (![0, 1, 2, 3].includes(recovery)) throw new Error('Cannot recover: invalid recovery bit');
|
||||||
const h = truncateHash(ut.ensureBytes(msgHash));
|
const h = ut.bytesToNumberBE(truncateHash(ut.ensureBytes(msgHash)));
|
||||||
const { n } = CURVE;
|
const { n } = CURVE;
|
||||||
const radj = recovery === 2 || recovery === 3 ? r + n : r;
|
const radj = recovery === 2 || recovery === 3 ? r + n : r;
|
||||||
if (radj >= Fp.ORDER) throw new Error('Cannot recover: bit 2/3 is invalid with current r');
|
if (radj >= Fp.ORDER) throw new Error('Cannot recover: bit 2/3 is invalid with current r');
|
||||||
@ -1099,38 +1096,42 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
|
|
||||||
// RFC6979 methods
|
// RFC6979 methods
|
||||||
function bits2int(bytes: Uint8Array): bigint {
|
function bits2int(bytes: Uint8Array): bigint {
|
||||||
const { nByteLength } = CURVE;
|
// Truncate to nBitLength leftmost bits (kinda)
|
||||||
if (!(bytes instanceof Uint8Array)) throw new Error('Expected Uint8Array');
|
// NOTE: for curves with nBitLength % 8 !== 0: bits2octets(bits2octets(hash)) !== bits2octets(hash)
|
||||||
const slice = bytes.length > nByteLength ? bytes.slice(0, nByteLength) : bytes;
|
// for some cases, because bytes.length * 8 is not actual bitLength.
|
||||||
// const slice = bytes; nByteLength; nBitLength;
|
const delta = bytes.length * 8 - CURVE.nBitLength;
|
||||||
let num = ut.bytesToNumberBE(slice);
|
const num = ut.bytesToNumberBE(bytes);
|
||||||
// const { nBitLength } = CURVE;
|
return delta > 0 ? num >> BigInt(delta) : num;
|
||||||
// const delta = (bytes.length * 8) - nBitLength;
|
|
||||||
// if (delta > 0) {
|
|
||||||
// // console.log('bits=', bytes.length*8, 'CURVE n=', nBitLength, 'delta=', delta);
|
|
||||||
// // console.log(bytes.length, nBitLength, delta);
|
|
||||||
// // console.log(bytes, new Error().stack);
|
|
||||||
// num >>= BigInt(delta);
|
|
||||||
// }
|
|
||||||
return num;
|
|
||||||
}
|
|
||||||
function bits2octets(bytes: Uint8Array): Uint8Array {
|
|
||||||
const z1 = bits2int(bytes);
|
|
||||||
const z2 = mod.mod(z1, CURVE_ORDER);
|
|
||||||
return int2octets(z2 < _0n ? z1 : z2);
|
|
||||||
}
|
}
|
||||||
|
// NOTE: pads output with zero as per spec
|
||||||
|
const ORDER_MASK = ut.bitMask(CURVE.nBitLength);
|
||||||
function int2octets(num: bigint): Uint8Array {
|
function int2octets(num: bigint): Uint8Array {
|
||||||
return numToField(num); // prohibits >nByteLength bytes
|
if (typeof num !== 'bigint') throw new Error('Expected bigint');
|
||||||
|
if (!(_0n <= num && num < ORDER_MASK))
|
||||||
|
throw new Error(`Expected number < 2^${CURVE.nBitLength}`);
|
||||||
|
return ut.numberToBytesBE(num, CURVE.nByteLength); // works with order, can have different size than numToField!
|
||||||
}
|
}
|
||||||
// Steps A, D of RFC6979 3.2
|
// Steps A, D of RFC6979 3.2
|
||||||
// Creates RFC6979 seed; converts msg/privKey to numbers.
|
// Creates RFC6979 seed; converts msg/privKey to numbers.
|
||||||
|
// Used only in sign, not in verify.
|
||||||
|
// NOTE: we cannot assume here that msgHash has same amount of bytes as curve order, this will be wrong at least for P521.
|
||||||
|
// Also it can be bigger for P224 + SHA256
|
||||||
function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) {
|
function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy) {
|
||||||
if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`);
|
if (msgHash == null) throw new Error(`sign: expected valid message hash, not "${msgHash}"`);
|
||||||
// Step A is ignored, since we already provide hash instead of msg
|
// Step A is ignored, since we already provide hash instead of msg
|
||||||
const h1 = numToField(truncateHash(ut.ensureBytes(msgHash)));
|
|
||||||
|
// NOTE: instead of bits2int, we calling here truncateHash, since we need
|
||||||
|
// custom truncation for stark. For other curves it is essentially same as calling bits2int + mod
|
||||||
|
// However, we cannot later call bits2octets (which is truncateHash + int2octets), since nested bits2int is broken
|
||||||
|
// for curves where nBitLength % 8 !== 0, so we unwrap it here as int2octets call.
|
||||||
|
// const bits2octets = (bits)=>int2octets(ut.bytesToNumberBE(truncateHash(bits)))
|
||||||
|
const h1 = truncateHash(ut.ensureBytes(msgHash));
|
||||||
|
const h1int = ut.bytesToNumberBE(h1);
|
||||||
|
const h1octets = int2octets(h1int);
|
||||||
|
|
||||||
const d = normalizePrivateKey(privateKey);
|
const d = normalizePrivateKey(privateKey);
|
||||||
// K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k')
|
// K = HMAC_K(V || 0x00 || int2octets(x) || bits2octets(h1) || k')
|
||||||
const seedArgs = [int2octets(d), bits2octets(h1)];
|
const seedArgs = [int2octets(d), h1octets];
|
||||||
// RFC6979 3.6: additional k' could be provided
|
// RFC6979 3.6: additional k' could be provided
|
||||||
if (extraEntropy != null) {
|
if (extraEntropy != null) {
|
||||||
if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES);
|
if (extraEntropy === true) extraEntropy = CURVE.randomBytes(Fp.BYTES);
|
||||||
@ -1142,7 +1143,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
// Step D
|
// Step D
|
||||||
// V, 0x00 are done in HmacDRBG constructor.
|
// V, 0x00 are done in HmacDRBG constructor.
|
||||||
const seed = ut.concatBytes(...seedArgs);
|
const seed = ut.concatBytes(...seedArgs);
|
||||||
const m = bits2int(h1);
|
const m = h1int; // NOTE: no need to call bits2int second time here, it is inside truncateHash!
|
||||||
return { seed, m, d };
|
return { seed, m, d };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1156,7 +1157,8 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
*/
|
*/
|
||||||
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined {
|
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): Signature | undefined {
|
||||||
const { n } = CURVE;
|
const { n } = CURVE;
|
||||||
const k = truncateHash(kBytes, true);
|
// RFC 6979 Section 3.2, step 3: k = bits2int(T)
|
||||||
|
const k = ut.bytesToNumberBE(truncateHash(kBytes, true)); // Cannot use fields methods, since it is group element
|
||||||
if (!isWithinCurveOrder(k)) return;
|
if (!isWithinCurveOrder(k)) return;
|
||||||
// Important: all mod() calls in the function must be done over `n`
|
// Important: all mod() calls in the function must be done over `n`
|
||||||
const kinv = mod.invert(k, n);
|
const kinv = mod.invert(k, n);
|
||||||
@ -1189,6 +1191,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
* ```
|
* ```
|
||||||
* @param opts `lowS, extraEntropy`
|
* @param opts `lowS, extraEntropy`
|
||||||
*/
|
*/
|
||||||
|
// TODO: add opts.prehashed = True, if !opts.prehashed do hash on msg?
|
||||||
function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature {
|
function sign(msgHash: Hex, privKey: PrivKey, opts = defaultSigOpts): Signature {
|
||||||
// Steps A, D of RFC6979 3.2.
|
// Steps A, D of RFC6979 3.2.
|
||||||
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
|
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
|
||||||
@ -1256,7 +1259,7 @@ export function weierstrass(curveDef: CurveType): CurveFn {
|
|||||||
}
|
}
|
||||||
const { n } = CURVE;
|
const { n } = CURVE;
|
||||||
const { r, s } = signature;
|
const { r, s } = signature;
|
||||||
const h = truncateHash(msgHash);
|
const h = ut.bytesToNumberBE(truncateHash(msgHash)); // Cannot use fields methods, since it is group element
|
||||||
const sinv = mod.invert(s, n); // s^-1
|
const sinv = mod.invert(s, n); // s^-1
|
||||||
// R = u1⋅G - u2⋅P
|
// R = u1⋅G - u2⋅P
|
||||||
const u1 = mod.mod(h * sinv, n);
|
const u1 = mod.mod(h * sinv, n);
|
||||||
|
12
src/stark.ts
12
src/stark.ts
@ -3,7 +3,7 @@ import { keccak_256 } from '@noble/hashes/sha3';
|
|||||||
import { sha256 } from '@noble/hashes/sha256';
|
import { sha256 } from '@noble/hashes/sha256';
|
||||||
import { weierstrass, ProjectivePointType } from './abstract/weierstrass.js';
|
import { weierstrass, ProjectivePointType } from './abstract/weierstrass.js';
|
||||||
import * as cutils from './abstract/utils.js';
|
import * as cutils from './abstract/utils.js';
|
||||||
import { Fp } from './abstract/modular.js';
|
import { Fp, mod } from './abstract/modular.js';
|
||||||
import { getHash } from './_shortw_utils.js';
|
import { getHash } from './_shortw_utils.js';
|
||||||
|
|
||||||
type ProjectivePoint = ProjectivePointType<bigint>;
|
type ProjectivePoint = ProjectivePointType<bigint>;
|
||||||
@ -31,8 +31,7 @@ export const starkCurve = weierstrass({
|
|||||||
// Default options
|
// Default options
|
||||||
lowS: false,
|
lowS: false,
|
||||||
...getHash(sha256),
|
...getHash(sha256),
|
||||||
truncateHash: (hash: Uint8Array, truncateOnly = false): bigint => {
|
truncateHash: (hash: Uint8Array, truncateOnly = false): Uint8Array => {
|
||||||
// TODO: cleanup, ugly code
|
|
||||||
// Fix truncation
|
// Fix truncation
|
||||||
if (!truncateOnly) {
|
if (!truncateOnly) {
|
||||||
let hashS = bytesToNumber0x(hash).toString(16);
|
let hashS = bytesToNumber0x(hash).toString(16);
|
||||||
@ -43,12 +42,13 @@ export const starkCurve = weierstrass({
|
|||||||
}
|
}
|
||||||
// Truncate zero bytes on left (compat with elliptic)
|
// Truncate zero bytes on left (compat with elliptic)
|
||||||
while (hash[0] === 0) hash = hash.subarray(1);
|
while (hash[0] === 0) hash = hash.subarray(1);
|
||||||
|
// bits2int + part of bits2octets (mod if !truncateOnly)
|
||||||
const byteLength = hash.length;
|
const byteLength = hash.length;
|
||||||
const delta = byteLength * 8 - nBitLength; // size of curve.n (252 bits)
|
const delta = byteLength * 8 - nBitLength; // size of curve.n (252 bits)
|
||||||
let h = hash.length ? bytesToNumber0x(hash) : 0n;
|
let h = hash.length ? bytesToNumber0x(hash) : 0n;
|
||||||
if (delta > 0) h = h >> BigInt(delta);
|
if (delta > 0) h = h >> BigInt(delta); // truncate to nBitLength leftmost bits
|
||||||
if (!truncateOnly && h >= CURVE_N) h -= CURVE_N;
|
if (!truncateOnly) h = mod(h, CURVE_N);
|
||||||
return h;
|
return cutils.numberToVarBytesBE(h);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user