check elements length in constructor

This commit is contained in:
smart_ex 2022-03-11 15:37:20 +10:00
parent 6df03ab139
commit bc533ede2d
4 changed files with 45 additions and 11 deletions

@ -192,7 +192,7 @@ export default class MerkleTree {
throw new Error('Element not found')
}
const edgePath = this.path(edgeIndex)
return { edgePath, edgeElement, edgeIndex }
return { edgePath, edgeElement, edgeIndex, edgeElementsCount: this._layers[0].length }
}
/**
@ -202,7 +202,7 @@ export default class MerkleTree {
getTreeSlices(count = 4): { edge: TreeEdge, elements: Element[] }[] {
const length = this._layers[0].length
let size = Math.ceil(length / count)
size % 2 && size++
if (size % 2) size++
const slices = []
for (let i = 0; i < length; i += size) {
const edgeLeft = i

@ -32,16 +32,16 @@ export class PartialMerkleTree {
edgePath,
edgeElement,
edgeIndex,
edgeElementsCount,
}: TreeEdge, leaves: Element[], { hashFunction, zeroElement }: MerkleTreeOptions = {}) {
hashFunction = hashFunction || defaultHash
const hashFn = (left, right) => (left !== undefined && right !== undefined) ? hashFunction(left, right) : undefined
if (edgeIndex + leaves.length !== edgeElementsCount) throw new Error('Invalid number of elements')
this._edgeLeafProof = edgePath
this._initialRoot = edgePath.pathRoot
this.zeroElement = zeroElement ?? 0
this._edgeLeaf = { data: edgeElement, index: edgeIndex }
this._leavesAfterEdge = leaves
this.levels = levels
this._hashFn = hashFn
this._hashFn = hashFunction || defaultHash
this._createProofMap()
this._buildTree()
}
@ -257,9 +257,9 @@ export class PartialMerkleTree {
serialize(): SerializedPartialTreeState {
const leaves = this.layers[0].slice(this._edgeLeaf.index)
return {
_initialRoot: this._initialRoot,
_edgeLeafProof: this._edgeLeafProof,
_edgeLeaf: this._edgeLeaf,
_edgeElementsCount: this._layers[0].length,
levels: this.levels,
leaves,
_zeros: this._zeros,
@ -271,6 +271,7 @@ export class PartialMerkleTree {
edgePath: data._edgeLeafProof,
edgeElement: data._edgeLeaf.data,
edgeIndex: data._edgeLeaf.index,
edgeElementsCount: data._edgeElementsCount,
}
return new PartialMerkleTree(data.levels, edge, data.leaves, {
hashFunction,

@ -20,11 +20,11 @@ export type SerializedTreeState = {
}
export type SerializedPartialTreeState = {
levels: number,
levels: number
leaves: Element[]
_zeros: Array<Element>,
_edgeLeafProof: ProofPath,
_initialRoot: Element,
_edgeElementsCount: number
_zeros: Array<Element>
_edgeLeafProof: ProofPath
_edgeLeaf: LeafWithIndex
}
@ -38,6 +38,7 @@ export type TreeEdge = {
edgeElement: Element;
edgePath: ProofPath;
edgeIndex: number;
edgeElementsCount: number;
}
export type LeafWithIndex = { index: number, data: Element }

@ -1,4 +1,4 @@
import { MerkleTree, TreeEdge } from '../src'
import { MerkleTree, PartialMerkleTree, TreeEdge } from '../src'
import { assert, should } from 'chai'
import { buildMimcSponge } from 'circomlibjs'
import { createHash } from 'crypto'
@ -292,6 +292,7 @@ describe('MerkleTree', () => {
},
edgeElement: 4,
edgeIndex: 4,
edgeElementsCount: 6,
}
const tree = new MerkleTree(4, [0, 1, 2, 3, 4, 5])
assert.deepEqual(tree.getTreeEdge(4), expectedEdge)
@ -302,7 +303,38 @@ describe('MerkleTree', () => {
should().throw(call, 'Element not found')
})
})
describe('#getTreeSlices', () => {
let fullTree: MerkleTree
before(() => {
const elements = Array.from({ length: 128 }, (_, i) => i)
fullTree = new MerkleTree(10, elements)
})
it('should return correct slices count', () => {
const count = 5
const slicesCount = fullTree.getTreeSlices(5).length
should().equal(count, slicesCount)
})
it('should be able to create partial tree from last slice', () => {
const lastSlice = fullTree.getTreeSlices().pop()
const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements)
should().equal(partialTree.root, fullTree.root)
})
it('should be able to build full tree from slices', () => {
const slices = fullTree.getTreeSlices()
const lastSlice = slices.pop()
const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements)
slices.reverse().forEach(({ edge, elements }) => partialTree.shiftEdge(edge, elements))
assert.deepEqual(partialTree.layers, fullTree.layers)
})
it('should throw if invalid number of elements', () => {
const [firstSlice] = fullTree.getTreeSlices()
const call = () => new PartialMerkleTree(10, firstSlice.edge, firstSlice.elements)
should().throw(call, 'Invalid number of elements')
})
})
describe('#getters', () => {
const elements = [1, 2, 3, 4, 5]
const layers = [