diff --git a/core/state/iterator.go b/core/state/iterator.go index a58a15ad3..170aec983 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -31,15 +31,14 @@ import ( type NodeIterator struct { state *StateDB // State being iterated - stateIt *trie.NodeIterator // Primary iterator for the global state trie - dataIt *trie.NodeIterator // Secondary iterator for the data trie of a contract + stateIt trie.NodeIterator // Primary iterator for the global state trie + dataIt trie.NodeIterator // Secondary iterator for the data trie of a contract accountHash common.Hash // Hash of the node containing the account codeHash common.Hash // Hash of the contract source code code []byte // Source code associated with a contract Hash common.Hash // Hash of the current entry being iterated (nil if not standalone) - Entry interface{} // Current state entry being iterated (internal representation) Parent common.Hash // Hash of the first full ancestor node (nil if current is the root) Error error // Failure set in case of an internal error in the iterator @@ -80,9 +79,9 @@ func (it *NodeIterator) step() error { } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { - if cont := it.dataIt.Next(); !cont { - if it.dataIt.Error != nil { - return it.dataIt.Error + if cont := it.dataIt.Next(true); !cont { + if it.dataIt.Error() != nil { + return it.dataIt.Error() } it.dataIt = nil } @@ -94,15 +93,15 @@ func (it *NodeIterator) step() error { return nil } // Step to the next state trie node, terminating if we're out of nodes - if cont := it.stateIt.Next(); !cont { - if it.stateIt.Error != nil { - return it.stateIt.Error + if cont := it.stateIt.Next(true); !cont { + if it.stateIt.Error() != nil { + return it.stateIt.Error() } it.state, it.stateIt = nil, nil return nil } // If the state trie node is an internal entry, leave as is - if !it.stateIt.Leaf { + if !it.stateIt.Leaf() { return nil } // Otherwise we've reached an account node, initiate data iteration @@ -112,7 +111,7 @@ func (it *NodeIterator) step() error { Root common.Hash CodeHash []byte } - if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob), &account); err != nil { + if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } dataTrie, err := trie.New(account.Root, it.state.db) @@ -120,7 +119,7 @@ func (it *NodeIterator) step() error { return err } it.dataIt = trie.NewNodeIterator(dataTrie) - if !it.dataIt.Next() { + if !it.dataIt.Next(true) { it.dataIt = nil } if !bytes.Equal(account.CodeHash, emptyCodeHash) { @@ -130,7 +129,7 @@ func (it *NodeIterator) step() error { return fmt.Errorf("code %x: %v", account.CodeHash, err) } } - it.accountHash = it.stateIt.Parent + it.accountHash = it.stateIt.Parent() return nil } @@ -138,7 +137,7 @@ func (it *NodeIterator) step() error { // The method returns whether there are any more data left for inspection. func (it *NodeIterator) retrieve() bool { // Clear out any previously set values - it.Hash, it.Entry = common.Hash{}, nil + it.Hash = common.Hash{} // If the iteration's done, return no available data if it.state == nil { @@ -147,14 +146,14 @@ func (it *NodeIterator) retrieve() bool { // Otherwise retrieve the current entry switch { case it.dataIt != nil: - it.Hash, it.Entry, it.Parent = it.dataIt.Hash, it.dataIt.Node, it.dataIt.Parent + it.Hash, it.Parent = it.dataIt.Hash(), it.dataIt.Parent() if it.Parent == (common.Hash{}) { it.Parent = it.accountHash } case it.code != nil: - it.Hash, it.Entry, it.Parent = it.codeHash, it.code, it.accountHash + it.Hash, it.Parent = it.codeHash, it.accountHash case it.stateIt != nil: - it.Hash, it.Entry, it.Parent = it.stateIt.Hash, it.stateIt.Node, it.stateIt.Parent + it.Hash, it.Parent = it.stateIt.Hash(), it.stateIt.Parent() } return true } diff --git a/trie/iterator.go b/trie/iterator.go index afde6e19e..234c49ecc 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -16,13 +16,14 @@ package trie -import "github.com/ethereum/go-ethereum/common" +import ( + "bytes" + "github.com/ethereum/go-ethereum/common" +) // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { - trie *Trie - nodeIt *NodeIterator - keyBuf []byte + nodeIt NodeIterator Key []byte // Current data key on which the iterator is positioned on Value []byte // Current data value on which the iterator is positioned on @@ -31,19 +32,23 @@ type Iterator struct { // NewIterator creates a new key-value iterator. func NewIterator(trie *Trie) *Iterator { return &Iterator{ - trie: trie, nodeIt: NewNodeIterator(trie), - keyBuf: make([]byte, 0, 64), - Key: nil, + } +} + +// FromNodeIterator creates a new key-value iterator from a node iterator +func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { + return &Iterator{ + nodeIt: it, } } // Next moves the iterator forward one key-value entry. func (it *Iterator) Next() bool { - for it.nodeIt.Next() { - if it.nodeIt.Leaf { - it.Key = it.makeKey() - it.Value = it.nodeIt.LeafBlob + for it.nodeIt.Next(true) { + if it.nodeIt.Leaf() { + it.Key = decodeCompact(it.nodeIt.Path()) + it.Value = it.nodeIt.LeafBlob() return true } } @@ -52,74 +57,123 @@ func (it *Iterator) Next() bool { return false } -func (it *Iterator) makeKey() []byte { - key := it.keyBuf[:0] - for _, se := range it.nodeIt.stack { - switch node := se.node.(type) { - case *fullNode: - if se.child <= 16 { - key = append(key, byte(se.child)) - } - case *shortNode: - if hasTerm(node.Key) { - key = append(key, node.Key[:len(node.Key)-1]...) - } else { - key = append(key, node.Key...) - } - } - } - return decodeCompact(key) +// NodeIterator is an iterator to traverse the trie pre-order. +type NodeIterator interface { + // Hash returns the hash of the current node + Hash() common.Hash + // Parent returns the hash of the parent of the current node + Parent() common.Hash + // Leaf returns true iff the current node is a leaf node. + Leaf() bool + // LeafBlob returns the contents of the node, if it is a leaf. + // Callers must not retain references to the return value after calling Next() + LeafBlob() []byte + // Path returns the hex-encoded path to the current node. + // Callers must not retain references to the return value after calling Next() + Path() []byte + // Next moves the iterator to the next node. If the parameter is false, any child + // nodes will be skipped. + Next(bool) bool + // Error returns the error status of the iterator. + Error() error } // nodeIteratorState represents the iteration state at one particular node of the // trie, which can be resumed at a later invocation. type nodeIteratorState struct { - hash common.Hash // Hash of the node being iterated (nil if not standalone) - node node // Trie node being iterated - parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - child int // Child to be processed next + hash common.Hash // Hash of the node being iterated (nil if not standalone) + node node // Trie node being iterated + parent common.Hash // Hash of the first full ancestor node (nil if current is the root) + child int // Child to be processed next + pathlen int // Length of the path to this node } -// NodeIterator is an iterator to traverse the trie post-order. -type NodeIterator struct { +type nodeIterator struct { trie *Trie // Trie being iterated stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state - Hash common.Hash // Hash of the current node being iterated (nil if not standalone) - Node node // Current node being iterated (internal representation) - Parent common.Hash // Hash of the first full ancestor node (nil if current is the root) - Leaf bool // Flag whether the current node is a value (data) node - LeafBlob []byte // Data blob contained within a leaf (otherwise nil) + err error // Failure set in case of an internal error in the iterator - Error error // Failure set in case of an internal error in the iterator + path []byte // Path to the current node } // NewNodeIterator creates an post-order trie iterator. -func NewNodeIterator(trie *Trie) *NodeIterator { +func NewNodeIterator(trie *Trie) NodeIterator { if trie.Hash() == emptyState { - return new(NodeIterator) + return new(nodeIterator) } - return &NodeIterator{trie: trie} + return &nodeIterator{trie: trie} +} + +// Hash returns the hash of the current node +func (it *nodeIterator) Hash() common.Hash { + if len(it.stack) == 0 { + return common.Hash{} + } + + return it.stack[len(it.stack)-1].hash +} + +// Parent returns the hash of the parent node +func (it *nodeIterator) Parent() common.Hash { + if len(it.stack) == 0 { + return common.Hash{} + } + + return it.stack[len(it.stack)-1].parent +} + +// Leaf returns true if the current node is a leaf +func (it *nodeIterator) Leaf() bool { + if len(it.stack) == 0 { + return false + } + + _, ok := it.stack[len(it.stack)-1].node.(valueNode) + return ok +} + +// LeafBlob returns the data for the current node, if it is a leaf +func (it *nodeIterator) LeafBlob() []byte { + if len(it.stack) == 0 { + return nil + } + + if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + return []byte(node) + } + return nil +} + +// Path returns the hex-encoded path to the current node +func (it *nodeIterator) Path() []byte { + return it.path +} + +// Error returns the error set in case of an internal error in the iterator +func (it *nodeIterator) Error() error { + return it.err } // Next moves the iterator to the next node, returning whether there are any // further nodes. In case of an internal error this method returns false and -// sets the Error field to the encountered failure. -func (it *NodeIterator) Next() bool { +// sets the Error field to the encountered failure. If `descend` is false, +// skips iterating over any subnodes of the current node. +func (it *nodeIterator) Next(descend bool) bool { // If the iterator failed previously, don't do anything - if it.Error != nil { + if it.err != nil { return false } // Otherwise step forward with the iterator and report any errors - if err := it.step(); err != nil { - it.Error = err + if err := it.step(descend); err != nil { + it.err = err return false } - return it.retrieve() + return it.trie != nil } // step moves the iterator to the next node of the trie. -func (it *NodeIterator) step() error { +func (it *nodeIterator) step(descend bool) error { if it.trie == nil { // Abort if we reached the end of the iteration return nil @@ -132,93 +186,180 @@ func (it *NodeIterator) step() error { state.hash = root } it.stack = append(it.stack, state) - } else { - // Continue iterating at the previous node otherwise. + return nil + } + + if !descend { + // If we're skipping children, pop the current node first + it.path = it.path[:it.stack[len(it.stack)-1].pathlen] it.stack = it.stack[:len(it.stack)-1] + } + + // Continue iteration to the next child +outer: + for { if len(it.stack) == 0 { it.trie = nil return nil } - } - - // Continue iteration to the next child - for { parent := it.stack[len(it.stack)-1] ancestor := parent.hash if (ancestor == common.Hash{}) { ancestor = parent.parent } if node, ok := parent.node.(*fullNode); ok { - // Full node, traverse all children, then the node itself - if parent.child >= len(node.Children) { - break - } + // Full node, iterate over children for parent.child++; parent.child < len(node.Children); parent.child++ { - if current := node.Children[parent.child]; current != nil { + child := node.Children[parent.child] + if child != nil { + hash, _ := child.cache() it.stack = append(it.stack, &nodeIteratorState{ - hash: common.BytesToHash(node.flags.hash), - node: current, - parent: ancestor, - child: -1, + hash: common.BytesToHash(hash), + node: child, + parent: ancestor, + child: -1, + pathlen: len(it.path), }) - break + it.path = append(it.path, byte(parent.child)) + break outer } } } else if node, ok := parent.node.(*shortNode); ok { - // Short node, traverse the pointer singleton child, then the node itself - if parent.child >= 0 { + // Short node, return the pointer singleton child + if parent.child < 0 { + parent.child++ + hash, _ := node.Val.cache() + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: node.Val, + parent: ancestor, + child: -1, + pathlen: len(it.path), + }) + if hasTerm(node.Key) { + it.path = append(it.path, node.Key[:len(node.Key)-1]...) + } else { + it.path = append(it.path, node.Key...) + } break } - parent.child++ - it.stack = append(it.stack, &nodeIteratorState{ - hash: common.BytesToHash(node.flags.hash), - node: node.Val, - parent: ancestor, - child: -1, - }) } else if hash, ok := parent.node.(hashNode); ok { - // Hash node, resolve the hash child from the database, then the node itself - if parent.child >= 0 { + // Hash node, resolve the hash child from the database + if parent.child < 0 { + parent.child++ + node, err := it.trie.resolveHash(hash, nil, nil) + if err != nil { + return err + } + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: node, + parent: ancestor, + child: -1, + pathlen: len(it.path), + }) break } - parent.child++ - - node, err := it.trie.resolveHash(hash, nil, nil) - if err != nil { - return err - } - it.stack = append(it.stack, &nodeIteratorState{ - hash: common.BytesToHash(hash), - node: node, - parent: ancestor, - child: -1, - }) - } else { - break } + it.path = it.path[:parent.pathlen] + it.stack = it.stack[:len(it.stack)-1] } return nil } -// retrieve pulls and caches the current trie node the iterator is traversing. -// In case of a value node, the additional leaf blob is also populated with the -// data contents for external interpretation. -// -// The method returns whether there are any more data left for inspection. -func (it *NodeIterator) retrieve() bool { - // Clear out any previously set values - it.Hash, it.Node, it.Parent, it.Leaf, it.LeafBlob = common.Hash{}, nil, common.Hash{}, false, nil +type differenceIterator struct { + a, b NodeIterator // Nodes returned are those in b - a. + eof bool // Indicates a has run out of elements + count int // Number of nodes scanned on either trie +} - // If the iteration's done, return no available data - if it.trie == nil { +// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that +// are not in a. Returns the iterator, and a pointer to an integer recording the number +// of nodes seen. +func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) { + a.Next(true) + it := &differenceIterator{ + a: a, + b: b, + } + return it, &it.count +} + +func (it *differenceIterator) Hash() common.Hash { + return it.b.Hash() +} + +func (it *differenceIterator) Parent() common.Hash { + return it.b.Parent() +} + +func (it *differenceIterator) Leaf() bool { + return it.b.Leaf() +} + +func (it *differenceIterator) LeafBlob() []byte { + return it.b.LeafBlob() +} + +func (it *differenceIterator) Path() []byte { + return it.b.Path() +} + +func (it *differenceIterator) Next(bool) bool { + // Invariants: + // - We always advance at least one element in b. + // - At the start of this function, a's path is lexically greater than b's. + if !it.b.Next(true) { return false } - // Otherwise retrieve the current node and resolve leaf accessors - state := it.stack[len(it.stack)-1] + it.count += 1 - it.Hash, it.Node, it.Parent = state.hash, state.node, state.parent - if value, ok := it.Node.(valueNode); ok { - it.Leaf, it.LeafBlob = true, []byte(value) + if it.eof { + // a has reached eof, so we just return all elements from b + return true + } + + for { + apath, bpath := it.a.Path(), it.b.Path() + switch bytes.Compare(apath, bpath) { + case -1: + // b jumped past a; advance a + if !it.a.Next(true) { + it.eof = true + return true + } + it.count += 1 + case 1: + // b is before a + return true + case 0: + if it.a.Hash() != it.b.Hash() || it.a.Leaf() != it.b.Leaf() { + // Keys are identical, but hashes or leaf status differs + return true + } + if it.a.Leaf() && it.b.Leaf() && !bytes.Equal(it.a.LeafBlob(), it.b.LeafBlob()) { + // Both are leaf nodes, but with different values + return true + } + + // a and b are identical; skip this whole subtree if the nodes have hashes + hasHash := it.a.Hash() == common.Hash{} + if !it.b.Next(hasHash) { + return false + } + it.count += 1 + if !it.a.Next(hasHash) { + it.eof = true + return true + } + it.count += 1 + } } - return true +} + +func (it *differenceIterator) Error() error { + if err := it.a.Error(); err != nil { + return err + } + return it.b.Error() } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index c56ac85be..0ad9711ed 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -99,9 +99,9 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := NewNodeIterator(trie); it.Next(); { - if it.Hash != (common.Hash{}) { - hashes[it.Hash] = struct{}{} + for it := NewNodeIterator(trie); it.Next(true); { + if it.Hash() != (common.Hash{}) { + hashes[it.Hash()] = struct{}{} } } // Cross check the hashes and the database itself @@ -116,3 +116,60 @@ func TestNodeIteratorCoverage(t *testing.T) { } } } + +func TestDifferenceIterator(t *testing.T) { + triea := newEmpty() + valsa := []struct{ k, v string }{ + {"bar", "b"}, + {"barb", "ba"}, + {"bars", "bb"}, + {"bard", "bc"}, + {"fab", "z"}, + {"foo", "a"}, + {"food", "ab"}, + {"foos", "aa"}, + } + for _, val := range valsa { + triea.Update([]byte(val.k), []byte(val.v)) + } + triea.Commit() + + trieb := newEmpty() + valsb := []struct{ k, v string }{ + {"aardvark", "c"}, + {"bar", "b"}, + {"barb", "bd"}, + {"bars", "be"}, + {"fab", "z"}, + {"foo", "a"}, + {"foos", "aa"}, + {"food", "ab"}, + {"jars", "d"}, + } + for _, val := range valsb { + trieb.Update([]byte(val.k), []byte(val.v)) + } + trieb.Commit() + + found := make(map[string]string) + di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb)) + it := NewIteratorFromNodeIterator(di) + for it.Next() { + found[string(it.Key)] = string(it.Value) + } + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "bd"}, + {"bars", "be"}, + {"jars", "d"}, + } + for _, item := range all { + if found[item.k] != item.v { + t.Errorf("iterator value mismatch for %s: got %q want %q", item.k, found[item.k], item.v) + } + } + if len(found) != len(all) { + t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all)) + } +} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 4d9ebe4d3..8b90da02f 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -159,7 +159,7 @@ func (t *SecureTrie) Iterator() *Iterator { return t.trie.Iterator() } -func (t *SecureTrie) NodeIterator() *NodeIterator { +func (t *SecureTrie) NodeIterator() NodeIterator { return NewNodeIterator(&t.trie) } diff --git a/trie/sync_test.go b/trie/sync_test.go index 4168c4d65..acae039cd 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -81,9 +81,9 @@ func checkTrieConsistency(db Database, root common.Hash) error { return nil // // Consider a non existent state consistent } it := NewNodeIterator(trie) - for it.Next() { + for it.Next(true) { } - return it.Error + return it.Error() } // Tests that an empty trie is not scheduled for syncing.