From 4b2d31ce7f0f2ce217243a77e5723d774da32212 Mon Sep 17 00:00:00 2001 From: Paul Miller Date: Tue, 28 Feb 2023 19:18:06 +0000 Subject: [PATCH] stark: more methods --- src/stark.ts | 75 +++++++++++++++++++----------- test/stark/poseidon.test.js | 91 +++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 26 deletions(-) diff --git a/src/stark.ts b/src/stark.ts index 0afabe3..eefb59f 100644 --- a/src/stark.ts +++ b/src/stark.ts @@ -5,17 +5,8 @@ import { utf8ToBytes } from '@noble/hashes/utils'; import { Fp, mod, Field, validateField } from './abstract/modular.js'; import { poseidon } from './abstract/poseidon.js'; import { weierstrass, ProjPointType, SignatureType } from './abstract/weierstrass.js'; -import { - Hex, - bitMask, - bytesToHex, - bytesToNumberBE, - concatBytes, - ensureBytes as ensureBytesOrig, - hexToBytes, - hexToNumber, - numberToVarBytesBE, -} from './abstract/utils.js'; +import * as u from './abstract/utils.js'; +import type { Hex } from './abstract/utils.js'; import { getHash } from './_shortw_utils.js'; // Stark-friendly elliptic curve @@ -30,7 +21,7 @@ function bits2int(bytes: Uint8Array): bigint { while (bytes[0] === 0) bytes = bytes.subarray(1); // strip leading 0s // Copy-pasted from weierstrass.ts const delta = bytes.length * 8 - nBitLength; - const num = bytesToNumberBE(bytes); + const num = u.bytesToNumberBE(bytes); return delta > 0 ? num >> BigInt(delta) : num; } function hex0xToBytes(hex: string): Uint8Array { @@ -38,7 +29,7 @@ function hex0xToBytes(hex: string): Uint8Array { hex = strip0x(hex); // allow 0x prefix if (hex.length & 1) hex = '0' + hex; // allow unpadded hex } - return hexToBytes(hex); + return u.hexToBytes(hex); } const curve = weierstrass({ a: BigInt(1), // Params: a, b @@ -59,7 +50,7 @@ const curve = weierstrass({ bits2int_modN: (bytes: Uint8Array): bigint => { // 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee11318 => // 2102820b232636d200cb21f1d330f20d096cae09d1bf3edb1cc333ddee113180 - const hex = bytesToNumberBE(bytes).toString(16); // toHex unpadded + const hex = u.bytesToNumberBE(bytes).toString(16); // toHex unpadded if (hex.length === 63) bytes = hex0xToBytes(hex + '0'); // append trailing 0 return mod(bits2int(bytes), CURVE_ORDER); }, @@ -67,11 +58,11 @@ const curve = weierstrass({ export const _starkCurve = curve; function ensureBytes(hex: Hex): Uint8Array { - return ensureBytesOrig('', typeof hex === 'string' ? hex0xToBytes(hex) : hex); + return u.ensureBytes('', typeof hex === 'string' ? hex0xToBytes(hex) : hex); } function normPrivKey(privKey: Hex): string { - return bytesToHex(ensureBytes(privKey)).padStart(64, '0'); + return u.bytesToHex(ensureBytes(privKey)).padStart(64, '0'); } export function getPublicKey(privKey: Hex, isCompressed = false): Uint8Array { return curve.getPublicKey(normPrivKey(privKey), isCompressed); @@ -91,7 +82,7 @@ const { CURVE, ProjectivePoint, Signature, utils } = curve; export { CURVE, ProjectivePoint, Signature, utils }; function extractX(bytes: Uint8Array): string { - const hex = bytesToHex(bytes.subarray(1)); + const hex = u.bytesToHex(bytes.subarray(1)); const stripped = hex.replace(/^0+/gm, ''); // strip leading 0s return `0x${stripped}`; } @@ -109,7 +100,7 @@ export function grindKey(seed: Hex) { const sha256mask = 2n ** 256n; const limit = sha256mask - mod(sha256mask, CURVE_ORDER); for (let i = 0; ; i++) { - const key = sha256Num(concatBytes(_seed, numberToVarBytesBE(BigInt(i)))); + const key = sha256Num(u.concatBytes(_seed, u.numberToVarBytesBE(BigInt(i)))); if (key < limit) return mod(key, CURVE_ORDER).toString(16); // key should be in [0, limit) if (i === 100000) throw new Error('grindKey is broken: tried 100k vals'); // prevent dos } @@ -135,7 +126,7 @@ export function getAccountPath( ): string { const layerNum = int31(sha256Num(layer)); const applicationNum = int31(sha256Num(application)); - const eth = hexToNumber(strip0x(ethereumAddress)); + const eth = u.hexToNumber(strip0x(ethereumAddress)); return `m/2645'/${layerNum}'/${applicationNum}'/${int31(eth)}'/${int31(eth >> 31n)}'/${index}`; } @@ -196,7 +187,7 @@ function pedersenArg(arg: PedersenArg): bigint { if (!Number.isSafeInteger(arg)) throw new Error(`Invalid pedersenArg: ${arg}`); value = BigInt(arg); } else { - value = bytesToNumberBE(ensureBytes(arg)); + value = u.bytesToNumberBE(ensureBytes(arg)); } if (!(0n <= value && value < curve.CURVE.Fp.ORDER)) throw new Error(`PedersenArg should be 0 <= value < CURVE.P: ${value}`); // [0..Fp) @@ -234,9 +225,9 @@ export function hashChain(data: PedersenArg[], fn = pedersen) { export const computeHashOnElements = (data: PedersenArg[], fn = pedersen) => [0, ...data, data.length].reduce((x, y) => fn(x, y)); -const MASK_250 = bitMask(250); -export const keccak = (data: Uint8Array): bigint => bytesToNumberBE(keccak_256(data)) & MASK_250; -const sha256Num = (data: Uint8Array | string): bigint => bytesToNumberBE(sha256(data)); +const MASK_250 = u.bitMask(250); +export const keccak = (data: Uint8Array): bigint => u.bytesToNumberBE(keccak_256(data)) & MASK_250; +const sha256Num = (data: Uint8Array | string): bigint => u.bytesToNumberBE(sha256(data)); // Poseidon hash export const Fp253 = Fp( @@ -280,7 +271,13 @@ export type PoseidonOpts = { roundsPartial: number; }; -export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) { +export type PoseidonFn = ReturnType & { + m: number; + rate: number; + capacity: number; +}; + +export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]): PoseidonFn { validateField(opts.Fp); if (!Number.isSafeInteger(opts.rate) || !Number.isSafeInteger(opts.capacity)) throw new Error(`Wrong poseidon opts: ${opts}`); @@ -292,7 +289,7 @@ export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) { for (let j = 0; j < m; j++) row.push(poseidonRoundConstant(opts.Fp, 'Hades', m * i + j)); roundConstants.push(row); } - return poseidon({ + const res: Partial = poseidon({ ...opts, t: m, sboxPower: 3, @@ -300,6 +297,10 @@ export function poseidonBasic(opts: PoseidonOpts, mds: bigint[][]) { mds, roundConstants, }); + res.m = m; + res.rate = opts.rate; + res.capacity = opts.capacity; + return res as PoseidonFn; } export function poseidonCreate(opts: PoseidonOpts, mdsAttempt = 0) { @@ -313,6 +314,28 @@ export const poseidonSmall = poseidonBasic( MDS_SMALL ); -export function poseidonHash(x: bigint, y: bigint, fn = poseidonSmall) { +export function poseidonHash(x: bigint, y: bigint, fn = poseidonSmall): bigint { return fn([x, y, 2n])[0]; } + +export function poseidonHashFunc(x: Uint8Array, y: Uint8Array, fn = poseidonSmall): Uint8Array { + return u.numberToVarBytesBE(poseidonHash(u.bytesToNumberBE(x), u.bytesToNumberBE(y), fn)); +} + +export function poseidonHashSingle(x: bigint, fn = poseidonSmall): bigint { + return fn([x, 0n, 1n])[0]; +} + +export function poseidonHashMany(values: bigint[], fn = poseidonSmall): bigint { + const { m, rate } = fn; + if (!Array.isArray(values)) throw new Error('bigint array expected in values'); + const padded = Array.from(values); // copy + padded.push(1n); + while (padded.length % rate !== 0) padded.push(0n); + let state: bigint[] = new Array(m).fill(0n); + for (let i = 0; i < padded.length; i += rate) { + for (let j = 0; j < rate; j++) state[j] += padded[i + j]; + state = fn(state); + } + return state[0]; +} diff --git a/test/stark/poseidon.test.js b/test/stark/poseidon.test.js index c1ddd6b..9ac4382 100644 --- a/test/stark/poseidon.test.js +++ b/test/stark/poseidon.test.js @@ -1,6 +1,7 @@ import { deepStrictEqual, throws } from 'assert'; import { describe, should } from 'micro-should'; import * as starknet from '../../esm/stark.js'; +import { bytesToHex as hex } from '@noble/hashes/utils'; import * as fs from 'fs'; function parseTest(path) { @@ -107,6 +108,96 @@ should('Poseidon examples', () => { ]); }); +should('Poseidon 2', () => { + // Cross-test with cairo-lang 0.11 + deepStrictEqual( + starknet.poseidonHash(1n, 1n), + 315729444126170353286530004158376771769107830460625027134495740547491428733n + ); + deepStrictEqual( + starknet.poseidonHash(123n, 123n), + 3149184350054566761517315875549307360045573205732410509163060794402900549639n + ); + deepStrictEqual( + starknet.poseidonHash(1231231231231231231231231312312n, 1231231231231231231231231312312n), + 2544250291965936388474000136445328679708604225006461780180655815882994563864n + ); + // poseidonHashSingle + deepStrictEqual( + starknet.poseidonHashSingle(1n), + 3085182978037364507644541379307921604860861694664657935759708330416374536741n + ); + deepStrictEqual( + starknet.poseidonHashSingle(123n), + 2751345659320901472675327541550911744303539407817894466726181731796247467344n + ); + deepStrictEqual( + starknet.poseidonHashSingle(1231231231231231231231231312312n), + 3083085683696942145160394401206391098729120397175152900096470498748103599322n + ); + // poseidonHashMany + throws(() => starknet.poseidonHash(new Uint8Array([1, 2, 3]))); + deepStrictEqual( + starknet.poseidonHashMany([1n]), + 154809849725474173771833689306955346864791482278938452209165301614543497938n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n]), + 1557996165160500454210437319447297236715335099509187222888255133199463084263n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n]), + 976552833909388839716191681593200982850734838655927116322079791360264131378n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n, 7n]), + 1426681430756292883765769449684978541173830451959857824597431064948702170774n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n, 6n]), + 3578895185591466904832617962452140411216018208734547126302182794057260630783n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n, 5n]), + 2047942584693618630610564708884241243670450597197937863619828684896211911953n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n, 4n]), + 717812721730784692894550948559585317289413466140233907962980309405694367376n + ); + deepStrictEqual( + starknet.poseidonHashMany([1n, 2n, 3n, 4n, 5n, 6n, 7n, 8n, 9n, 1n, 2n, 3n]), + 2926122208425648133778911655767364584769133265503722614793281770361723147648n + ); + deepStrictEqual( + starknet.poseidonHashMany([ + 154809849725474173771833689306955346864791482278938452209165301614543497938n, + 1557996165160500454210437319447297236715335099509187222888255133199463084263n, + 976552833909388839716191681593200982850734838655927116322079791360264131378n, + 1426681430756292883765769449684978541173830451959857824597431064948702170774n, + 3578895185591466904832617962452140411216018208734547126302182794057260630783n, + ]), + 1019392520709073131437410341528874594624843119359955302374885123884546721410n + ); + // poseidon_hash_func + deepStrictEqual( + hex(starknet.poseidonHashFunc(new Uint8Array([1, 2]), new Uint8Array([3, 4]))), + '01f87cbb9c58139605384d0f0df49b446600af020aa9dac92301d45c96d78c0a' + ); + deepStrictEqual( + hex(starknet.poseidonHashFunc(new Uint8Array(32).fill(255), new Uint8Array(32).fill(255))), + '05fd546b5ee3bcbbcbb733ed90bfc33033169d6765ac37bba71794a11cbb51a6' + ); + deepStrictEqual( + hex(starknet.poseidonHashFunc(new Uint8Array(64).fill(255), new Uint8Array(64).fill(255))), + '07dba6b4d94b3e32697afe0825d6dac2dccafd439f7806a9575693c93735596b' + ); + deepStrictEqual( + hex(starknet.poseidonHashFunc(new Uint8Array(256).fill(255), new Uint8Array(256).fill(255))), + '02f048581901865201dad701a5653d946b961748ec770fc11139aa7c06a9432a' + ); +}); + // ESM is broken. import url from 'url'; if (import.meta.url === url.pathToFileURL(process.argv[1]).href) {