diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go index 6d1ebf897c..de71ee41a4 100644 --- a/core/types/hashing_test.go +++ b/core/types/hashing_test.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" @@ -38,7 +39,8 @@ func TestDeriveSha(t *testing.T) { t.Fatal(err) } for len(txs) < 1000 { - exp := types.DeriveSha(txs, new(trie.Trie)) + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp := types.DeriveSha(txs, tr) got := types.DeriveSha(txs, trie.NewStackTrie(nil)) if !bytes.Equal(got[:], exp[:]) { t.Fatalf("%d txs: got %x exp %x", len(txs), got, exp) @@ -85,7 +87,8 @@ func BenchmarkDeriveSha200(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - exp = types.DeriveSha(txs, new(trie.Trie)) + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp = types.DeriveSha(txs, tr) } }) @@ -106,7 +109,8 @@ func TestFuzzDeriveSha(t *testing.T) { rndSeed := mrand.Int() for i := 0; i < 10; i++ { seed := rndSeed + i - exp := types.DeriveSha(newDummy(i), new(trie.Trie)) + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp := types.DeriveSha(newDummy(i), tr) got := types.DeriveSha(newDummy(i), trie.NewStackTrie(nil)) if !bytes.Equal(got[:], exp[:]) { printList(newDummy(seed)) @@ -134,7 +138,8 @@ func TestDerivableList(t *testing.T) { }, } for i, tc := range tcs[1:] { - exp := types.DeriveSha(flatList(tc), new(trie.Trie)) + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp := types.DeriveSha(flatList(tc), tr) got := types.DeriveSha(flatList(tc), trie.NewStackTrie(nil)) if !bytes.Equal(got[:], exp[:]) { t.Fatalf("case %d: got %x exp %x", i, got, exp) diff --git a/les/server_handler.go b/les/server_handler.go index da06ac315e..ef1af844c2 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -374,7 +374,7 @@ func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccou return acc, nil } -// getHelperTrie returns the post-processed trie root for the given trie ID and section index +// GetHelperTrie returns the post-processed trie root for the given trie ID and section index func (h *serverHandler) GetHelperTrie(typ uint, index uint64) *trie.Trie { var ( root common.Hash diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go index 09ee6bb9c7..18717e70d0 100644 --- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go +++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go @@ -24,6 +24,7 @@ import ( "sort" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/trie" ) @@ -61,8 +62,7 @@ func (f *fuzzer) readInt() uint64 { } func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) { - - trie := new(trie.Trie) + trie, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) vals := make(map[string]*kv) size := f.readInt() // Fill it with some fluff diff --git a/trie/committer.go b/trie/committer.go index db753e2fa0..20be7e9690 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -89,7 +89,7 @@ func (c *committer) commit(n node, db *Database) (node, int, error) { if hash != nil && !dirty { return hash, 0, nil } - // Commit children, then parent, and remove remove the dirty flag. + // Commit children, then parent, and remove the dirty flag. switch cn := n.(type) { case *shortNode: // Commit child diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 9a46e9b995..ea8a46bb43 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/memorydb" @@ -296,7 +297,7 @@ func TestUnionIterator(t *testing.T) { } func TestIteratorNoDups(t *testing.T) { - var tr Trie + tr, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) for _, val := range testdata1 { tr.Update([]byte(val.k), []byte(val.v)) } diff --git a/trie/proof.go b/trie/proof.go index 88ca80b0e7..f42dcc761b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/log" ) @@ -552,7 +551,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key } // Rebuild the trie with the leaf stream, the shape of trie // should be same with the original one. - tr := &Trie{root: root, db: NewDatabase(memorydb.New())} + tr := newWithRootNode(root) if empty { tr.root = nil } diff --git a/trie/proof_test.go b/trie/proof_test.go index 29866714c2..cdf5cf6050 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" ) @@ -79,7 +80,7 @@ func TestProof(t *testing.T) { } func TestOneElementProof(t *testing.T) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) updateString(trie, "k", "v") for i, prover := range makeProvers(trie) { proof := prover([]byte("k")) @@ -130,7 +131,7 @@ func TestBadProof(t *testing.T) { // Tests that missing keys can also be proven. The test explicitly uses a single // entry trie and checks for missing keys both before and after the single entry. func TestMissingKeyProof(t *testing.T) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) updateString(trie, "k", "v") for i, key := range []string{"a", "j", "l", "z"} { @@ -386,7 +387,7 @@ func TestOneElementRangeProof(t *testing.T) { } // Test the mini trie with only a single element. - tinyTrie := new(Trie) + tinyTrie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) entry := &kv{randBytes(32), randBytes(20), false} tinyTrie.Update(entry.k, entry.v) @@ -458,7 +459,7 @@ func TestAllElementsProof(t *testing.T) { // TestSingleSideRangeProof tests the range starts from zero. func TestSingleSideRangeProof(t *testing.T) { for i := 0; i < 64; i++ { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -493,7 +494,7 @@ func TestSingleSideRangeProof(t *testing.T) { // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. func TestReverseSingleSideRangeProof(t *testing.T) { for i := 0; i < 64; i++ { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -600,7 +601,7 @@ func TestBadRangeProof(t *testing.T) { // TestGappedRangeProof focuses on the small trie with embedded nodes. // If the gapped node is embedded in the trie, it should be detected too. func TestGappedRangeProof(t *testing.T) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) var entries []*kv // Sorted entries for i := byte(0); i < 10; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} @@ -674,7 +675,7 @@ func TestSameSideProofs(t *testing.T) { } func TestHasRightElement(t *testing.T) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) var entries entrySlice for i := 0; i < 4096; i++ { value := &kv{randBytes(32), randBytes(20), false} @@ -1027,7 +1028,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } func randomTrie(n int) (*Trie, map[string]*kv) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) vals := make(map[string]*kv) for i := byte(0); i < 100; i++ { value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} @@ -1052,7 +1053,7 @@ func randBytes(n int) []byte { } func nonRandomTrie(n int) (*Trie, map[string]*kv) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) vals := make(map[string]*kv) max := uint64(0xffffffffffffffff) for i := uint64(0); i < uint64(n); i++ { @@ -1077,7 +1078,7 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { common.Hex2Bytes("02"), common.Hex2Bytes("03"), } - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) for i, key := range keys { trie.Update(key, vals[i]) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 18be12d34a..248b93544d 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -87,7 +87,7 @@ func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) { return t.trie.TryGetNode(path) } -// TryUpdate account will abstract the write of an account to the +// TryUpdateAccount account will abstract the write of an account to the // secure trie. func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { hk := t.hashKey(key) @@ -185,8 +185,10 @@ func (t *SecureTrie) Hash() common.Hash { // Copy returns a copy of SecureTrie. func (t *SecureTrie) Copy() *SecureTrie { - cpy := *t - return &cpy + return &SecureTrie{ + trie: *t.trie.Copy(), + secKeyCache: t.secKeyCache, + } } // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index fb6c38ee22..a3ece84b57 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -112,8 +112,7 @@ func TestSecureTrieConcurrency(t *testing.T) { threads := runtime.NumCPU() tries := make([]*SecureTrie, threads) for i := 0; i < threads; i++ { - cpy := *trie - tries[i] = &cpy + tries[i] = trie.Copy() } // Start a batch of goroutines interactng with the trie pend := new(sync.WaitGroup) diff --git a/trie/trie.go b/trie/trie.go index e40b03be38..fe7d6dc17e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -62,10 +63,15 @@ type LeafCallback func(paths [][]byte, hexpath []byte, leaf []byte, parent commo type Trie struct { db *Database root node - // Keep track of the number leafs which have been inserted since the last + + // Keep track of the number leaves which have been inserted since the last // hashing operation. This number will not directly map to the number of // actually unhashed nodes unhashed int + + // tracer is the state diff tracer can be used to track newly added/deleted + // trie node. It will be reset after each commit operation. + tracer *tracer } // newFlag returns the cache flag value for a newly created node. @@ -73,6 +79,16 @@ func (t *Trie) newFlag() nodeFlag { return nodeFlag{dirty: true} } +// Copy returns a copy of Trie. +func (t *Trie) Copy() *Trie { + return &Trie{ + db: t.db, + root: t.root, + unhashed: t.unhashed, + tracer: t.tracer.copy(), + } +} + // New creates a trie with an existing root node from db. // // If root is the zero hash or the sha3 hash of an empty string, the @@ -85,6 +101,7 @@ func New(root common.Hash, db *Database) (*Trie, error) { } trie := &Trie{ db: db, + //tracer: newTracer(), } if root != (common.Hash{}) && root != emptyRoot { rootnode, err := trie.resolveHash(root[:], nil) @@ -96,6 +113,16 @@ func New(root common.Hash, db *Database) (*Trie, error) { return trie, nil } +// newWithRootNode initializes the trie with the given root node. +// It's only used by range prover. +func newWithRootNode(root node) *Trie { + return &Trie{ + root: root, + //tracer: newTracer(), + db: NewDatabase(rawdb.NewMemoryDatabase()), + } +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. func (t *Trie) NodeIterator(start []byte) NodeIterator { @@ -317,7 +344,12 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if matchlen == 0 { return true, branch, nil } - // Otherwise, replace it with a short node leading up to the branch. + // New branch node is created as a child of the original short node. + // Track the newly inserted node in the tracer. The node identifier + // passed is the path from the root node. + t.tracer.onInsert(append(prefix, key[:matchlen]...)) + + // Replace it with a short node leading up to the branch. return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil case *fullNode: @@ -331,6 +363,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, n, nil case nil: + // New short node is created and track it in the tracer. The node identifier + // passed is the path from the root node. Note the valueNode won't be tracked + // since it's always embedded in its parent. + t.tracer.onInsert(prefix) + return true, &shortNode{key, value, t.newFlag()}, nil case hashNode: @@ -383,6 +420,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, n, nil // don't replace n on mismatch } if matchlen == len(key) { + // The matched short node is deleted entirely and track + // it in the deletion set. The same the valueNode doesn't + // need to be tracked at all since it's always embedded. + t.tracer.onDelete(prefix) + return true, nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix @@ -395,6 +437,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } switch child := child.(type) { case *shortNode: + // The child shortNode is merged into its parent, track + // is deleted as well. + t.tracer.onDelete(append(prefix, n.Key...)) + // Deleting from the subtrie reduced it to another // short node. Merge the nodes to avoid creating a // shortNode{..., shortNode{...}}. Use concat (which @@ -456,6 +502,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { return false, nil, err } if cnode, ok := cnode.(*shortNode); ok { + // Replace the entire full node with the short node. + // Mark the original short node as deleted since the + // value is embedded into the parent now. + t.tracer.onDelete(append(prefix, byte(pos))) + k := append([]byte{byte(pos)}, cnode.Key...) return true, &shortNode{k, cnode.Val, t.newFlag()}, nil } @@ -537,6 +588,8 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if t.db == nil { panic("commit called on trie with nil database") } + defer t.tracer.reset() + if t.root == nil { return emptyRoot, 0, nil } @@ -595,4 +648,5 @@ func (t *Trie) hashRoot() (node, node, error) { func (t *Trie) Reset() { t.root = nil t.unhashed = 0 + t.tracer.reset() } diff --git a/trie/trie_test.go b/trie/trie_test.go index a1fdc8cd58..fd9556622d 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -32,6 +32,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" @@ -53,7 +54,7 @@ func newEmpty() *Trie { } func TestEmptyTrie(t *testing.T) { - var trie Trie + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) res := trie.Hash() exp := emptyRoot if res != exp { @@ -62,7 +63,7 @@ func TestEmptyTrie(t *testing.T) { } func TestNull(t *testing.T) { - var trie Trie + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) key := make([]byte, 32) value := []byte("test") trie.Update(key, value) @@ -374,6 +375,7 @@ const ( opHash opReset opItercheckhash + opNodeDiff opMax // boundary value, not an actual op ) @@ -408,10 +410,13 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { } func runRandTest(rt randTest) bool { - triedb := NewDatabase(memorydb.New()) - - tr, _ := New(common.Hash{}, triedb) - values := make(map[string]string) // tracks content of the trie + var ( + triedb = NewDatabase(memorydb.New()) + tr, _ = New(common.Hash{}, triedb) + values = make(map[string]string) // tracks content of the trie + origTrie, _ = New(common.Hash{}, triedb) + ) + tr.tracer = newTracer() for i, step := range rt { // fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n", @@ -432,6 +437,7 @@ func runRandTest(rt randTest) bool { } case opCommit: _, _, rt[i].err = tr.Commit(nil) + origTrie = tr.Copy() case opHash: tr.Hash() case opReset: @@ -446,6 +452,9 @@ func runRandTest(rt randTest) bool { return false } tr = newtr + tr.tracer = newTracer() + + origTrie = tr.Copy() case opItercheckhash: checktr, _ := New(common.Hash{}, triedb) it := NewIterator(tr.NodeIterator(nil)) @@ -455,6 +464,59 @@ func runRandTest(rt randTest) bool { if tr.Hash() != checktr.Hash() { rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") } + case opNodeDiff: + var ( + inserted = tr.tracer.insertList() + deleted = tr.tracer.deleteList() + origIter = origTrie.NodeIterator(nil) + curIter = tr.NodeIterator(nil) + origSeen = make(map[string]struct{}) + curSeen = make(map[string]struct{}) + ) + for origIter.Next(true) { + if origIter.Leaf() { + continue + } + origSeen[string(origIter.Path())] = struct{}{} + } + for curIter.Next(true) { + if curIter.Leaf() { + continue + } + curSeen[string(curIter.Path())] = struct{}{} + } + var ( + insertExp = make(map[string]struct{}) + deleteExp = make(map[string]struct{}) + ) + for path := range curSeen { + _, present := origSeen[path] + if !present { + insertExp[path] = struct{}{} + } + } + for path := range origSeen { + _, present := curSeen[path] + if !present { + deleteExp[path] = struct{}{} + } + } + if len(insertExp) != len(inserted) { + rt[i].err = fmt.Errorf("insert set mismatch") + } + if len(deleteExp) != len(deleted) { + rt[i].err = fmt.Errorf("delete set mismatch") + } + for _, insert := range inserted { + if _, present := insertExp[string(insert)]; !present { + rt[i].err = fmt.Errorf("missing inserted node") + } + } + for _, del := range deleted { + if _, present := deleteExp[string(del)]; !present { + rt[i].err = fmt.Errorf("missing deleted node") + } + } } // Abort the test on error. if rt[i].err != nil { @@ -481,7 +543,7 @@ func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) } const benchElemCount = 20000 func benchGet(b *testing.B, commit bool) { - trie := new(Trie) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) if commit { _, tmpdb := tempDB() trie, _ = New(common.Hash{}, tmpdb) diff --git a/trie/util_test.go b/trie/util_test.go new file mode 100644 index 0000000000..fadb0553b5 --- /dev/null +++ b/trie/util_test.go @@ -0,0 +1,122 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" +) + +// Tests if the trie diffs are tracked correctly. +func TestTrieTracer(t *testing.T) { + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie, _ := New(common.Hash{}, db) + trie.tracer = newTracer() + + // Insert a batch of entries, all the nodes should be marked as inserted + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + trie.Update([]byte(val.k), []byte(val.v)) + } + trie.Hash() + + seen := make(map[string]struct{}) + it := trie.NodeIterator(nil) + for it.Next(true) { + if it.Leaf() { + continue + } + seen[string(it.Path())] = struct{}{} + } + inserted := trie.tracer.insertList() + if len(inserted) != len(seen) { + t.Fatalf("Unexpected inserted node tracked want %d got %d", len(seen), len(inserted)) + } + for _, k := range inserted { + _, ok := seen[string(k)] + if !ok { + t.Fatalf("Unexpected inserted node") + } + } + deleted := trie.tracer.deleteList() + if len(deleted) != 0 { + t.Fatalf("Unexpected deleted node tracked %d", len(deleted)) + } + + // Commit the changes + trie.Commit(nil) + + // Delete all the elements, check deletion set + for _, val := range vals { + trie.Delete([]byte(val.k)) + } + trie.Hash() + + inserted = trie.tracer.insertList() + if len(inserted) != 0 { + t.Fatalf("Unexpected inserted node tracked %d", len(inserted)) + } + deleted = trie.tracer.deleteList() + if len(deleted) != len(seen) { + t.Fatalf("Unexpected deleted node tracked want %d got %d", len(seen), len(deleted)) + } + for _, k := range deleted { + _, ok := seen[string(k)] + if !ok { + t.Fatalf("Unexpected inserted node") + } + } +} + +func TestTrieTracerNoop(t *testing.T) { + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie, _ := New(common.Hash{}, db) + trie.tracer = newTracer() + + // Insert a batch of entries, all the nodes should be marked as inserted + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + trie.Update([]byte(val.k), []byte(val.v)) + } + for _, val := range vals { + trie.Delete([]byte(val.k)) + } + if len(trie.tracer.insertList()) != 0 { + t.Fatalf("Unexpected inserted node tracked %d", len(trie.tracer.insertList())) + } + if len(trie.tracer.deleteList()) != 0 { + t.Fatalf("Unexpected deleted node tracked %d", len(trie.tracer.deleteList())) + } +} diff --git a/trie/utils.go b/trie/utils.go new file mode 100644 index 0000000000..5f9e3ba58e --- /dev/null +++ b/trie/utils.go @@ -0,0 +1,133 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +// tracer tracks the changes of trie nodes. During the trie operations, +// some nodes can be deleted from the trie, while these deleted nodes +// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted +// nodes won't be removed from the disk at all. Tracer is an auxiliary tool +// used to track all insert and delete operations of trie and capture all +// deleted nodes eventually. +// +// The changed nodes can be mainly divided into two categories: the leaf +// node and intermediate node. The former is inserted/deleted by callers +// while the latter is inserted/deleted in order to follow the rule of trie. +// This tool can track all of them no matter the node is embedded in its +// parent or not, but valueNode is never tracked. +// +// Note tracer is not thread-safe, callers should be responsible for handling +// the concurrency issues by themselves. +type tracer struct { + insert map[string]struct{} + delete map[string]struct{} +} + +// newTracer initializes trie node diff tracer. +func newTracer() *tracer { + return &tracer{ + insert: make(map[string]struct{}), + delete: make(map[string]struct{}), + } +} + +// onInsert tracks the newly inserted trie node. If it's already +// in the deletion set(resurrected node), then just wipe it from +// the deletion set as it's untouched. +func (t *tracer) onInsert(key []byte) { + // Tracer isn't used right now, remove this check later. + if t == nil { + return + } + if _, present := t.delete[string(key)]; present { + delete(t.delete, string(key)) + return + } + t.insert[string(key)] = struct{}{} +} + +// onDelete tracks the newly deleted trie node. If it's already +// in the addition set, then just wipe it from the addition set +// as it's untouched. +func (t *tracer) onDelete(key []byte) { + // Tracer isn't used right now, remove this check later. + if t == nil { + return + } + if _, present := t.insert[string(key)]; present { + delete(t.insert, string(key)) + return + } + t.delete[string(key)] = struct{}{} +} + +// insertList returns the tracked inserted trie nodes in list format. +func (t *tracer) insertList() [][]byte { + // Tracer isn't used right now, remove this check later. + if t == nil { + return nil + } + var ret [][]byte + for key := range t.insert { + ret = append(ret, []byte(key)) + } + return ret +} + +// deleteList returns the tracked deleted trie nodes in list format. +func (t *tracer) deleteList() [][]byte { + // Tracer isn't used right now, remove this check later. + if t == nil { + return nil + } + var ret [][]byte + for key := range t.delete { + ret = append(ret, []byte(key)) + } + return ret +} + +// reset clears the content tracked by tracer. +func (t *tracer) reset() { + // Tracer isn't used right now, remove this check later. + if t == nil { + return + } + t.insert = make(map[string]struct{}) + t.delete = make(map[string]struct{}) +} + +// copy returns a deep copied tracer instance. +func (t *tracer) copy() *tracer { + // Tracer isn't used right now, remove this check later. + if t == nil { + return nil + } + var ( + insert = make(map[string]struct{}) + delete = make(map[string]struct{}) + ) + for key := range t.insert { + insert[key] = struct{}{} + } + for key := range t.delete { + delete[key] = struct{}{} + } + return &tracer{ + insert: insert, + delete: delete, + } +}