check elements length in constructor

This commit is contained in:
smart_ex 2022-03-11 15:37:20 +10:00
parent 6df03ab139
commit bc533ede2d
4 changed files with 45 additions and 11 deletions

@ -192,7 +192,7 @@ export default class MerkleTree {
throw new Error('Element not found') throw new Error('Element not found')
} }
const edgePath = this.path(edgeIndex) const edgePath = this.path(edgeIndex)
return { edgePath, edgeElement, edgeIndex } return { edgePath, edgeElement, edgeIndex, edgeElementsCount: this._layers[0].length }
} }
/** /**
@ -202,7 +202,7 @@ export default class MerkleTree {
getTreeSlices(count = 4): { edge: TreeEdge, elements: Element[] }[] { getTreeSlices(count = 4): { edge: TreeEdge, elements: Element[] }[] {
const length = this._layers[0].length const length = this._layers[0].length
let size = Math.ceil(length / count) let size = Math.ceil(length / count)
size % 2 && size++ if (size % 2) size++
const slices = [] const slices = []
for (let i = 0; i < length; i += size) { for (let i = 0; i < length; i += size) {
const edgeLeft = i const edgeLeft = i

@ -32,16 +32,16 @@ export class PartialMerkleTree {
edgePath, edgePath,
edgeElement, edgeElement,
edgeIndex, edgeIndex,
edgeElementsCount,
}: TreeEdge, leaves: Element[], { hashFunction, zeroElement }: MerkleTreeOptions = {}) { }: TreeEdge, leaves: Element[], { hashFunction, zeroElement }: MerkleTreeOptions = {}) {
hashFunction = hashFunction || defaultHash if (edgeIndex + leaves.length !== edgeElementsCount) throw new Error('Invalid number of elements')
const hashFn = (left, right) => (left !== undefined && right !== undefined) ? hashFunction(left, right) : undefined
this._edgeLeafProof = edgePath this._edgeLeafProof = edgePath
this._initialRoot = edgePath.pathRoot this._initialRoot = edgePath.pathRoot
this.zeroElement = zeroElement ?? 0 this.zeroElement = zeroElement ?? 0
this._edgeLeaf = { data: edgeElement, index: edgeIndex } this._edgeLeaf = { data: edgeElement, index: edgeIndex }
this._leavesAfterEdge = leaves this._leavesAfterEdge = leaves
this.levels = levels this.levels = levels
this._hashFn = hashFn this._hashFn = hashFunction || defaultHash
this._createProofMap() this._createProofMap()
this._buildTree() this._buildTree()
} }
@ -257,9 +257,9 @@ export class PartialMerkleTree {
serialize(): SerializedPartialTreeState { serialize(): SerializedPartialTreeState {
const leaves = this.layers[0].slice(this._edgeLeaf.index) const leaves = this.layers[0].slice(this._edgeLeaf.index)
return { return {
_initialRoot: this._initialRoot,
_edgeLeafProof: this._edgeLeafProof, _edgeLeafProof: this._edgeLeafProof,
_edgeLeaf: this._edgeLeaf, _edgeLeaf: this._edgeLeaf,
_edgeElementsCount: this._layers[0].length,
levels: this.levels, levels: this.levels,
leaves, leaves,
_zeros: this._zeros, _zeros: this._zeros,
@ -271,6 +271,7 @@ export class PartialMerkleTree {
edgePath: data._edgeLeafProof, edgePath: data._edgeLeafProof,
edgeElement: data._edgeLeaf.data, edgeElement: data._edgeLeaf.data,
edgeIndex: data._edgeLeaf.index, edgeIndex: data._edgeLeaf.index,
edgeElementsCount: data._edgeElementsCount,
} }
return new PartialMerkleTree(data.levels, edge, data.leaves, { return new PartialMerkleTree(data.levels, edge, data.leaves, {
hashFunction, hashFunction,

@ -20,11 +20,11 @@ export type SerializedTreeState = {
} }
export type SerializedPartialTreeState = { export type SerializedPartialTreeState = {
levels: number, levels: number
leaves: Element[] leaves: Element[]
_zeros: Array<Element>, _edgeElementsCount: number
_edgeLeafProof: ProofPath, _zeros: Array<Element>
_initialRoot: Element, _edgeLeafProof: ProofPath
_edgeLeaf: LeafWithIndex _edgeLeaf: LeafWithIndex
} }
@ -38,6 +38,7 @@ export type TreeEdge = {
edgeElement: Element; edgeElement: Element;
edgePath: ProofPath; edgePath: ProofPath;
edgeIndex: number; edgeIndex: number;
edgeElementsCount: number;
} }
export type LeafWithIndex = { index: number, data: Element } export type LeafWithIndex = { index: number, data: Element }

@ -1,4 +1,4 @@
import { MerkleTree, TreeEdge } from '../src' import { MerkleTree, PartialMerkleTree, TreeEdge } from '../src'
import { assert, should } from 'chai' import { assert, should } from 'chai'
import { buildMimcSponge } from 'circomlibjs' import { buildMimcSponge } from 'circomlibjs'
import { createHash } from 'crypto' import { createHash } from 'crypto'
@ -292,6 +292,7 @@ describe('MerkleTree', () => {
}, },
edgeElement: 4, edgeElement: 4,
edgeIndex: 4, edgeIndex: 4,
edgeElementsCount: 6,
} }
const tree = new MerkleTree(4, [0, 1, 2, 3, 4, 5]) const tree = new MerkleTree(4, [0, 1, 2, 3, 4, 5])
assert.deepEqual(tree.getTreeEdge(4), expectedEdge) assert.deepEqual(tree.getTreeEdge(4), expectedEdge)
@ -302,7 +303,38 @@ describe('MerkleTree', () => {
should().throw(call, 'Element not found') should().throw(call, 'Element not found')
}) })
}) })
describe('#getTreeSlices', () => {
let fullTree: MerkleTree
before(() => {
const elements = Array.from({ length: 128 }, (_, i) => i)
fullTree = new MerkleTree(10, elements)
})
it('should return correct slices count', () => {
const count = 5
const slicesCount = fullTree.getTreeSlices(5).length
should().equal(count, slicesCount)
})
it('should be able to create partial tree from last slice', () => {
const lastSlice = fullTree.getTreeSlices().pop()
const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements)
should().equal(partialTree.root, fullTree.root)
})
it('should be able to build full tree from slices', () => {
const slices = fullTree.getTreeSlices()
const lastSlice = slices.pop()
const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements)
slices.reverse().forEach(({ edge, elements }) => partialTree.shiftEdge(edge, elements))
assert.deepEqual(partialTree.layers, fullTree.layers)
})
it('should throw if invalid number of elements', () => {
const [firstSlice] = fullTree.getTreeSlices()
const call = () => new PartialMerkleTree(10, firstSlice.edge, firstSlice.elements)
should().throw(call, 'Invalid number of elements')
})
})
describe('#getters', () => { describe('#getters', () => {
const elements = [1, 2, 3, 4, 5] const elements = [1, 2, 3, 4, 5]
const layers = [ const layers = [