stark: more methods

This commit is contained in:
Paul Miller 2023-02-28 19:18:06 +00:00
parent 16115f27a6
commit 4b2d31ce7f
No known key found for this signature in database
GPG Key ID: 697079DA6878B89B
2 changed files with 140 additions and 26 deletions

@ -5,17 +5,8 @@ import { utf8ToBytes } from '@noble/hashes/utils';
import { Fp, mod, Field, validateField } from './abstract/modular.js'; import { Fp, mod, Field, validateField } from './abstract/modular.js';
import { poseidon } from './abstract/poseidon.js'; import { poseidon } from './abstract/poseidon.js';
import { weierstrass, ProjPointType, SignatureType } from './abstract/weierstrass.js'; import { weierstrass, ProjPointType, SignatureType } from './abstract/weierstrass.js';
import { import * as u from './abstract/utils.js';
Hex, import type { Hex } from './abstract/utils.js';
bitMask,
bytesToHex,
bytesToNumberBE,
concatBytes,
ensureBytes as ensureBytesOrig,
hexToBytes,
hexToNumber,
numberToVarBytesBE,
} from './abstract/utils.js';
import { getHash } from './_shortw_utils.js'; import { getHash } from './_shortw_utils.js';
// Stark-friendly elliptic curve // Stark-friendly elliptic curve
@ -30,7 +21,7 @@ function bits2int(bytes: Uint8Array): bigint {
while (bytes[0] === 0) bytes = bytes.subarray(1); // strip leading 0s while (bytes[0] === 0) bytes = bytes.subarray(1); // strip leading 0s
// Copy-pasted from weierstrass.ts // Copy-pasted from weierstrass.ts
const delta = bytes.length * 8 - nBitLength; const delta = bytes.length * 8 - nBitLength;
const num = bytesToNumberBE(bytes); const num = u.bytesToNumberBE(bytes);
return delta > 0 ? num >> BigInt(delta) : num; return delta > 0 ? num >> BigInt(delta) : num;
} }
function hex0xToBytes(hex: string): Uint8Array { function hex0xToBytes(hex: string): Uint8Array {
@ -38,7 +29,7 @@ function hex0xToBytes(hex: string): Uint8Array {
hex = strip0x(hex); // allow 0x prefix hex = strip0x(hex); // allow 0x prefix
if (hex.length & 1) hex = '0' + hex; // allow unpadded hex if (hex.length & 1) hex = '0' + hex; // allow unpadded hex
} }
return hexToBytes(hex); return u.hexToBytes(hex);
} }
const curve = weierstrass({ const curve = weierstrass({
a: BigInt(1), // Params: a, b a: BigInt(1), // Params: a, b
@ -59,7 +50,7 @@ const curve = weierstrass({
bits2int_modN: (bytes: Uint8Array): bigint => { bits2int_modN: (bytes: Uint8Array): bigint => {
// 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee11318 => // 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee11318 =>
// 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee113180 // 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee113180
const hex = bytesToNumberBE(bytes).toString(16); // toHex unpadded const hex = u.bytesToNumberBE(bytes).toString(16); // toHex unpadded
if (hex.length === 63) bytes = hex0xToBytes(hex + '0'); // append trailing 0 if (hex.length === 63) bytes = hex0xToBytes(hex + '0'); // append trailing 0
return mod(bits2int(bytes), CURVE_ORDER); return mod(bits2int(bytes), CURVE_ORDER);
}, },
@ -67,11 +58,11 @@ const curve = weierstrass({
export const _starkCurve = curve; export const _starkCurve = curve;
function ensureBytes(hex: Hex): Uint8Array { function ensureBytes(hex: Hex): Uint8Array {
return ensureBytesOrig('', typeof hex === 'string' ? hex0xToBytes(hex) : hex); return u.ensureBytes('', typeof hex === 'string' ? hex0xToBytes(hex) : hex);
} }
function normPrivKey(privKey: Hex): string { function normPrivKey(privKey: Hex): string {
return bytesToHex(ensureBytes(privKey)).padStart(64, '0'); return u.bytesToHex(ensureBytes(privKey)).padStart(64, '0');
} }
export function getPublicKey(privKey: Hex, isCompressed = false): Uint8Array { export function getPublicKey(privKey: Hex, isCompressed = false): Uint8Array {
return curve.getPublicKey(normPrivKey(privKey), isCompressed); return curve.getPublicKey(normPrivKey(privKey), isCompressed);
@ -91,7 +82,7 @@ const { CURVE, ProjectivePoint, Signature, utils } = curve;
export { CURVE, ProjectivePoint, Signature, utils }; export { CURVE, ProjectivePoint, Signature, utils };
function extractX(bytes: Uint8Array): string { function extractX(bytes: Uint8Array): string {
const hex = bytesToHex(bytes.subarray(1)); const hex = u.bytesToHex(bytes.subarray(1));
const stripped = hex.replace(/^0+/gm, ''); // strip leading 0s const stripped = hex.replace(/^0+/gm, ''); // strip leading 0s
return `0x${stripped}`; return `0x${stripped}`;
} }
@ -109,7 +100,7 @@ export function grindKey(seed: Hex) {
const sha256mask = 2n ** 256n; const sha256mask = 2n ** 256n;
const limit = sha256mask - mod(sha256mask, CURVE_ORDER); const limit = sha256mask - mod(sha256mask, CURVE_ORDER);
for (let i = 0; ; i++) { for (let i = 0; ; i++) {
const key = sha256Num(concatBytes(_seed, numberToVarBytesBE(BigInt(i)))); const key = sha256Num(u.concatBytes(_seed, u.numberToVarBytesBE(BigInt(i))));
if (key < limit) return mod(key, CURVE_ORDER).toString(16); // key should be in [0, limit) if (key < limit) return mod(key, CURVE_ORDER).toString(16); // key should be in [0, limit)
if (i === 100000) throw new Error('grindKey is broken: tried 100k vals'); // prevent dos if (i === 100000) throw new Error('grindKey is broken: tried 100k vals'); // prevent dos
} }
@ -135,7 +126,7 @@ export function getAccountPath(
): string { ): string {
const layerNum = int31(sha256Num(layer)); const layerNum = int31(sha256Num(layer));
const applicationNum = int31(sha256Num(application)); const applicationNum = int31(sha256Num(application));
const eth = hexToNumber(strip0x(ethereumAddress)); const eth = u.hexToNumber(strip0x(ethereumAddress));
return `m/2645'/${layerNum}'/${applicationNum}'/${int31(eth)}'/${int31(eth >> 31n)}'/${index}`; return `m/2645'/${layerNum}'/${applicationNum}'/${int31(eth)}'/${int31(eth >> 31n)}'/${index}`;
} }
@ -196,7 +187,7 @@ function pedersenArg(arg: PedersenArg): bigint {
if (!Number.isSafeInteger(arg)) throw new Error(`Invalid pedersenArg: ${arg}`); if (!Number.isSafeInteger(arg)) throw new Error(`Invalid pedersenArg: ${arg}`);
value = BigInt(arg); value = BigInt(arg);
} else { } else {
value = bytesToNumberBE(ensureBytes(arg)); value = u.bytesToNumberBE(ensureBytes(arg));
} }
if (!(0n <= value && value < curve.CURVE.Fp.ORDER)) if (!(0n <= value && value < curve.CURVE.Fp.ORDER))
throw new Error(`PedersenArg should be 0 <= value < CURVE.P: ${value}`); // [0..Fp) throw new Error(`PedersenArg should be 0 <= value < CURVE.P: ${value}`); // [0..Fp)
@ -234,9 +225,9 @@ export function hashChain(data: PedersenArg[], fn = pedersen) {
export const computeHashOnElements = (data: PedersenArg[], fn = pedersen) => export const computeHashOnElements = (data: PedersenArg[], fn = pedersen) =>
[0, ...data, data.length].reduce((x, y) => fn(x, y)); [0, ...data, data.length].reduce((x, y) => fn(x, y));
const MASK_250 = bitMask(250); const MASK_250 = u.bitMask(250);
export const keccak = (data: Uint8Array): bigint => bytesToNumberBE(keccak_256(data)) & MASK_250; export const keccak = (data: Uint8Array): bigint => u.bytesToNumberBE(keccak_256(data)) & MASK_250;
const sha256Num = (data: Uint8Array | string): bigint => bytesToNumberBE(sha256(data)); const sha256Num = (data: Uint8Array | string): bigint => u.bytesToNumberBE(sha256(data));
// Poseidon hash // Poseidon hash
export const Fp253 = Fp( export const Fp253 = Fp(
@ -280,7 +271,13 @@ export type PoseidonOpts = {
roundsPartial: number; roundsPartial: number;
}; };
export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) { export type PoseidonFn = ReturnType<typeof poseidon> & {
m: number;
rate: number;
capacity: number;
};
export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]): PoseidonFn {
validateField(opts.Fp); validateField(opts.Fp);
if (!Number.isSafeInteger(opts.rate) || !Number.isSafeInteger(opts.capacity)) if (!Number.isSafeInteger(opts.rate) || !Number.isSafeInteger(opts.capacity))
throw new Error(`Wrong poseidon opts: ${opts}`); throw new Error(`Wrong poseidon opts: ${opts}`);
@ -292,7 +289,7 @@ export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) {
for (let j = 0; j < m; j++) row.push(poseidonRoundConstant(opts.Fp, 'Hades', m * i + j)); for (let j = 0; j < m; j++) row.push(poseidonRoundConstant(opts.Fp, 'Hades', m * i + j));
roundConstants.push(row); roundConstants.push(row);
} }
return poseidon({ const res: Partial<PoseidonFn> = poseidon({
...opts, ...opts,
t: m, t: m,
sboxPower: 3, sboxPower: 3,
@ -300,6 +297,10 @@ export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) {
mds, mds,
roundConstants, roundConstants,
}); });
res.m = m;
res.rate = opts.rate;
res.capacity = opts.capacity;
return res as PoseidonFn;
} }
export function poseidonCreate(opts: PoseidonOpts, mdsAttempt = 0) { export function poseidonCreate(opts: PoseidonOpts, mdsAttempt = 0) {
@ -313,6 +314,28 @@ export const poseidonSmall = poseidonBasic(
MDS_SMALL MDS_SMALL
); );
export function poseidonHash(x: bigint, y: bigint, fn = poseidonSmall) { export function poseidonHash(x: bigint, y: bigint, fn = poseidonSmall): bigint {
return fn([x, y, 2n])[0]; return fn([x, y, 2n])[0];
} }
export function poseidonHashFunc(x: Uint8Array, y: Uint8Array, fn = poseidonSmall): Uint8Array {
return u.numberToVarBytesBE(poseidonHash(u.bytesToNumberBE(x), u.bytesToNumberBE(y), fn));
}
export function poseidonHashSingle(x: bigint, fn = poseidonSmall): bigint {
return fn([x, 0n, 1n])[0];
}
export function poseidonHashMany(values: bigint[], fn = poseidonSmall): bigint {
const { m, rate } = fn;
if (!Array.isArray(values)) throw new Error('bigint array expected in values');
const padded = Array.from(values); // copy
padded.push(1n);
while (padded.length % rate !== 0) padded.push(0n);
let state: bigint[] = new Array(m).fill(0n);
for (let i = 0; i < padded.length; i += rate) {
for (let j = 0; j < rate; j++) state[j] += padded[i + j];
state = fn(state);
}
return state[0];
}

@ -1,6 +1,7 @@
import { deepStrictEqual, throws } from 'assert'; import { deepStrictEqual, throws } from 'assert';
import { describe, should } from 'micro-should'; import { describe, should } from 'micro-should';
import * as starknet from '../../esm/stark.js'; import * as starknet from '../../esm/stark.js';
import { bytesToHex as hex } from '@noble/hashes/utils';
import * as fs from 'fs'; import * as fs from 'fs';
function parseTest(path) { function parseTest(path) {
@ -107,6 +108,96 @@ should('Poseidon examples', () => {
]); ]);
}); });
should('Poseidon 2', () => {
// Cross-test with cairo-lang 0.11
deepStrictEqual(
starknet.poseidonHash(1n, 1n),
315729444126170353286530004158376771769107830460625027134495740547491428733n
);
deepStrictEqual(
starknet.poseidonHash(123n, 123n),
3149184350054566761517315875549307360045573205732410509163060794402900549639n
);
deepStrictEqual(
starknet.poseidonHash(1231231231231231231231231312312n, 1231231231231231231231231312312n),
2544250291965936388474000136445328679708604225006461780180655815882994563864n
);
// poseidonHashSingle
deepStrictEqual(
starknet.poseidonHashSingle(1n),
3085182978037364507644541379307921604860861694664657935759708330416374536741n
);
deepStrictEqual(
starknet.poseidonHashSingle(123n),
2751345659320901472675327541550911744303539407817894466726181731796247467344n
);
deepStrictEqual(
starknet.poseidonHashSingle(1231231231231231231231231312312n),
3083085683696942145160394401206391098729120397175152900096470498748103599322n
);
// poseidonHashMany
throws(() => starknet.poseidonHash(new Uint8Array([1, 2, 3])));
deepStrictEqual(
starknet.poseidonHashMany([1n]),
154809849725474173771833689306955346864791482278938452209165301614543497938n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n]),
1557996165160500454210437319447297236715335099509187222888255133199463084263n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n]),
976552833909388839716191681593200982850734838655927116322079791360264131378n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n, 7n]),
1426681430756292883765769449684978541173830451959857824597431064948702170774n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n]),
3578895185591466904832617962452140411216018208734547126302182794057260630783n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n]),
2047942584693618630610564708884241243670450597197937863619828684896211911953n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n]),
717812721730784692894550948559585317289413466140233907962980309405694367376n
);
deepStrictEqual(
starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n]),
2926122208425648133778911655767364584769133265503722614793281770361723147648n
);
deepStrictEqual(
starknet.poseidonHashMany([
154809849725474173771833689306955346864791482278938452209165301614543497938n,
1557996165160500454210437319447297236715335099509187222888255133199463084263n,
976552833909388839716191681593200982850734838655927116322079791360264131378n,
1426681430756292883765769449684978541173830451959857824597431064948702170774n,
3578895185591466904832617962452140411216018208734547126302182794057260630783n,
]),
1019392520709073131437410341528874594624843119359955302374885123884546721410n
);
// poseidon_hash_func
deepStrictEqual(
hex(starknet.poseidonHashFunc(new Uint8Array([1, 2]), new Uint8Array([3, 4]))),
'01f87cbb9c58139605384d0f0df49b446600af020aa9dac92301d45c96d78c0a'
);
deepStrictEqual(
hex(starknet.poseidonHashFunc(new Uint8Array(32).fill(255), new Uint8Array(32).fill(255))),
'05fd546b5ee3bcbbcbb733ed90bfc33033169d6765ac37bba71794a11cbb51a6'
);
deepStrictEqual(
hex(starknet.poseidonHashFunc(new Uint8Array(64).fill(255), new Uint8Array(64).fill(255))),
'07dba6b4d94b3e32697afe0825d6dac2dccafd439f7806a9575693c93735596b'
);
deepStrictEqual(
hex(starknet.poseidonHashFunc(new Uint8Array(256).fill(255), new Uint8Array(256).fill(255))),
'02f048581901865201dad701a5653d946b961748ec770fc11139aa7c06a9432a'
);
});
// ESM is broken. // ESM is broken.
import url from 'url'; import url from 'url';
if (import.meta.url === url.pathToFileURL(process.argv[1]).href) { if (import.meta.url === url.pathToFileURL(process.argv[1]).href) {