From 6d2aeb43d516fbe2cafd9e65df7ccbd885f861d3 Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Wed, 21 Jun 2023 03:31:45 +0800 Subject: [PATCH] cmd, core/state, eth, tests, trie: improve state reader (#27428) The state availability is checked during the creation of a state reader. - In hash-based database, if the specified root node does not exist on disk disk, then the state reader won't be created and an error will be returned. - In path-based database, if the specified state layer is not available, then the state reader won't be created and an error will be returned. This change also contains a stricter semantics regarding the `Commit` operation: once it has been performed, the trie is no longer usable, and certain operations will return an error. --- cmd/evm/internal/t8ntool/execution.go | 8 ++- cmd/geth/dbcmd.go | 6 +- cmd/geth/snapshot.go | 26 +++++-- core/state/database.go | 5 +- core/state/dump.go | 13 +++- core/state/iterator.go | 11 ++- core/state/pruner/pruner.go | 10 ++- core/state/snapshot/generate.go | 7 +- core/state/state_test.go | 12 ++-- core/state/statedb.go | 24 +++++-- core/state/statedb_test.go | 97 ++++++++++++++++++-------- core/state/sync_test.go | 8 ++- eth/api_debug.go | 16 ++++- eth/api_debug_test.go | 26 +++---- eth/protocols/snap/sync_test.go | 4 +- light/trie.go | 10 ++- light/trie_test.go | 12 +++- tests/fuzzers/stacktrie/trie_fuzzer.go | 2 +- tests/fuzzers/trie/trie-fuzzer.go | 2 +- tests/state_test_util.go | 9 ++- trie/database_wrap.go | 4 +- trie/errors.go | 6 ++ trie/iterator_test.go | 41 ++++++----- trie/proof.go | 4 ++ trie/proof_test.go | 2 +- trie/secure_trie.go | 12 +++- trie/sync_test.go | 52 +++++++++++--- trie/tracer_test.go | 8 +-- trie/trie.go | 56 ++++++++++++--- trie/trie_reader.go | 25 ++++--- trie/trie_test.go | 8 +-- trie/triedb/hashdb/database.go | 9 ++- 32 files changed, 384 insertions(+), 151 deletions(-) diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 5633198d39..fb85794b76 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -19,7 +19,6 @@ package t8ntool import ( "fmt" "math/big" - "os" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" @@ -269,7 +268,6 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, // Commit block root, err := statedb.Commit(chainConfig.IsEIP158(vmContext.BlockNumber)) if err != nil { - fmt.Fprintf(os.Stderr, "Could not commit state: %v", err) return nil, nil, NewError(ErrorEVM, fmt.Errorf("could not commit state: %v", err)) } execRs := &ExecutionResult{ @@ -288,6 +286,12 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, h := types.DeriveSha(types.Withdrawals(pre.Env.Withdrawals), trie.NewStackTrie(nil)) execRs.WithdrawalsRoot = &h } + // Re-create statedb instance with new root upon the updated database + // for accessing latest states. + statedb, err = state.New(root, statedb.Database(), nil) + if err != nil { + return nil, nil, NewError(ErrorEVM, fmt.Errorf("could not reopen state: %v", err)) + } return statedb, execRs, nil } diff --git a/cmd/geth/dbcmd.go b/cmd/geth/dbcmd.go index b409b19260..aa7939851d 100644 --- a/cmd/geth/dbcmd.go +++ b/cmd/geth/dbcmd.go @@ -519,8 +519,12 @@ func dbDumpTrie(ctx *cli.Context) error { if err != nil { return err } + trieIt, err := theTrie.NodeIterator(start) + if err != nil { + return err + } var count int64 - it := trie.NewIterator(theTrie.NodeIterator(start)) + it := trie.NewIterator(trieIt) for it.Next() { if max > 0 && count == max { fmt.Printf("Exiting after %d values\n", count) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 7b6b939be3..67837261c6 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -292,7 +292,12 @@ func traverseState(ctx *cli.Context) error { lastReport time.Time start = time.Now() ) - accIter := trie.NewIterator(t.NodeIterator(nil)) + acctIt, err := t.NodeIterator(nil) + if err != nil { + log.Error("Failed to open iterator", "root", root, "err", err) + return err + } + accIter := trie.NewIterator(acctIt) for accIter.Next() { accounts += 1 var acc types.StateAccount @@ -307,7 +312,12 @@ func traverseState(ctx *cli.Context) error { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err } - storageIter := trie.NewIterator(storageTrie.NodeIterator(nil)) + storageIt, err := storageTrie.NodeIterator(nil) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } + storageIter := trie.NewIterator(storageIt) for storageIter.Next() { slots += 1 } @@ -385,7 +395,11 @@ func traverseRawState(ctx *cli.Context) error { hasher = crypto.NewKeccakState() got = make([]byte, 32) ) - accIter := t.NodeIterator(nil) + accIter, err := t.NodeIterator(nil) + if err != nil { + log.Error("Failed to open iterator", "root", root, "err", err) + return err + } for accIter.Next(true) { nodes += 1 node := accIter.Hash() @@ -422,7 +436,11 @@ func traverseRawState(ctx *cli.Context) error { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return errors.New("missing storage trie") } - storageIter := storageTrie.NodeIterator(nil) + storageIter, err := storageTrie.NodeIterator(nil) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } for storageIter.Next(true) { nodes += 1 node := storageIter.Hash() diff --git a/core/state/database.go b/core/state/database.go index 304c04f963..2c8a7afee6 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -113,8 +113,9 @@ type Trie interface { Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) // NodeIterator returns an iterator that returns nodes of the trie. Iteration - // starts at the key after the given start key. - NodeIterator(startKey []byte) trie.NodeIterator + // starts at the key after the given start key. And error will be returned + // if fails to create node iterator. + NodeIterator(startKey []byte) (trie.NodeIterator, error) // Prove constructs a Merkle proof for key. The result contains all encoded nodes // on the path to the value at key. The value itself is also included in the last diff --git a/core/state/dump.go b/core/state/dump.go index 96a2ac3d3e..2d9c65e473 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -140,7 +140,11 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] log.Info("Trie dumping started", "root", s.trie.Hash()) c.OnRoot(s.trie.Hash()) - it := trie.NewIterator(s.trie.NodeIterator(conf.Start)) + trieIt, err := s.trie.NodeIterator(conf.Start) + if err != nil { + return nil + } + it := trie.NewIterator(trieIt) for it.Next() { var data types.StateAccount if err := rlp.DecodeBytes(it.Value, &data); err != nil { @@ -178,7 +182,12 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] log.Error("Failed to load storage trie", "err", err) continue } - storageIt := trie.NewIterator(tr.NodeIterator(nil)) + trieIt, err := tr.NodeIterator(nil) + if err != nil { + log.Error("Failed to create trie iterator", "err", err) + continue + } + storageIt := trie.NewIterator(trieIt) for storageIt.Next() { _, content, _, err := rlp.Split(storageIt.Value) if err != nil { diff --git a/core/state/iterator.go b/core/state/iterator.go index 79bb650abc..7802263b75 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -74,8 +74,12 @@ func (it *nodeIterator) step() error { return nil } // Initialize the iterator if we've just started + var err error if it.stateIt == nil { - it.stateIt = it.state.trie.NodeIterator(nil) + it.stateIt, err = it.state.trie.NodeIterator(nil) + if err != nil { + return err + } } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { @@ -113,7 +117,10 @@ func (it *nodeIterator) step() error { if err != nil { return err } - it.dataIt = dataTrie.NodeIterator(nil) + it.dataIt, err = dataTrie.NodeIterator(nil) + if err != nil { + return err + } if !it.dataIt.Next(true) { it.dataIt = nil } diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index aed002f905..1ee5a0e625 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -420,7 +420,10 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { if err != nil { return err } - accIter := t.NodeIterator(nil) + accIter, err := t.NodeIterator(nil) + if err != nil { + return err + } for accIter.Next(true) { hash := accIter.Hash() @@ -441,7 +444,10 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { if err != nil { return err } - storageIter := storageTrie.NodeIterator(nil) + storageIter, err := storageTrie.NodeIterator(nil) + if err != nil { + return err + } for storageIter.Next(true) { hash := storageIter.Hash() if hash != (common.Hash{}) { diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index 78aa23a5de..1ec82b9442 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -382,8 +382,6 @@ func (dl *diskLayer) generateRange(ctx *generatorContext, trieId *trie.ID, prefi } var ( trieMore bool - nodeIt = tr.NodeIterator(origin) - iter = trie.NewIterator(nodeIt) kvkeys, kvvals = result.keys, result.vals // counters @@ -397,7 +395,12 @@ func (dl *diskLayer) generateRange(ctx *generatorContext, trieId *trie.ID, prefi start = time.Now() internal time.Duration ) + nodeIt, err := tr.NodeIterator(origin) + if err != nil { + return false, nil, err + } nodeIt.AddResolver(resolver) + iter := trie.NewIterator(nodeIt) for iter.Next() { if last != nil && bytes.Compare(iter.Key, last) > 0 { diff --git a/core/state/state_test.go b/core/state/state_test.go index 15e6037265..555e1f7454 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -43,7 +43,8 @@ func newStateTest() *stateTest { func TestDump(t *testing.T) { db := rawdb.NewMemoryDatabase() - sdb, _ := New(types.EmptyRootHash, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil) + tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true}) + sdb, _ := New(types.EmptyRootHash, tdb, nil) s := &stateTest{db: db, state: sdb} // generate a few entries @@ -57,9 +58,10 @@ func TestDump(t *testing.T) { // write some of them to the trie s.state.updateStateObject(obj1) s.state.updateStateObject(obj2) - s.state.Commit(false) + root, _ := s.state.Commit(false) // check that DumpToCollector contains the state objects that are in trie + s.state, _ = New(root, tdb, nil) got := string(s.state.Dump(nil)) want := `{ "root": "71edff0130dd2385947095001c73d9e28d862fc286fca2b922ca6f6f3cddfdd2", @@ -95,7 +97,8 @@ func TestDump(t *testing.T) { func TestIterativeDump(t *testing.T) { db := rawdb.NewMemoryDatabase() - sdb, _ := New(types.EmptyRootHash, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil) + tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true}) + sdb, _ := New(types.EmptyRootHash, tdb, nil) s := &stateTest{db: db, state: sdb} // generate a few entries @@ -111,7 +114,8 @@ func TestIterativeDump(t *testing.T) { // write some of them to the trie s.state.updateStateObject(obj1) s.state.updateStateObject(obj2) - s.state.Commit(false) + root, _ := s.state.Commit(false) + s.state, _ = New(root, tdb, nil) b := &bytes.Buffer{} s.state.IterativeDump(nil, json.NewEncoder(b)) diff --git a/core/state/statedb.go b/core/state/statedb.go index 90b32e3fc0..f8457c6de3 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -56,8 +56,14 @@ func (n *proofList) Delete(key []byte) error { // StateDB structs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: +// // * Contracts // * Accounts +// +// Once the state is committed, tries cached in stateDB (including account +// trie, storage tries) will no longer be functional. A new state instance +// must be created with new root and updated database for accessing post- +// commit states. type StateDB struct { db Database prefetcher *triePrefetcher @@ -680,19 +686,23 @@ func (s *StateDB) CreateAccount(addr common.Address) { } } -func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { - so := db.getStateObject(addr) +func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := s.getStateObject(addr) if so == nil { return nil } - tr, err := so.getTrie(db.db) + tr, err := so.getTrie(s.db) if err != nil { return err } - it := trie.NewIterator(tr.NodeIterator(nil)) + trieIt, err := tr.NodeIterator(nil) + if err != nil { + return err + } + it := trie.NewIterator(trieIt) for it.Next() { - key := common.BytesToHash(db.trie.GetKey(it.Key)) + key := common.BytesToHash(s.trie.GetKey(it.Key)) if value, dirty := so.dirtyStorage[key]; dirty { if !cb(key, value) { return nil @@ -977,6 +987,10 @@ func (s *StateDB) clearJournalAndRefund() { } // Commit writes the state to the underlying in-memory trie database. +// Once the state is committed, tries cached in stateDB (including account +// trie, storage tries) will no longer be functional. A new state instance +// must be created with new root and updated database for accessing post- +// commit states. func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { // Short circuit in case any database failure occurred earlier. if s.dbErr != nil { diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 82118fc5f4..44eace01d8 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -19,6 +19,7 @@ package state import ( "bytes" "encoding/binary" + "errors" "fmt" "math" "math/big" @@ -521,7 +522,8 @@ func TestCopyOfCopy(t *testing.T) { // // See https://github.com/ethereum/go-ethereum/issues/20106. func TestCopyCommitCopy(t *testing.T) { - state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()), nil) + tdb := NewDatabase(rawdb.NewMemoryDatabase()) + state, _ := New(types.EmptyRootHash, tdb, nil) // Create an account and check if the retrieved balance is correct addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") @@ -558,20 +560,6 @@ func TestCopyCommitCopy(t *testing.T) { if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("first copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } - - copyOne.Commit(false) - if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { - t.Fatalf("first copy post-commit balance mismatch: have %v, want %v", balance, 42) - } - if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { - t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) - } - if val := copyOne.GetState(addr, skey); val != sval { - t.Fatalf("first copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) - } - if val := copyOne.GetCommittedState(addr, skey); val != sval { - t.Fatalf("first copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) - } // Copy the copy and check the balance once more copyTwo := copyOne.Copy() if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { @@ -583,8 +571,23 @@ func TestCopyCommitCopy(t *testing.T) { if val := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { - t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) + if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("second copy committed storage slot mismatch: have %x, want %x", val, sval) + } + // Commit state, ensure states can be loaded from disk + root, _ := state.Commit(false) + state, _ = New(root, tdb, nil) + if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("state post-commit balance mismatch: have %v, want %v", balance, 42) + } + if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("state post-commit code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := state.GetState(addr, skey); val != sval { + t.Fatalf("state post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := state.GetCommittedState(addr, skey); val != sval { + t.Fatalf("state post-commit committed storage slot mismatch: have %x, want %x", val, sval) } } @@ -644,19 +647,6 @@ func TestCopyCopyCommitCopy(t *testing.T) { if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("second copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } - copyTwo.Commit(false) - if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { - t.Fatalf("second copy post-commit balance mismatch: have %v, want %v", balance, 42) - } - if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { - t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) - } - if val := copyTwo.GetState(addr, skey); val != sval { - t.Fatalf("second copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) - } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { - t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) - } // Copy the copy-copy and check the balance once more copyThree := copyTwo.Copy() if balance := copyThree.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { @@ -668,11 +658,56 @@ func TestCopyCopyCommitCopy(t *testing.T) { if val := copyThree.GetState(addr, skey); val != sval { t.Fatalf("third copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyThree.GetCommittedState(addr, skey); val != sval { + if val := copyThree.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("third copy committed storage slot mismatch: have %x, want %x", val, sval) } } +// TestCommitCopy tests the copy from a committed state is not functional. +func TestCommitCopy(t *testing.T) { + state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + // Create an account and check if the retrieved balance is correct + addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") + skey := common.HexToHash("aaa") + sval := common.HexToHash("bbb") + + state.SetBalance(addr, big.NewInt(42)) // Change the account trie + state.SetCode(addr, []byte("hello")) // Change an external metadata + state.SetState(addr, skey, sval) // Change the storage trie + + if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42) + } + if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := state.GetState(addr, skey); val != sval { + t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + // Copy the committed state database, the copied one is not functional. + state.Commit(true) + copied := state.Copy() + if balance := copied.GetBalance(addr); balance.Cmp(big.NewInt(0)) != 0 { + t.Fatalf("unexpected balance: have %v", balance) + } + if code := copied.GetCode(addr); code != nil { + t.Fatalf("unexpected code: have %x", code) + } + if val := copied.GetState(addr, skey); val != (common.Hash{}) { + t.Fatalf("unexpected storage slot: have %x", val) + } + if val := copied.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("unexpected storage slot: have %x", val) + } + if !errors.Is(copied.Error(), trie.ErrCommitted) { + t.Fatalf("unexpected state error, %v", copied.Error()) + } +} + // TestDeleteCreateRevert tests a weird state transition corner case that we hit // while changing the internals of StateDB. The workflow is that a contract is // self-destructed, then in a follow-up transaction (but same block) it's created diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 6e9d9342ee..cac0ee6e48 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -109,7 +109,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error { if err != nil { return err } - it := trie.NodeIterator(nil) + it := trie.MustNodeIterator(nil) for it.Next(true) { } return it.Error() @@ -566,6 +566,10 @@ func TestIncompleteStateSync(t *testing.T) { addedPaths []string addedHashes []common.Hash ) + reader, err := srcDb.TrieDB().Reader(srcRoot) + if err != nil { + t.Fatalf("state is not available %x", srcRoot) + } nodeQueue := make(map[string]stateElement) codeQueue := make(map[common.Hash]struct{}) paths, nodes, codes := sched.Missing(1) @@ -603,7 +607,7 @@ func TestIncompleteStateSync(t *testing.T) { results := make([]trie.NodeSyncResult, 0, len(nodeQueue)) for path, element := range nodeQueue { owner, inner := trie.ResolvePath([]byte(element.path)) - data, err := srcDb.TrieDB().Reader(srcRoot).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x", element.hash) } diff --git a/eth/api_debug.go b/eth/api_debug.go index 6f9daadd6b..c25e024cb2 100644 --- a/eth/api_debug.go +++ b/eth/api_debug.go @@ -235,7 +235,11 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockNrOrHash rpc.Block } func storageRangeAt(st state.Trie, start []byte, maxResult int) (StorageRangeResult, error) { - it := trie.NewIterator(st.NodeIterator(start)) + trieIt, err := st.NodeIterator(start) + if err != nil { + return StorageRangeResult{}, err + } + it := trie.NewIterator(trieIt) result := StorageRangeResult{Storage: storageMap{}} for i := 0; i < maxResult && it.Next(); i++ { _, content, _, err := rlp.Split(it.Value) @@ -326,7 +330,15 @@ func (api *DebugAPI) getModifiedAccounts(startBlock, endBlock *types.Block) ([]c if err != nil { return nil, err } - diff, _ := trie.NewDifferenceIterator(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{})) + oldIt, err := oldTrie.NodeIterator([]byte{}) + if err != nil { + return nil, err + } + newIt, err := newTrie.NodeIterator([]byte{}) + if err != nil { + return nil, err + } + diff, _ := trie.NewDifferenceIterator(oldIt, newIt) iter := trie.NewIterator(diff) var dirty []common.Address diff --git a/eth/api_debug_test.go b/eth/api_debug_test.go index c7f0d4ac90..ee54e4c3c8 100644 --- a/eth/api_debug_test.go +++ b/eth/api_debug_test.go @@ -62,34 +62,34 @@ func TestAccountRange(t *testing.T) { t.Parallel() var ( - statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), &trie.Config{Preimages: true}) - state, _ = state.New(types.EmptyRootHash, statedb, nil) - addrs = [AccountRangeMaxResults * 2]common.Address{} - m = map[common.Address]bool{} + statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), &trie.Config{Preimages: true}) + sdb, _ = state.New(types.EmptyRootHash, statedb, nil) + addrs = [AccountRangeMaxResults * 2]common.Address{} + m = map[common.Address]bool{} ) for i := range addrs { hash := common.HexToHash(fmt.Sprintf("%x", i)) addr := common.BytesToAddress(crypto.Keccak256Hash(hash.Bytes()).Bytes()) addrs[i] = addr - state.SetBalance(addrs[i], big.NewInt(1)) + sdb.SetBalance(addrs[i], big.NewInt(1)) if _, ok := m[addr]; ok { t.Fatalf("bad") } else { m[addr] = true } } - state.Commit(true) - root := state.IntermediateRoot(true) + root, _ := sdb.Commit(true) + sdb, _ = state.New(root, statedb, nil) trie, err := statedb.OpenTrie(root) if err != nil { t.Fatal(err) } - accountRangeTest(t, &trie, state, common.Hash{}, AccountRangeMaxResults/2, AccountRangeMaxResults/2) + accountRangeTest(t, &trie, sdb, common.Hash{}, AccountRangeMaxResults/2, AccountRangeMaxResults/2) // test pagination - firstResult := accountRangeTest(t, &trie, state, common.Hash{}, AccountRangeMaxResults, AccountRangeMaxResults) - secondResult := accountRangeTest(t, &trie, state, common.BytesToHash(firstResult.Next), AccountRangeMaxResults, AccountRangeMaxResults) + firstResult := accountRangeTest(t, &trie, sdb, common.Hash{}, AccountRangeMaxResults, AccountRangeMaxResults) + secondResult := accountRangeTest(t, &trie, sdb, common.BytesToHash(firstResult.Next), AccountRangeMaxResults, AccountRangeMaxResults) hList := make([]common.Hash, 0) for addr1 := range firstResult.Accounts { @@ -107,7 +107,7 @@ func TestAccountRange(t *testing.T) { // set and get an even split between the first and second sets. slices.SortFunc(hList, common.Hash.Less) middleH := hList[AccountRangeMaxResults/2] - middleResult := accountRangeTest(t, &trie, state, middleH, AccountRangeMaxResults, AccountRangeMaxResults) + middleResult := accountRangeTest(t, &trie, sdb, middleH, AccountRangeMaxResults, AccountRangeMaxResults) missing, infirst, insecond := 0, 0, 0 for h := range middleResult.Accounts { if _, ok := firstResult.Accounts[h]; ok { @@ -136,8 +136,10 @@ func TestEmptyAccountRange(t *testing.T) { statedb = state.NewDatabase(rawdb.NewMemoryDatabase()) st, _ = state.New(types.EmptyRootHash, statedb, nil) ) + // Commit(although nothing to flush) and re-init the statedb st.Commit(true) - st.IntermediateRoot(true) + st, _ = state.New(types.EmptyRootHash, statedb, nil) + results := st.IteratorDump(&state.DumpConfig{ SkipCode: true, SkipStorage: true, diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 6efe0c8325..173f1da604 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -1664,7 +1664,7 @@ func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { t.Fatal(err) } accounts, slots := 0, 0 - accIt := trie.NewIterator(accTrie.NodeIterator(nil)) + accIt := trie.NewIterator(accTrie.MustNodeIterator(nil)) for accIt.Next() { var acc struct { Nonce uint64 @@ -1682,7 +1682,7 @@ func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { if err != nil { t.Fatal(err) } - storeIt := trie.NewIterator(storeTrie.NodeIterator(nil)) + storeIt := trie.NewIterator(storeTrie.MustNodeIterator(nil)) for storeIt.Next() { slots++ } diff --git a/light/trie.go b/light/trie.go index 888598ca92..e0675f98c9 100644 --- a/light/trie.go +++ b/light/trie.go @@ -184,8 +184,8 @@ func (t *odrTrie) Hash() common.Hash { return t.trie.Hash() } -func (t *odrTrie) NodeIterator(startkey []byte) trie.NodeIterator { - return newNodeIterator(t, startkey) +func (t *odrTrie) NodeIterator(startkey []byte) (trie.NodeIterator, error) { + return newNodeIterator(t, startkey), nil } func (t *odrTrie) GetKey(sha []byte) []byte { @@ -248,7 +248,11 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator { }) } it.do(func() error { - it.NodeIterator = it.t.trie.NodeIterator(startkey) + var err error + it.NodeIterator, err = it.t.trie.NodeIterator(startkey) + if err != nil { + return err + } return it.NodeIterator.Error() }) return it diff --git a/light/trie_test.go b/light/trie_test.go index 0a24609cbd..ad7d769c84 100644 --- a/light/trie_test.go +++ b/light/trie_test.go @@ -62,8 +62,16 @@ func TestNodeIterator(t *testing.T) { } func diffTries(t1, t2 state.Trie) error { - i1 := trie.NewIterator(t1.NodeIterator(nil)) - i2 := trie.NewIterator(t2.NodeIterator(nil)) + trieIt1, err := t1.NodeIterator(nil) + if err != nil { + return err + } + trieIt2, err := t2.NodeIterator(nil) + if err != nil { + return err + } + i1 := trie.NewIterator(trieIt1) + i2 := trie.NewIterator(trieIt2) for i1.Next() && i2.Next() { if !bytes.Equal(i1.Key, i2.Key) { spew.Dump(i2) diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go index 77e602f807..d7be54eda0 100644 --- a/tests/fuzzers/stacktrie/trie_fuzzer.go +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -221,7 +221,7 @@ func (f *fuzzer) fuzz() int { panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC)) } trieA, _ = trie.New(trie.TrieID(rootA), dbA) - iterA := trieA.NodeIterator(nil) + iterA := trieA.MustNodeIterator(nil) for iterA.Next(true) { if iterA.Hash() == (common.Hash{}) { if _, present := nodeset[string(iterA.Path())]; present { diff --git a/tests/fuzzers/trie/trie-fuzzer.go b/tests/fuzzers/trie/trie-fuzzer.go index 0926d04764..33edc2a18a 100644 --- a/tests/fuzzers/trie/trie-fuzzer.go +++ b/tests/fuzzers/trie/trie-fuzzer.go @@ -179,7 +179,7 @@ func runRandTest(rt randTest) error { origin = hash case opItercheckhash: checktr := trie.NewEmpty(triedb) - it := trie.NewIterator(tr.NodeIterator(nil)) + it := trie.NewIterator(tr.MustNodeIterator(nil)) for it.Next() { checktr.MustUpdate(it.Key, it.Value) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index eec91b67a7..667e951b66 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -204,6 +204,11 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bo if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { return snaps, statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) } + // Re-init the post-state instance for further operation + statedb, err = state.New(root, statedb.Database(), snaps) + if err != nil { + return nil, nil, err + } return snaps, statedb, nil } @@ -275,9 +280,7 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh // the coinbase gets no txfee, so isn't created, and thus needs to be touched statedb.AddBalance(block.Coinbase(), new(big.Int)) // Commit block - statedb.Commit(config.IsEIP158(block.Number())) - // And _now_ get the state root - root := statedb.IntermediateRoot(config.IsEIP158(block.Number())) + root, _ := statedb.Commit(config.IsEIP158(block.Number())) return snaps, statedb, root, err } diff --git a/trie/database_wrap.go b/trie/database_wrap.go index b05f56847a..30835b7d25 100644 --- a/trie/database_wrap.go +++ b/trie/database_wrap.go @@ -113,8 +113,8 @@ func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { } // Reader returns a reader for accessing all trie nodes with provided state root. -// Nil is returned in case the state is not available. -func (db *Database) Reader(blockRoot common.Hash) Reader { +// An error will be returned if the requested state is not available. +func (db *Database) Reader(blockRoot common.Hash) (Reader, error) { return db.backend.(*hashdb.Database).Reader(blockRoot) } diff --git a/trie/errors.go b/trie/errors.go index bd82b950aa..7be7041c7f 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -17,11 +17,17 @@ package trie import ( + "errors" "fmt" "github.com/ethereum/go-ethereum/common" ) +// ErrCommitted is returned when a already committed trie is requested for usage. +// The potential usages can be `Get`, `Update`, `Delete`, `NodeIterator`, `Prove` +// and so on. +var ErrCommitted = errors.New("trie is already committed") + // MissingNodeError is returned by the trie functions (Get, Update, Delete) // in the case where a trie node is not present in the local database. It contains // information necessary for retrieving the missing node. diff --git a/trie/iterator_test.go b/trie/iterator_test.go index ccf5ea0258..5c7f9c98fb 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -34,7 +34,7 @@ import ( func TestEmptyIterator(t *testing.T) { trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) - iter := trie.NodeIterator(nil) + iter := trie.MustNodeIterator(nil) seen := make(map[string]struct{}) for iter.Next(true) { @@ -67,7 +67,7 @@ func TestIterator(t *testing.T) { trie, _ = New(TrieID(root), db) found := make(map[string]string) - it := NewIterator(trie.NodeIterator(nil)) + it := NewIterator(trie.MustNodeIterator(nil)) for it.Next() { found[string(it.Key)] = string(it.Value) } @@ -101,7 +101,7 @@ func TestIteratorLargeData(t *testing.T) { vals[string(value2.k)] = value2 } - it := NewIterator(trie.NodeIterator(nil)) + it := NewIterator(trie.MustNodeIterator(nil)) for it.Next() { vals[string(it.Key)].t = true } @@ -139,7 +139,7 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) { // Gather all the node hashes found by the iterator var elements = make(map[common.Hash]iterationElement) - for it := trie.NodeIterator(nil); it.Next(true); { + for it := trie.MustNodeIterator(nil); it.Next(true); { if it.Hash() != (common.Hash{}) { elements[it.Hash()] = iterationElement{ hash: it.Hash(), @@ -149,8 +149,12 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) { } } // Cross check the hashes and the database itself + reader, err := nodeDb.Reader(trie.Hash()) + if err != nil { + t.Fatalf("state is not available %x", trie.Hash()) + } for _, element := range elements { - if blob, err := nodeDb.Reader(trie.Hash()).Node(common.Hash{}, element.path, element.hash); err != nil { + if blob, err := reader.Node(common.Hash{}, element.path, element.hash); err != nil { t.Errorf("failed to retrieve reported node %x: %v", element.hash, err) } else if !bytes.Equal(blob, element.blob) { t.Errorf("node blob is different, want %v got %v", element.blob, blob) @@ -210,19 +214,19 @@ func TestIteratorSeek(t *testing.T) { } // Seek to the middle. - it := NewIterator(trie.NodeIterator([]byte("fab"))) + it := NewIterator(trie.MustNodeIterator([]byte("fab"))) if err := checkIteratorOrder(testdata1[4:], it); err != nil { t.Fatal(err) } // Seek to a non-existent key. - it = NewIterator(trie.NodeIterator([]byte("barc"))) + it = NewIterator(trie.MustNodeIterator([]byte("barc"))) if err := checkIteratorOrder(testdata1[1:], it); err != nil { t.Fatal(err) } // Seek beyond the end. - it = NewIterator(trie.NodeIterator([]byte("z"))) + it = NewIterator(trie.MustNodeIterator([]byte("z"))) if err := checkIteratorOrder(nil, it); err != nil { t.Fatal(err) } @@ -264,7 +268,7 @@ func TestDifferenceIterator(t *testing.T) { trieb, _ = New(TrieID(rootB), dbb) found := make(map[string]string) - di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + di, _ := NewDifferenceIterator(triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)) it := NewIterator(di) for it.Next() { found[string(it.Key)] = string(it.Value) @@ -305,7 +309,7 @@ func TestUnionIterator(t *testing.T) { dbb.Update(rootB, types.EmptyRootHash, trienode.NewWithNodeSet(nodesB)) trieb, _ = New(TrieID(rootB), dbb) - di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) + di, _ := NewUnionIterator([]NodeIterator{triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)}) it := NewIterator(di) all := []struct{ k, v string }{ @@ -344,7 +348,7 @@ func TestIteratorNoDups(t *testing.T) { for _, val := range testdata1 { tr.MustUpdate([]byte(val.k), []byte(val.v)) } - checkIteratorNoDups(t, tr.NodeIterator(nil), nil) + checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil) } // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. @@ -369,7 +373,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) { tdb.Commit(root, false) } tr, _ = New(TrieID(root), tdb) - wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) + wantNodeCount := checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil) var ( paths [][]byte @@ -428,7 +432,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) { } // Iterate until the error is hit. seen := make(map[string]bool) - it := tr.NodeIterator(nil) + it := tr.MustNodeIterator(nil) checkIteratorNoDups(t, it, seen) missing, ok := it.Error().(*MissingNodeError) if !ok || missing.NodeHash != rhash { @@ -496,7 +500,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin } // Create a new iterator that seeks to "bars". Seeking can't proceed because // the node is missing. - it := tr.NodeIterator([]byte("bars")) + it := tr.MustNodeIterator([]byte("bars")) missing, ok := it.Error().(*MissingNodeError) if !ok { t.Fatal("want MissingNodeError, got", it.Error()) @@ -602,7 +606,10 @@ func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) { } root, nodes := trie.Commit(false) triedb.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)) + triedb.Commit(root, false) + // Return the generated trie + trie, _ = NewStateTrie(TrieID(root), triedb) return triedb, trie, logDb } @@ -614,8 +621,8 @@ func TestNodeIteratorLargeTrie(t *testing.T) { // Do a seek operation trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885")) // master: 24 get operations - // this pr: 5 get operations - if have, want := logDb.getCount, uint64(5); have != want { + // this pr: 6 get operations + if have, want := logDb.getCount, uint64(6); have != want { t.Fatalf("Too many lookups during seek, have %d want %d", have, want) } } @@ -646,7 +653,7 @@ func testIteratorNodeBlob(t *testing.T, scheme string) { var found = make(map[common.Hash][]byte) trie, _ = New(TrieID(root), triedb) - it := trie.NodeIterator(nil) + it := trie.MustNodeIterator(nil) for it.Next(true) { if it.Hash() == (common.Hash{}) { continue diff --git a/trie/proof.go b/trie/proof.go index 818e338892..a463c80b48 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -34,6 +34,10 @@ import ( // nodes of the longest existing prefix of the key (at least the root node), ending // with the node that proves the absence of the key. func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return ErrCommitted + } // Collect all nodes on the path to key. var ( prefix []byte diff --git a/trie/proof_test.go b/trie/proof_test.go index 6af2c3e472..d8effd0d04 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -63,7 +63,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database { // Create a leaf iterator based Merkle prover provers = append(provers, func(key []byte) *memorydb.Database { proof := memorydb.New() - if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { + if it := NewIterator(trie.MustNodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { for _, p := range it.Prove() { proof.Put(crypto.Keccak256(p), p) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 62666c7f0c..a6d5676197 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -250,12 +250,18 @@ func (t *StateTrie) Copy() *StateTrie { } } -// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration -// starts at the key after the given start key. -func (t *StateTrie) NodeIterator(start []byte) NodeIterator { +// NodeIterator returns an iterator that returns nodes of the underlying trie. +// Iteration starts at the key after the given start key. +func (t *StateTrie) NodeIterator(start []byte) (NodeIterator, error) { return t.trie.NodeIterator(start) } +// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered +// error but just print out an error message. +func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator { + return t.trie.MustNodeIterator(start) +} + // hashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. diff --git a/trie/sync_test.go b/trie/sync_test.go index 7cd4058176..ec49beef04 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -94,7 +94,7 @@ func checkTrieConsistency(db ethdb.Database, scheme string, root common.Hash) er if err != nil { return nil // Consider a non existent state consistent } - it := trie.NodeIterator(nil) + it := trie.MustNodeIterator(nil) for it.Next(true) { } return it.Error() @@ -159,12 +159,16 @@ func testIterativeSync(t *testing.T, count int, bypath bool, scheme string) { syncPath: NewSyncPath([]byte(paths[i])), }) } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(elements) > 0 { results := make([]NodeSyncResult, len(elements)) if !bypath { for i, element := range elements { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for hash %x: %v", element.hash, err) } @@ -230,12 +234,16 @@ func testIterativeDelayedSync(t *testing.T, scheme string) { syncPath: NewSyncPath([]byte(paths[i])), }) } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(elements) > 0 { // Sync only half of the scheduled nodes results := make([]NodeSyncResult, len(elements)/2+1) for i, element := range elements[:len(results)] { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -295,12 +303,16 @@ func testIterativeRandomSync(t *testing.T, count int, scheme string) { syncPath: NewSyncPath([]byte(paths[i])), } } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(queue) > 0 { // Fetch all the queued nodes in a random order results := make([]NodeSyncResult, 0, len(queue)) for path, element := range queue { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -358,12 +370,16 @@ func testIterativeRandomDelayedSync(t *testing.T, scheme string) { syncPath: NewSyncPath([]byte(path)), } } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(queue) > 0 { // Sync only half of the scheduled nodes, even those in random order results := make([]NodeSyncResult, 0, len(queue)/2+1) for path, element := range queue { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -426,13 +442,16 @@ func testDuplicateAvoidanceSync(t *testing.T, scheme string) { syncPath: NewSyncPath([]byte(paths[i])), }) } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } requested := make(map[common.Hash]struct{}) - for len(elements) > 0 { results := make([]NodeSyncResult, len(elements)) for i, element := range elements { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -501,12 +520,16 @@ func testIncompleteSync(t *testing.T, scheme string) { syncPath: NewSyncPath([]byte(paths[i])), }) } + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(elements) > 0 { // Fetch a batch of trie nodes results := make([]NodeSyncResult, len(elements)) for i, element := range elements { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -585,12 +608,15 @@ func testSyncOrdering(t *testing.T, scheme string) { }) reqs = append(reqs, NewSyncPath([]byte(paths[i]))) } - + reader, err := srcDb.Reader(srcTrie.Hash()) + if err != nil { + t.Fatalf("State is not available %x", srcTrie.Hash()) + } for len(elements) > 0 { results := make([]NodeSyncResult, len(elements)) for i, element := range elements { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err) } @@ -649,11 +675,15 @@ func syncWith(t *testing.T, root common.Hash, db ethdb.Database, srcDb *Database syncPath: NewSyncPath([]byte(paths[i])), }) } + reader, err := srcDb.Reader(root) + if err != nil { + t.Fatalf("State is not available %x", root) + } for len(elements) > 0 { results := make([]NodeSyncResult, len(elements)) for i, element := range elements { owner, inner := ResolvePath([]byte(element.path)) - data, err := srcDb.Reader(root).Node(owner, inner, element.hash) + data, err := reader.Node(owner, inner, element.hash) if err != nil { t.Fatalf("failed to retrieve node data for hash %x: %v", element.hash, err) } diff --git a/trie/tracer_test.go b/trie/tracer_test.go index 32949a5a99..4cec53688e 100644 --- a/trie/tracer_test.go +++ b/trie/tracer_test.go @@ -226,14 +226,14 @@ func TestAccessListLeak(t *testing.T) { }{ { func(tr *Trie) { - it := tr.NodeIterator(nil) + it := tr.MustNodeIterator(nil) for it.Next(true) { } }, }, { func(tr *Trie) { - it := NewIterator(tr.NodeIterator(nil)) + it := NewIterator(tr.MustNodeIterator(nil)) for it.Next() { } }, @@ -300,7 +300,7 @@ func compareSet(setA, setB map[string]struct{}) bool { func forNodes(tr *Trie) map[string][]byte { var ( - it = tr.NodeIterator(nil) + it = tr.MustNodeIterator(nil) nodes = make(map[string][]byte) ) for it.Next(true) { @@ -319,7 +319,7 @@ func iterNodes(db *Database, root common.Hash) map[string][]byte { func forHashedNodes(tr *Trie) map[string][]byte { var ( - it = tr.NodeIterator(nil) + it = tr.MustNodeIterator(nil) nodes = make(map[string][]byte) ) for it.Next(true) { diff --git a/trie/trie.go b/trie/trie.go index e63b2d2830..cf5d181ff7 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -39,6 +39,10 @@ type Trie struct { root node owner common.Hash + // Flag whether the commit operation is already performed. If so the + // trie is not usable(latest states is invisible). + committed bool + // Keep track of the number leaves which have been inserted since the last // hashing operation. This number will not directly map to the number of // actually unhashed nodes. @@ -60,11 +64,12 @@ func (t *Trie) newFlag() nodeFlag { // Copy returns a copy of Trie. func (t *Trie) Copy() *Trie { return &Trie{ - root: t.root, - owner: t.owner, - unhashed: t.unhashed, - reader: t.reader, - tracer: t.tracer.copy(), + root: t.root, + owner: t.owner, + committed: t.committed, + unhashed: t.unhashed, + reader: t.reader, + tracer: t.tracer.copy(), } } @@ -74,7 +79,7 @@ func (t *Trie) Copy() *Trie { // zero hash or the sha3 hash of an empty string, then trie is initially // empty, otherwise, the root node must be present in database or returns // a MissingNodeError if not. -func New(id *ID, db NodeReader) (*Trie, error) { +func New(id *ID, db *Database) (*Trie, error) { reader, err := newTrieReader(id.StateRoot, id.Owner, db) if err != nil { return nil, err @@ -100,10 +105,24 @@ func NewEmpty(db *Database) *Trie { return tr } +// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered +// error but just print out an error message. +func (t *Trie) MustNodeIterator(start []byte) NodeIterator { + it, err := t.NodeIterator(start) + if err != nil { + log.Error("Unhandled trie error in Trie.NodeIterator", "err", err) + } + return it +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. -func (t *Trie) NodeIterator(start []byte) NodeIterator { - return newNodeIterator(t, start) +func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + return newNodeIterator(t, start), nil } // MustGet is a wrapper of Get and will omit any encountered error but just @@ -122,6 +141,10 @@ func (t *Trie) MustGet(key []byte) []byte { // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. func (t *Trie) Get(key []byte) ([]byte, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0) if err == nil && didResolve { t.root = newroot @@ -181,6 +204,10 @@ func (t *Trie) MustGetNode(path []byte) ([]byte, int) { // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. func (t *Trie) GetNode(path []byte) ([]byte, int, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, 0, ErrCommitted + } item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0) if err != nil { return nil, resolved, err @@ -273,6 +300,10 @@ func (t *Trie) MustUpdate(key, value []byte) { // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. func (t *Trie) Update(key, value []byte) error { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return ErrCommitted + } return t.update(key, value) } @@ -387,6 +418,10 @@ func (t *Trie) MustDelete(key []byte) { // If the requested node is not present in trie, no error will be returned. // If the trie is corrupted, a MissingNodeError is returned. func (t *Trie) Delete(key []byte) error { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return ErrCommitted + } t.unhashed++ k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) @@ -574,7 +609,9 @@ func (t *Trie) Hash() common.Hash { // be created with new root and updated trie database for following usage func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { defer t.tracer.reset() - + defer func() { + t.committed = true + }() nodes := trienode.NewNodeSet(t.owner) t.tracer.markDeletions(nodes) @@ -621,4 +658,5 @@ func (t *Trie) Reset() { t.owner = common.Hash{} t.unhashed = 0 t.tracer.reset() + t.committed = false } diff --git a/trie/trie_reader.go b/trie/trie_reader.go index 51855320a6..d42adad2c2 100644 --- a/trie/trie_reader.go +++ b/trie/trie_reader.go @@ -17,9 +17,9 @@ package trie import ( - "fmt" - "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" ) // Reader wraps the Node method of a backing trie store. @@ -30,13 +30,6 @@ type Reader interface { Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) } -// NodeReader wraps all the necessary functions for accessing trie node. -type NodeReader interface { - // Reader returns a reader for accessing all trie nodes with provided - // state root. Nil is returned in case the state is not available. - Reader(root common.Hash) Reader -} - // trieReader is a wrapper of the underlying node reader. It's not safe // for concurrent usage. type trieReader struct { @@ -46,10 +39,16 @@ type trieReader struct { } // newTrieReader initializes the trie reader with the given node reader. -func newTrieReader(stateRoot, owner common.Hash, db NodeReader) (*trieReader, error) { - reader := db.Reader(stateRoot) - if reader == nil { - return nil, fmt.Errorf("state not found #%x", stateRoot) +func newTrieReader(stateRoot, owner common.Hash, db *Database) (*trieReader, error) { + if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { + if stateRoot == (common.Hash{}) { + log.Error("Zero state root hash!") + } + return &trieReader{owner: owner}, nil + } + reader, err := db.Reader(stateRoot) + if err != nil { + return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err} } return &trieReader{owner: owner, reader: reader}, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index bd21a572d2..e9278e4670 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -521,7 +521,7 @@ func runRandTest(rt randTest) bool { origin = root case opItercheckhash: checktr := NewEmpty(triedb) - it := NewIterator(tr.NodeIterator(nil)) + it := NewIterator(tr.MustNodeIterator(nil)) for it.Next() { checktr.MustUpdate(it.Key, it.Value) } @@ -530,8 +530,8 @@ func runRandTest(rt randTest) bool { } case opNodeDiff: var ( - origIter = origTrie.NodeIterator(nil) - curIter = tr.NodeIterator(nil) + origIter = origTrie.MustNodeIterator(nil) + curIter = tr.MustNodeIterator(nil) origSeen = make(map[string]struct{}) curSeen = make(map[string]struct{}) ) @@ -710,7 +710,7 @@ func TestTinyTrie(t *testing.T) { t.Errorf("3: got %x, exp %x", root, exp) } checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) - it := NewIterator(trie.NodeIterator(nil)) + it := NewIterator(trie.MustNodeIterator(nil)) for it.Next() { checktr.MustUpdate(it.Key, it.Value) } diff --git a/trie/triedb/hashdb/database.go b/trie/triedb/hashdb/database.go index eaccfe944a..ed20ff9cde 100644 --- a/trie/triedb/hashdb/database.go +++ b/trie/triedb/hashdb/database.go @@ -18,6 +18,7 @@ package hashdb import ( "errors" + "fmt" "reflect" "sync" "time" @@ -621,8 +622,12 @@ func (db *Database) Scheme() string { } // Reader retrieves a node reader belonging to the given state root. -func (db *Database) Reader(root common.Hash) *reader { - return &reader{db: db} +// An error will be returned if the requested state is not available. +func (db *Database) Reader(root common.Hash) (*reader, error) { + if _, err := db.Node(root); err != nil { + return nil, fmt.Errorf("state %#x is not available, %v", root, err) + } + return &reader{db: db}, nil } // reader is a state reader of Database which implements the Reader interface.