diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go
index 6d1ebf897c..de71ee41a4 100644
--- a/core/types/hashing_test.go
+++ b/core/types/hashing_test.go
@@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
@@ -38,7 +39,8 @@ func TestDeriveSha(t *testing.T) {
t.Fatal(err)
}
for len(txs) < 1000 {
- exp := types.DeriveSha(txs, new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(txs, tr)
got := types.DeriveSha(txs, trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
t.Fatalf("%d txs: got %x exp %x", len(txs), got, exp)
@@ -85,7 +87,8 @@ func BenchmarkDeriveSha200(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- exp = types.DeriveSha(txs, new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp = types.DeriveSha(txs, tr)
}
})
@@ -106,7 +109,8 @@ func TestFuzzDeriveSha(t *testing.T) {
rndSeed := mrand.Int()
for i := 0; i < 10; i++ {
seed := rndSeed + i
- exp := types.DeriveSha(newDummy(i), new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(newDummy(i), tr)
got := types.DeriveSha(newDummy(i), trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
printList(newDummy(seed))
@@ -134,7 +138,8 @@ func TestDerivableList(t *testing.T) {
},
}
for i, tc := range tcs[1:] {
- exp := types.DeriveSha(flatList(tc), new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(flatList(tc), tr)
got := types.DeriveSha(flatList(tc), trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
t.Fatalf("case %d: got %x exp %x", i, got, exp)
diff --git a/les/server_handler.go b/les/server_handler.go
index da06ac315e..ef1af844c2 100644
--- a/les/server_handler.go
+++ b/les/server_handler.go
@@ -374,7 +374,7 @@ func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccou
return acc, nil
}
-// getHelperTrie returns the post-processed trie root for the given trie ID and section index
+// GetHelperTrie returns the post-processed trie root for the given trie ID and section index
func (h *serverHandler) GetHelperTrie(typ uint, index uint64) *trie.Trie {
var (
root common.Hash
diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
index 09ee6bb9c7..18717e70d0 100644
--- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
+++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
@@ -24,6 +24,7 @@ import (
"sort"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/trie"
)
@@ -61,8 +62,7 @@ func (f *fuzzer) readInt() uint64 {
}
func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
-
- trie := new(trie.Trie)
+ trie, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
size := f.readInt()
// Fill it with some fluff
diff --git a/trie/committer.go b/trie/committer.go
index db753e2fa0..20be7e9690 100644
--- a/trie/committer.go
+++ b/trie/committer.go
@@ -89,7 +89,7 @@ func (c *committer) commit(n node, db *Database) (node, int, error) {
if hash != nil && !dirty {
return hash, 0, nil
}
- // Commit children, then parent, and remove remove the dirty flag.
+ // Commit children, then parent, and remove the dirty flag.
switch cn := n.(type) {
case *shortNode:
// Commit child
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 9a46e9b995..ea8a46bb43 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -24,6 +24,7 @@ import (
"testing"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
@@ -296,7 +297,7 @@ func TestUnionIterator(t *testing.T) {
}
func TestIteratorNoDups(t *testing.T) {
- var tr Trie
+ tr, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
diff --git a/trie/proof.go b/trie/proof.go
index 88ca80b0e7..f42dcc761b 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log"
)
@@ -552,7 +551,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
}
// Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
- tr := &Trie{root: root, db: NewDatabase(memorydb.New())}
+ tr := newWithRootNode(root)
if empty {
tr.root = nil
}
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 29866714c2..cdf5cf6050 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -26,6 +26,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
)
@@ -79,7 +80,7 @@ func TestProof(t *testing.T) {
}
func TestOneElementProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) {
proof := prover([]byte("k"))
@@ -130,7 +131,7 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} {
@@ -386,7 +387,7 @@ func TestOneElementRangeProof(t *testing.T) {
}
// Test the mini trie with only a single element.
- tinyTrie := new(Trie)
+ tinyTrie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v)
@@ -458,7 +459,7 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -493,7 +494,7 @@ func TestSingleSideRangeProof(t *testing.T) {
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -600,7 +601,7 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -674,7 +675,7 @@ func TestSameSideProofs(t *testing.T) {
}
func TestHasRightElement(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -1027,7 +1028,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
func randomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -1052,7 +1053,7 @@ func randBytes(n int) []byte {
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
@@ -1077,7 +1078,7 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
common.Hex2Bytes("02"),
common.Hex2Bytes("03"),
}
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
for i, key := range keys {
trie.Update(key, vals[i])
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 18be12d34a..248b93544d 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -87,7 +87,7 @@ func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path)
}
-// TryUpdate account will abstract the write of an account to the
+// TryUpdateAccount account will abstract the write of an account to the
// secure trie.
func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
hk := t.hashKey(key)
@@ -185,8 +185,10 @@ func (t *SecureTrie) Hash() common.Hash {
// Copy returns a copy of SecureTrie.
func (t *SecureTrie) Copy() *SecureTrie {
- cpy := *t
- return &cpy
+ return &SecureTrie{
+ trie: *t.trie.Copy(),
+ secKeyCache: t.secKeyCache,
+ }
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index fb6c38ee22..a3ece84b57 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -112,8 +112,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
threads := runtime.NumCPU()
tries := make([]*SecureTrie, threads)
for i := 0; i < threads; i++ {
- cpy := *trie
- tries[i] = &cpy
+ tries[i] = trie.Copy()
}
// Start a batch of goroutines interactng with the trie
pend := new(sync.WaitGroup)
diff --git a/trie/trie.go b/trie/trie.go
index e40b03be38..fe7d6dc17e 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -24,6 +24,7 @@ import (
"sync"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
@@ -62,10 +63,15 @@ type LeafCallback func(paths [][]byte, hexpath []byte, leaf []byte, parent commo
type Trie struct {
db *Database
root node
- // Keep track of the number leafs which have been inserted since the last
+
+ // Keep track of the number leaves which have been inserted since the last
// hashing operation. This number will not directly map to the number of
// actually unhashed nodes
unhashed int
+
+ // tracer is the state diff tracer can be used to track newly added/deleted
+ // trie node. It will be reset after each commit operation.
+ tracer *tracer
}
// newFlag returns the cache flag value for a newly created node.
@@ -73,6 +79,16 @@ func (t *Trie) newFlag() nodeFlag {
return nodeFlag{dirty: true}
}
+// Copy returns a copy of Trie.
+func (t *Trie) Copy() *Trie {
+ return &Trie{
+ db: t.db,
+ root: t.root,
+ unhashed: t.unhashed,
+ tracer: t.tracer.copy(),
+ }
+}
+
// New creates a trie with an existing root node from db.
//
// If root is the zero hash or the sha3 hash of an empty string, the
@@ -85,6 +101,7 @@ func New(root common.Hash, db *Database) (*Trie, error) {
}
trie := &Trie{
db: db,
+ //tracer: newTracer(),
}
if root != (common.Hash{}) && root != emptyRoot {
rootnode, err := trie.resolveHash(root[:], nil)
@@ -96,6 +113,16 @@ func New(root common.Hash, db *Database) (*Trie, error) {
return trie, nil
}
+// newWithRootNode initializes the trie with the given root node.
+// It's only used by range prover.
+func newWithRootNode(root node) *Trie {
+ return &Trie{
+ root: root,
+ //tracer: newTracer(),
+ db: NewDatabase(rawdb.NewMemoryDatabase()),
+ }
+}
+
// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key.
func (t *Trie) NodeIterator(start []byte) NodeIterator {
@@ -317,7 +344,12 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
if matchlen == 0 {
return true, branch, nil
}
- // Otherwise, replace it with a short node leading up to the branch.
+ // New branch node is created as a child of the original short node.
+ // Track the newly inserted node in the tracer. The node identifier
+ // passed is the path from the root node.
+ t.tracer.onInsert(append(prefix, key[:matchlen]...))
+
+ // Replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
case *fullNode:
@@ -331,6 +363,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, n, nil
case nil:
+ // New short node is created and track it in the tracer. The node identifier
+ // passed is the path from the root node. Note the valueNode won't be tracked
+ // since it's always embedded in its parent.
+ t.tracer.onInsert(prefix)
+
return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
@@ -383,6 +420,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
+ // The matched short node is deleted entirely and track
+ // it in the deletion set. The same the valueNode doesn't
+ // need to be tracked at all since it's always embedded.
+ t.tracer.onDelete(prefix)
+
return true, nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
@@ -395,6 +437,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
}
switch child := child.(type) {
case *shortNode:
+ // The child shortNode is merged into its parent, track
+ // is deleted as well.
+ t.tracer.onDelete(append(prefix, n.Key...))
+
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
@@ -456,6 +502,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err
}
if cnode, ok := cnode.(*shortNode); ok {
+ // Replace the entire full node with the short node.
+ // Mark the original short node as deleted since the
+ // value is embedded into the parent now.
+ t.tracer.onDelete(append(prefix, byte(pos)))
+
k := append([]byte{byte(pos)}, cnode.Key...)
return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
@@ -537,6 +588,8 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) {
if t.db == nil {
panic("commit called on trie with nil database")
}
+ defer t.tracer.reset()
+
if t.root == nil {
return emptyRoot, 0, nil
}
@@ -595,4 +648,5 @@ func (t *Trie) hashRoot() (node, node, error) {
func (t *Trie) Reset() {
t.root = nil
t.unhashed = 0
+ t.tracer.reset()
}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index a1fdc8cd58..fd9556622d 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -32,6 +32,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
@@ -53,7 +54,7 @@ func newEmpty() *Trie {
}
func TestEmptyTrie(t *testing.T) {
- var trie Trie
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
res := trie.Hash()
exp := emptyRoot
if res != exp {
@@ -62,7 +63,7 @@ func TestEmptyTrie(t *testing.T) {
}
func TestNull(t *testing.T) {
- var trie Trie
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
key := make([]byte, 32)
value := []byte("test")
trie.Update(key, value)
@@ -374,6 +375,7 @@ const (
opHash
opReset
opItercheckhash
+ opNodeDiff
opMax // boundary value, not an actual op
)
@@ -408,10 +410,13 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
}
func runRandTest(rt randTest) bool {
- triedb := NewDatabase(memorydb.New())
-
- tr, _ := New(common.Hash{}, triedb)
- values := make(map[string]string) // tracks content of the trie
+ var (
+ triedb = NewDatabase(memorydb.New())
+ tr, _ = New(common.Hash{}, triedb)
+ values = make(map[string]string) // tracks content of the trie
+ origTrie, _ = New(common.Hash{}, triedb)
+ )
+ tr.tracer = newTracer()
for i, step := range rt {
// fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n",
@@ -432,6 +437,7 @@ func runRandTest(rt randTest) bool {
}
case opCommit:
_, _, rt[i].err = tr.Commit(nil)
+ origTrie = tr.Copy()
case opHash:
tr.Hash()
case opReset:
@@ -446,6 +452,9 @@ func runRandTest(rt randTest) bool {
return false
}
tr = newtr
+ tr.tracer = newTracer()
+
+ origTrie = tr.Copy()
case opItercheckhash:
checktr, _ := New(common.Hash{}, triedb)
it := NewIterator(tr.NodeIterator(nil))
@@ -455,6 +464,59 @@ func runRandTest(rt randTest) bool {
if tr.Hash() != checktr.Hash() {
rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
}
+ case opNodeDiff:
+ var (
+ inserted = tr.tracer.insertList()
+ deleted = tr.tracer.deleteList()
+ origIter = origTrie.NodeIterator(nil)
+ curIter = tr.NodeIterator(nil)
+ origSeen = make(map[string]struct{})
+ curSeen = make(map[string]struct{})
+ )
+ for origIter.Next(true) {
+ if origIter.Leaf() {
+ continue
+ }
+ origSeen[string(origIter.Path())] = struct{}{}
+ }
+ for curIter.Next(true) {
+ if curIter.Leaf() {
+ continue
+ }
+ curSeen[string(curIter.Path())] = struct{}{}
+ }
+ var (
+ insertExp = make(map[string]struct{})
+ deleteExp = make(map[string]struct{})
+ )
+ for path := range curSeen {
+ _, present := origSeen[path]
+ if !present {
+ insertExp[path] = struct{}{}
+ }
+ }
+ for path := range origSeen {
+ _, present := curSeen[path]
+ if !present {
+ deleteExp[path] = struct{}{}
+ }
+ }
+ if len(insertExp) != len(inserted) {
+ rt[i].err = fmt.Errorf("insert set mismatch")
+ }
+ if len(deleteExp) != len(deleted) {
+ rt[i].err = fmt.Errorf("delete set mismatch")
+ }
+ for _, insert := range inserted {
+ if _, present := insertExp[string(insert)]; !present {
+ rt[i].err = fmt.Errorf("missing inserted node")
+ }
+ }
+ for _, del := range deleted {
+ if _, present := deleteExp[string(del)]; !present {
+ rt[i].err = fmt.Errorf("missing deleted node")
+ }
+ }
}
// Abort the test on error.
if rt[i].err != nil {
@@ -481,7 +543,7 @@ func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
const benchElemCount = 20000
func benchGet(b *testing.B, commit bool) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
if commit {
_, tmpdb := tempDB()
trie, _ = New(common.Hash{}, tmpdb)
diff --git a/trie/util_test.go b/trie/util_test.go
new file mode 100644
index 0000000000..fadb0553b5
--- /dev/null
+++ b/trie/util_test.go
@@ -0,0 +1,122 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+)
+
+// Tests if the trie diffs are tracked correctly.
+func TestTrieTracer(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, _ := New(common.Hash{}, db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ trie.Hash()
+
+ seen := make(map[string]struct{})
+ it := trie.NodeIterator(nil)
+ for it.Next(true) {
+ if it.Leaf() {
+ continue
+ }
+ seen[string(it.Path())] = struct{}{}
+ }
+ inserted := trie.tracer.insertList()
+ if len(inserted) != len(seen) {
+ t.Fatalf("Unexpected inserted node tracked want %d got %d", len(seen), len(inserted))
+ }
+ for _, k := range inserted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+ deleted := trie.tracer.deleteList()
+ if len(deleted) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(deleted))
+ }
+
+ // Commit the changes
+ trie.Commit(nil)
+
+ // Delete all the elements, check deletion set
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ trie.Hash()
+
+ inserted = trie.tracer.insertList()
+ if len(inserted) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(inserted))
+ }
+ deleted = trie.tracer.deleteList()
+ if len(deleted) != len(seen) {
+ t.Fatalf("Unexpected deleted node tracked want %d got %d", len(seen), len(deleted))
+ }
+ for _, k := range deleted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+}
+
+func TestTrieTracerNoop(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, _ := New(common.Hash{}, db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ if len(trie.tracer.insertList()) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(trie.tracer.insertList()))
+ }
+ if len(trie.tracer.deleteList()) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(trie.tracer.deleteList()))
+ }
+}
diff --git a/trie/utils.go b/trie/utils.go
new file mode 100644
index 0000000000..5f9e3ba58e
--- /dev/null
+++ b/trie/utils.go
@@ -0,0 +1,133 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+// tracer tracks the changes of trie nodes. During the trie operations,
+// some nodes can be deleted from the trie, while these deleted nodes
+// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
+// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
+// used to track all insert and delete operations of trie and capture all
+// deleted nodes eventually.
+//
+// The changed nodes can be mainly divided into two categories: the leaf
+// node and intermediate node. The former is inserted/deleted by callers
+// while the latter is inserted/deleted in order to follow the rule of trie.
+// This tool can track all of them no matter the node is embedded in its
+// parent or not, but valueNode is never tracked.
+//
+// Note tracer is not thread-safe, callers should be responsible for handling
+// the concurrency issues by themselves.
+type tracer struct {
+ insert map[string]struct{}
+ delete map[string]struct{}
+}
+
+// newTracer initializes trie node diff tracer.
+func newTracer() *tracer {
+ return &tracer{
+ insert: make(map[string]struct{}),
+ delete: make(map[string]struct{}),
+ }
+}
+
+// onInsert tracks the newly inserted trie node. If it's already
+// in the deletion set(resurrected node), then just wipe it from
+// the deletion set as it's untouched.
+func (t *tracer) onInsert(key []byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+ if _, present := t.delete[string(key)]; present {
+ delete(t.delete, string(key))
+ return
+ }
+ t.insert[string(key)] = struct{}{}
+}
+
+// onDelete tracks the newly deleted trie node. If it's already
+// in the addition set, then just wipe it from the addition set
+// as it's untouched.
+func (t *tracer) onDelete(key []byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+ if _, present := t.insert[string(key)]; present {
+ delete(t.insert, string(key))
+ return
+ }
+ t.delete[string(key)] = struct{}{}
+}
+
+// insertList returns the tracked inserted trie nodes in list format.
+func (t *tracer) insertList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var ret [][]byte
+ for key := range t.insert {
+ ret = append(ret, []byte(key))
+ }
+ return ret
+}
+
+// deleteList returns the tracked deleted trie nodes in list format.
+func (t *tracer) deleteList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var ret [][]byte
+ for key := range t.delete {
+ ret = append(ret, []byte(key))
+ }
+ return ret
+}
+
+// reset clears the content tracked by tracer.
+func (t *tracer) reset() {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+ t.insert = make(map[string]struct{})
+ t.delete = make(map[string]struct{})
+}
+
+// copy returns a deep copied tracer instance.
+func (t *tracer) copy() *tracer {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var (
+ insert = make(map[string]struct{})
+ delete = make(map[string]struct{})
+ )
+ for key := range t.insert {
+ insert[key] = struct{}{}
+ }
+ for key := range t.delete {
+ delete[key] = struct{}{}
+ }
+ return &tracer{
+ insert: insert,
+ delete: delete,
+ }
+}