From a13e920af01692cb07a520cda688f1cc5b5469dd Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 18 Apr 2017 13:37:10 +0200 Subject: [PATCH] trie: clean up iterator constructors Make it so each iterator has exactly one public constructor: - NodeIterators can be created through a method. - Iterators can be created through NewIterator on any NodeIterator. --- core/state/dump.go | 5 +++-- core/state/iterator.go | 2 +- core/state/statedb.go | 2 +- trie/iterator.go | 15 ++++----------- trie/iterator_test.go | 14 +++++++------- trie/secure_trie.go | 6 +----- trie/sync_test.go | 2 +- trie/trie.go | 4 ++-- trie/trie_test.go | 2 +- 9 files changed, 21 insertions(+), 31 deletions(-) diff --git a/core/state/dump.go b/core/state/dump.go index 8294d61b9..6338ddf88 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -22,6 +22,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" ) type DumpAccount struct { @@ -44,7 +45,7 @@ func (self *StateDB) RawDump() Dump { Accounts: make(map[string]DumpAccount), } - it := self.trie.Iterator() + it := trie.NewIterator(self.trie.NodeIterator()) for it.Next() { addr := self.trie.GetKey(it.Key) var data Account @@ -61,7 +62,7 @@ func (self *StateDB) RawDump() Dump { Code: common.Bytes2Hex(obj.Code(self.db)), Storage: make(map[string]string), } - storageIt := obj.getTrie(self.db).Iterator() + storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator()) for storageIt.Next() { account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } diff --git a/core/state/iterator.go b/core/state/iterator.go index 170aec983..d2dd5a74e 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -118,7 +118,7 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.dataIt = trie.NewNodeIterator(dataTrie) + it.dataIt = dataTrie.NodeIterator() if !it.dataIt.Next(true) { it.dataIt = nil } diff --git a/core/state/statedb.go b/core/state/statedb.go index 0c72fc6b0..24381ced5 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -481,7 +481,7 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common cb(h, value) } - it := so.getTrie(db.db).Iterator() + it := trie.NewIterator(so.getTrie(db.db).NodeIterator()) for it.Next() { // ignore cached values key := common.BytesToHash(db.trie.GetKey(it.Key)) diff --git a/trie/iterator.go b/trie/iterator.go index dd63a0c5a..fef5b2593 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -31,15 +31,8 @@ type Iterator struct { Value []byte // Current data value on which the iterator is positioned on } -// NewIterator creates a new key-value iterator. -func NewIterator(trie *Trie) *Iterator { - return &Iterator{ - nodeIt: NewNodeIterator(trie), - } -} - -// FromNodeIterator creates a new key-value iterator from a node iterator -func NewIteratorFromNodeIterator(it NodeIterator) *Iterator { +// NewIterator creates a new key-value iterator from a node iterator +func NewIterator(it NodeIterator) *Iterator { return &Iterator{ nodeIt: it, } @@ -99,8 +92,8 @@ type nodeIterator struct { path []byte // Path to the current node } -// NewNodeIterator creates an post-order trie iterator. -func NewNodeIterator(trie *Trie) NodeIterator { +// newNodeIterator creates an post-order trie iterator. +func newNodeIterator(trie *Trie) NodeIterator { if trie.Hash() == emptyState { return new(nodeIterator) } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index c101bb7b0..04d51aaf5 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -42,7 +42,7 @@ func TestIterator(t *testing.T) { trie.Commit() found := make(map[string]string) - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator()) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -72,7 +72,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie) + it := NewIterator(trie.NodeIterator()) for it.Next() { vals[string(it.Key)].t = true } @@ -99,7 +99,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Gather all the node hashes found by the iterator hashes := make(map[common.Hash]struct{}) - for it := NewNodeIterator(trie); it.Next(true); { + for it := trie.NodeIterator(); it.Next(true); { if it.Hash() != (common.Hash{}) { hashes[it.Hash()] = struct{}{} } @@ -154,8 +154,8 @@ func TestDifferenceIterator(t *testing.T) { trieb.Commit() found := make(map[string]string) - di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb)) - it := NewIteratorFromNodeIterator(di) + di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator()) + it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -189,8 +189,8 @@ func TestUnionIterator(t *testing.T) { } trieb.Commit() - di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)}) - it := NewIteratorFromNodeIterator(di) + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()}) + it := NewIterator(di) all := []struct{ k, v string }{ {"aardvark", "c"}, diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 113fb6a1a..201716d18 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -156,12 +156,8 @@ func (t *SecureTrie) Root() []byte { return t.trie.Root() } -func (t *SecureTrie) Iterator() *Iterator { - return t.trie.Iterator() -} - func (t *SecureTrie) NodeIterator() NodeIterator { - return NewNodeIterator(&t.trie) + return t.trie.NodeIterator() } // CommitTo writes all nodes and the secure hash pre-images to the given database. diff --git a/trie/sync_test.go b/trie/sync_test.go index acae039cd..6d345ad3f 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error { if err != nil { return nil // // Consider a non existent state consistent } - it := NewNodeIterator(trie) + it := trie.NodeIterator() for it.Next(true) { } return it.Error() diff --git a/trie/trie.go b/trie/trie.go index e61bd0383..dbffc0ac3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -126,8 +126,8 @@ func New(root common.Hash, db Database) (*Trie, error) { } // Iterator returns an iterator over all mappings in the trie. -func (t *Trie) Iterator() *Iterator { - return NewIterator(t) +func (t *Trie) NodeIterator() NodeIterator { + return newNodeIterator(t) } // Get returns the value for key stored in the trie. diff --git a/trie/trie_test.go b/trie/trie_test.go index 01ae3a4e7..cacb08824 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool { tr = newtr case opItercheckhash: checktr, _ := New(common.Hash{}, nil) - it := tr.Iterator() + it := NewIterator(tr.NodeIterator()) for it.Next() { checktr.Update(it.Key, it.Value) }