From 82c2f606cc3039377529f78715632b5579a68160 Mon Sep 17 00:00:00 2001 From: poma Date: Sun, 9 Aug 2020 14:31:38 +0300 Subject: [PATCH] Rewrite Poseidon hash implementation to be compatible with reference implementation --- circuits/poseidon.circom | 167 +- circuits/poseidon_constants.circom | 94 + src/poseidon.js | 148 +- src/poseidon_constants.json | 3449 +++++++++++++++++ ...idon_test.circom => poseidon2_test.circom} | 2 +- test/circuits/poseidon4_test.circom | 3 + test/poseidoncircuit.js | 68 +- 7 files changed, 3649 insertions(+), 282 deletions(-) create mode 100644 circuits/poseidon_constants.circom create mode 100644 src/poseidon_constants.json rename test/circuits/{poseidon_test.circom => poseidon2_test.circom} (51%) create mode 100644 test/circuits/poseidon4_test.circom diff --git a/circuits/poseidon.circom b/circuits/poseidon.circom index dad6806..4834748 100644 --- a/circuits/poseidon.circom +++ b/circuits/poseidon.circom @@ -1,3 +1,4 @@ +include "./poseidon_constants.circom"; template Sigma() { signal input in; @@ -12,163 +13,52 @@ template Sigma() { out <== in4*in; } -template Ark(t, C) { +template Ark(t, C, r) { signal input in[t]; signal output out[t]; + for (var i=0; i= (nRoundsP + nRoundsF/2))) { - k= i= nRoundsP + nRoundsF/2) { + k = i < nRoundsF/2 ? i : i - nRoundsP; + mix[i] = Mix(t, M); for (var j=0; j F.mul(a, F.square(F.square(a, a))); + +function poseidon(inputs) { + assert(inputs.length > 0); + assert(inputs.length < N_ROUNDS_P.length - 1); + + const t = inputs.length + 1; + const nRoundsF = N_ROUNDS_F; + const nRoundsP = N_ROUNDS_P[t - 2]; + + let state = [...inputs.map(a => F.e(a)), F.zero]; + for (let r = 0; r < nRoundsF + nRoundsP; r++) { + state = state.map((a, i) => F.add(a, C[t - 2][r * t + i])); + + if (r < nRoundsF / 2 || r >= nRoundsF / 2 + nRoundsP) { + state = state.map(a => pow5(a)); + } else { + state[0] = pow5(state[0]); + } + + // no matrix multiplication in the last round + if (r < nRoundsF + nRoundsP - 1) { + state = state.map((_, i) => + state.reduce((acc, a, j) => F.add(acc, F.mul(M[t - 2][j][i], a)), F.zero) + ); } } - return true; + return F.normalize(state[0]); } -exports.getMatrix = (t, seed, nRounds) => { - if (typeof seed === "undefined") seed = SEED; - if (typeof nRounds === "undefined") nRounds = NROUNDSF + NROUNDSP; - if (typeof t === "undefined") t = T; - let nonce = "0000"; - let cmatrix = getPseudoRandom(seed+"_matrix_"+nonce, t*2); - while (!allDifferent(cmatrix)) { - nonce = (Number(nonce)+1)+""; - while(nonce.length<4) nonce = "0"+nonce; - cmatrix = getPseudoRandom(seed+"_matrix_"+nonce, t*2); - } - - const M = new Array(t); - for (let i=0; i { - if (typeof seed === "undefined") seed = SEED; - if (typeof nRounds === "undefined") nRounds = NROUNDSF + NROUNDSP; - if (typeof t === "undefined") t = T; - const cts = getPseudoRandom(seed+"_constants", nRounds); - return cts; -}; - -function ark(state, c) { - for (let j=0; j { - - if (typeof seed === "undefined") seed = SEED; - if (typeof nRoundsF === "undefined") nRoundsF = NROUNDSF; - if (typeof nRoundsP === "undefined") nRoundsP = NROUNDSP; - if (typeof t === "undefined") t = T; - - assert(nRoundsF % 2 == 0); - const C = exports.getConstants(t, seed, nRoundsF + nRoundsP); - const M = exports.getMatrix(t, seed, nRoundsF + nRoundsP); - return function(inputs) { - let state = []; - assert(inputs.length <= t); - assert(inputs.length > 0); - for (let i=0; i= nRoundsF/2 + nRoundsP)) { - for (let j=0; j { - var output = new Uint8Array(32); - var input = Buffer.from('poseidon_constants'); - h = blake2b(output.length).update(input).digest('hex') - assert.equal('e57ba154fb2c47811dc1a2369b27e25a44915b4e4ece4eb8ec74850cb78e01b1', h); - }); -}); - describe("Poseidon Circuit test", function () { - let circuit; + let circuit2; + let circuit4; this.timeout(100000); - before( async () => { - const cirDef = await compiler(path.join(__dirname, "circuits", "poseidon_test.circom")); - - circuit = new snarkjs.Circuit(cirDef); - - console.log("Poseidon constraints: " + circuit.nConstraints); + before(async () => { + circuit2 = await tester(path.join(__dirname, "circuits", "poseidon2_test.circom")); + circuit4 = await tester(path.join(__dirname, "circuits", "poseidon4_test.circom")); }); it("Should check constrain of hash([1, 2])", async () => { - const w = circuit.calculateWitness({inputs: [1, 2]}); - - const res = w[circuit.getSignalIdx("main.out")]; - - const hash = poseidon.createHash(6, 8, 57); - - const res2 = hash([1,2]); - assert.equal('12242166908188651009877250812424843524687801523336557272219921456462821518061', res2.toString()); - assert.equal(res.toString(), res2.toString()); - assert(circuit.checkWitness(w)); + const hash = poseidon([1, 2]); + assert.equal("17117985411748610629288516079940078114952304104811071254131751175361957805920", hash.toString()); + const w = await circuit2.calculateWitness({inputs: [1, 2]}, true); + await circuit2.assertOut(w, {out : hash}); + await circuit2.checkConstraints(w); }); it("Should check constrain of hash([3, 4])", async () => { - const w = circuit.calculateWitness({inputs: [3, 4]}); + const hash = poseidon([3, 4]); + assert.equal("21867347236198497199818917118739170715216974132230970409806500217655788551452", hash.toString()); + const w = await circuit2.calculateWitness({inputs: [3, 4]}); + await circuit2.assertOut(w, {out : hash}); + await circuit2.checkConstraints(w); + }); - const res = w[circuit.getSignalIdx("main.out")]; - const hash = poseidon.createHash(6, 8, 57); + it("Should check constrain of hash([1, 2, 3, 4])", async () => { + const hash = poseidon([1, 2, 3, 4]); + assert.equal("10501812514110530158422365608831771203648472822841727510887411206067265790462", hash.toString()); + const w = await circuit4.calculateWitness({inputs: [1, 2, 3, 4]}); + await circuit4.assertOut(w, {out : hash}); + await circuit4.checkConstraints(w); + }); - const res2 = hash([3, 4]); - assert.equal('17185195740979599334254027721507328033796809509313949281114643312710535000993', res2.toString()); - - assert.equal(res.toString(), res2.toString()); - - assert(circuit.checkWitness(w)); + it("Should check constrain of hash([5, 6, 7, 8])", async () => { + const hash = poseidon([5, 6, 7, 8]); + assert.equal("20761996991478317428195238015626872345373101531750069996451149877836620406299", hash.toString()); + const w = await circuit4.calculateWitness({inputs: [5, 6, 7, 8]}); + await circuit4.assertOut(w, {out : hash}); + await circuit4.checkConstraints(w); }); });