Merge pull request #2035 from weiihann/v1.3.4-snapsync

all: pull snap sync PRs from upstream v1.13.5
This commit is contained in:
zzzckck 2023-12-19 11:25:56 +08:00 committed by GitHub
commit 474860ef77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 721 additions and 420 deletions

@ -1206,7 +1206,7 @@ func GenDoc(ctx *cli.Context) error {
URL: accounts.URL{Path: ".. ignored .."}, URL: accounts.URL{Path: ".. ignored .."},
}, },
{ {
Address: common.HexToAddress("0xffffffffffffffffffffffffffffffffffffffff"), Address: common.MaxAddress,
}, },
}}) }})
} }

@ -58,7 +58,7 @@ type accRangeTest struct {
func (s *Suite) TestSnapGetAccountRange(t *utesting.T) { func (s *Suite) TestSnapGetAccountRange(t *utesting.T) {
var ( var (
root = s.chain.RootAt(999) root = s.chain.RootAt(999)
ffHash = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") ffHash = common.MaxHash
zero = common.Hash{} zero = common.Hash{}
firstKeyMinus1 = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf29") firstKeyMinus1 = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf29")
firstKey = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf2a") firstKey = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf2a")
@ -125,7 +125,7 @@ type stRangesTest struct {
// TestSnapGetStorageRanges various forms of GetStorageRanges requests. // TestSnapGetStorageRanges various forms of GetStorageRanges requests.
func (s *Suite) TestSnapGetStorageRanges(t *utesting.T) { func (s *Suite) TestSnapGetStorageRanges(t *utesting.T) {
var ( var (
ffHash = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") ffHash = common.MaxHash
zero = common.Hash{} zero = common.Hash{}
firstKey = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf2a") firstKey = common.HexToHash("0x00bf49f440a1cd0527e4d06e2765654c0f56452257516d793a9b8d604dcfdf2a")
secondKey = common.HexToHash("0x09e47cd5056a689e708f22fe1f932709a320518e444f5f7d8d46a3da523d6606") secondKey = common.HexToHash("0x09e47cd5056a689e708f22fe1f932709a320518e444f5f7d8d46a3da523d6606")

@ -44,6 +44,12 @@ const (
var ( var (
hashT = reflect.TypeOf(Hash{}) hashT = reflect.TypeOf(Hash{})
addressT = reflect.TypeOf(Address{}) addressT = reflect.TypeOf(Address{})
// MaxAddress represents the maximum possible address value.
MaxAddress = HexToAddress("0xffffffffffffffffffffffffffffffffffffffff")
// MaxHash represents the maximum possible hash value.
MaxHash = HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
) )
// Hash represents the 32 byte Keccak256 hash of arbitrary data. // Hash represents the 32 byte Keccak256 hash of arbitrary data.

@ -22,6 +22,7 @@ import (
"sync" "sync"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
) )
const tmpSuffix = ".tmp" const tmpSuffix = ".tmp"
@ -240,6 +241,7 @@ func cleanup(path string) error {
} }
for _, name := range names { for _, name := range names {
if name == filepath.Base(path)+tmpSuffix { if name == filepath.Base(path)+tmpSuffix {
log.Info("Removed leftover freezer directory", "name", name)
return os.RemoveAll(filepath.Join(parent, name)) return os.RemoveAll(filepath.Join(parent, name))
} }
} }

@ -265,6 +265,12 @@ func (t *freezerTable) repair() error {
t.index.ReadAt(buffer, offsetsSize-indexEntrySize) t.index.ReadAt(buffer, offsetsSize-indexEntrySize)
lastIndex.unmarshalBinary(buffer) lastIndex.unmarshalBinary(buffer)
} }
// Print an error log if the index is corrupted due to an incorrect
// last index item. While it is theoretically possible to have a zero offset
// by storing all zero-size items, it is highly unlikely to occur in practice.
if lastIndex.offset == 0 && offsetsSize%indexEntrySize > 1 {
log.Error("Corrupted index file detected", "lastOffset", lastIndex.offset, "items", offsetsSize%indexEntrySize-1)
}
if t.readonly { if t.readonly {
t.head, err = t.openFile(lastIndex.filenum, openFreezerFileForReadOnly) t.head, err = t.openFile(lastIndex.filenum, openFreezerFileForReadOnly)
} else { } else {
@ -357,7 +363,7 @@ func (t *freezerTable) repair() error {
return err return err
} }
if verbose { if verbose {
t.logger.Info("Chain freezer table opened", "items", t.items.Load(), "size", t.headBytes) t.logger.Info("Chain freezer table opened", "items", t.items.Load(), "deleted", t.itemOffset.Load(), "hidden", t.itemHidden.Load(), "tailId", t.tailId, "headId", t.headId, "size", t.headBytes)
} else { } else {
t.logger.Debug("Chain freezer table opened", "items", t.items.Load(), "size", common.StorageSize(t.headBytes)) t.logger.Debug("Chain freezer table opened", "items", t.items.Load(), "size", common.StorageSize(t.headBytes))
} }
@ -530,6 +536,10 @@ func (t *freezerTable) truncateTail(items uint64) error {
if err := t.meta.Sync(); err != nil { if err := t.meta.Sync(); err != nil {
return err return err
} }
// Close the index file before shorten it.
if err := t.index.Close(); err != nil {
return err
}
// Truncate the deleted index entries from the index file. // Truncate the deleted index entries from the index file.
err = copyFrom(t.index.Name(), t.index.Name(), indexEntrySize*(newDeleted-deleted+1), func(f *os.File) error { err = copyFrom(t.index.Name(), t.index.Name(), indexEntrySize*(newDeleted-deleted+1), func(f *os.File) error {
tailIndex := indexEntry{ tailIndex := indexEntry{
@ -543,13 +553,14 @@ func (t *freezerTable) truncateTail(items uint64) error {
return err return err
} }
// Reopen the modified index file to load the changes // Reopen the modified index file to load the changes
if err := t.index.Close(); err != nil {
return err
}
t.index, err = openFreezerFileForAppend(t.index.Name()) t.index, err = openFreezerFileForAppend(t.index.Name())
if err != nil { if err != nil {
return err return err
} }
// Sync the file to ensure changes are flushed to disk
if err := t.index.Sync(); err != nil {
return err
}
// Release any files before the current tail // Release any files before the current tail
t.tailId = newTailId t.tailId = newTailId
t.itemOffset.Store(newDeleted) t.itemOffset.Store(newDeleted)
@ -782,7 +793,7 @@ func (t *freezerTable) retrieveItems(start, count, maxBytes uint64) ([]byte, []i
return fmt.Errorf("missing data file %d", fileId) return fmt.Errorf("missing data file %d", fileId)
} }
if _, err := dataFile.ReadAt(output[len(output)-length:], int64(start)); err != nil { if _, err := dataFile.ReadAt(output[len(output)-length:], int64(start)); err != nil {
return err return fmt.Errorf("%w, fileid: %d, start: %d, length: %d", err, fileId, start, length)
} }
return nil return nil
} }

@ -365,21 +365,15 @@ func generateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
} }
func stackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash, in chan trieKV, out chan common.Hash) { func stackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash, in chan trieKV, out chan common.Hash) {
var nodeWriter trie.NodeWriteFunc options := trie.NewStackTrieOptions()
if db != nil { if db != nil {
nodeWriter = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(db, owner, path, hash, blob, scheme) rawdb.WriteTrieNode(db, owner, path, hash, blob, scheme)
} })
} }
t := trie.NewStackTrieWithOwner(nodeWriter, owner) t := trie.NewStackTrie(options)
for leaf := range in { for leaf := range in {
t.Update(leaf.key[:], leaf.value) t.Update(leaf.key[:], leaf.value)
} }
var root common.Hash out <- t.Commit()
if db == nil {
root = t.Hash()
} else {
root, _ = t.Commit()
}
out <- root
} }

@ -1363,10 +1363,12 @@ func (s *StateDB) fastDeleteStorage(addrHash common.Hash, root common.Hash) (boo
nodes = trienode.NewNodeSet(addrHash) nodes = trienode.NewNodeSet(addrHash)
slots = make(map[common.Hash][]byte) slots = make(map[common.Hash][]byte)
) )
stack := trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { options := trie.NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
nodes.AddNode(path, trienode.NewDeleted()) nodes.AddNode(path, trienode.NewDeleted())
size += common.StorageSize(len(path)) size += common.StorageSize(len(path))
}) })
stack := trie.NewStackTrie(options)
for iter.Next() { for iter.Next() {
if size > storageDeleteLimit { if size > storageDeleteLimit {
return true, size, nil, nil, nil return true, size, nil, nil, nil

@ -138,7 +138,7 @@ func TestStateProcessorErrors(t *testing.T) {
) )
defer blockchain.Stop() defer blockchain.Stop()
bigNumber := new(big.Int).SetBytes(common.FromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) bigNumber := new(big.Int).SetBytes(common.MaxHash.Bytes())
tooBigNumber := new(big.Int).Set(bigNumber) tooBigNumber := new(big.Int).Set(bigNumber)
tooBigNumber.Add(tooBigNumber, common.Big1) tooBigNumber.Add(tooBigNumber, common.Big1)
for i, tt := range []struct { for i, tt := range []struct {

@ -367,7 +367,7 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP
if len(req.Origin) > 0 { if len(req.Origin) > 0 {
origin, req.Origin = common.BytesToHash(req.Origin), nil origin, req.Origin = common.BytesToHash(req.Origin), nil
} }
var limit = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") var limit = common.MaxHash
if len(req.Limit) > 0 { if len(req.Limit) > 0 {
limit, req.Limit = common.BytesToHash(req.Limit), nil limit, req.Limit = common.BytesToHash(req.Limit), nil
} }

@ -26,4 +26,32 @@ var (
IngressRegistrationErrorMeter = metrics.NewRegisteredMeter(ingressRegistrationErrorName, nil) IngressRegistrationErrorMeter = metrics.NewRegisteredMeter(ingressRegistrationErrorName, nil)
EgressRegistrationErrorMeter = metrics.NewRegisteredMeter(egressRegistrationErrorName, nil) EgressRegistrationErrorMeter = metrics.NewRegisteredMeter(egressRegistrationErrorName, nil)
// deletionGauge is the metric to track how many trie node deletions
// are performed in total during the sync process.
deletionGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/delete", nil)
// lookupGauge is the metric to track how many trie node lookups are
// performed to determine if node needs to be deleted.
lookupGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/lookup", nil)
// boundaryAccountNodesGauge is the metric to track how many boundary trie
// nodes in account trie are met.
boundaryAccountNodesGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/boundary/account", nil)
// boundaryAccountNodesGauge is the metric to track how many boundary trie
// nodes in storage tries are met.
boundaryStorageNodesGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/boundary/storage", nil)
// smallStorageGauge is the metric to track how many storages are small enough
// to retrieved in one or two request.
smallStorageGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/small", nil)
// largeStorageGauge is the metric to track how many storages are large enough
// to retrieved concurrently.
largeStorageGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/large", nil)
// skipStorageHealingGauge is the metric to track how many storages are retrieved
// in multiple requests but healing is not necessary.
skipStorageHealingGauge = metrics.NewRegisteredGauge("eth/protocols/snap/sync/storage/noheal", nil)
) )

@ -67,7 +67,7 @@ func (r *hashRange) End() common.Hash {
// If the end overflows (non divisible range), return a shorter interval // If the end overflows (non divisible range), return a shorter interval
next, overflow := new(uint256.Int).AddOverflow(r.current, r.step) next, overflow := new(uint256.Int).AddOverflow(r.current, r.step)
if overflow { if overflow {
return common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") return common.MaxHash
} }
return next.SubUint64(next, 1).Bytes32() return next.SubUint64(next, 1).Bytes32()
} }

@ -45,7 +45,7 @@ func TestHashRanges(t *testing.T) {
common.HexToHash("0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.HexToHash("0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
common.HexToHash("0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.HexToHash("0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
common.HexToHash("0xbfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.HexToHash("0xbfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.MaxHash,
}, },
}, },
// Split a divisible part of the hash range up into 2 chunks // Split a divisible part of the hash range up into 2 chunks
@ -58,7 +58,7 @@ func TestHashRanges(t *testing.T) {
}, },
ends: []common.Hash{ ends: []common.Hash{
common.HexToHash("0x8fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.HexToHash("0x8fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"),
common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.MaxHash,
}, },
}, },
// Split the entire hash range into a non divisible 3 chunks // Split the entire hash range into a non divisible 3 chunks
@ -73,7 +73,7 @@ func TestHashRanges(t *testing.T) {
ends: []common.Hash{ ends: []common.Hash{
common.HexToHash("0x5555555555555555555555555555555555555555555555555555555555555555"), common.HexToHash("0x5555555555555555555555555555555555555555555555555555555555555555"),
common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"), common.HexToHash("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"),
common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.MaxHash,
}, },
}, },
// Split a part of hash range into a non divisible 3 chunks // Split a part of hash range into a non divisible 3 chunks
@ -88,7 +88,7 @@ func TestHashRanges(t *testing.T) {
ends: []common.Hash{ ends: []common.Hash{
common.HexToHash("0x6aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), common.HexToHash("0x6aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
common.HexToHash("0xb555555555555555555555555555555555555555555555555555555555555555"), common.HexToHash("0xb555555555555555555555555555555555555555555555555555555555555555"),
common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.MaxHash,
}, },
}, },
// Split a part of hash range into a non divisible 3 chunks, but with a // Split a part of hash range into a non divisible 3 chunks, but with a
@ -108,7 +108,7 @@ func TestHashRanges(t *testing.T) {
ends: []common.Hash{ ends: []common.Hash{
common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff5"), common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff5"),
common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb"), common.HexToHash("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb"),
common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), common.MaxHash,
}, },
}, },
} }

@ -717,6 +717,19 @@ func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error {
} }
} }
// cleanPath is used to remove the dangling nodes in the stackTrie.
func (s *Syncer) cleanPath(batch ethdb.Batch, owner common.Hash, path []byte) {
if owner == (common.Hash{}) && rawdb.ExistsAccountTrieNode(s.db, path) {
rawdb.DeleteAccountTrieNode(batch, path)
deletionGauge.Inc(1)
}
if owner != (common.Hash{}) && rawdb.ExistsStorageTrieNode(s.db, owner, path) {
rawdb.DeleteStorageTrieNode(batch, owner, path)
deletionGauge.Inc(1)
}
lookupGauge.Inc(1)
}
// loadSyncStatus retrieves a previously aborted sync status from the database, // loadSyncStatus retrieves a previously aborted sync status from the database,
// or generates a fresh one if none is available. // or generates a fresh one if none is available.
func (s *Syncer) loadSyncStatus() { func (s *Syncer) loadSyncStatus() {
@ -739,9 +752,22 @@ func (s *Syncer) loadSyncStatus() {
s.accountBytes += common.StorageSize(len(key) + len(value)) s.accountBytes += common.StorageSize(len(key) + len(value))
}, },
} }
task.genTrie = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { options := trie.NewStackTrieOptions()
rawdb.WriteTrieNode(task.genBatch, owner, path, hash, val, s.scheme) options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(task.genBatch, common.Hash{}, path, hash, blob, s.scheme)
}) })
if s.scheme == rawdb.PathScheme {
// Configure the dangling node cleaner and also filter out boundary nodes
// only in the context of the path scheme. Deletion is forbidden in the
// hash scheme, as it can disrupt state completeness.
options = options.WithCleaner(func(path []byte) {
s.cleanPath(task.genBatch, common.Hash{}, path)
})
// Skip the left boundary if it's not the first range.
// Skip the right boundary if it's not the last range.
options = options.WithSkipBoundary(task.Next != (common.Hash{}), task.Last != common.MaxHash, boundaryAccountNodesGauge)
}
task.genTrie = trie.NewStackTrie(options)
for accountHash, subtasks := range task.SubTasks { for accountHash, subtasks := range task.SubTasks {
for _, subtask := range subtasks { for _, subtask := range subtasks {
subtask := subtask // closure for subtask.genBatch in the stacktrie writer callback subtask := subtask // closure for subtask.genBatch in the stacktrie writer callback
@ -752,9 +778,23 @@ func (s *Syncer) loadSyncStatus() {
s.storageBytes += common.StorageSize(len(key) + len(value)) s.storageBytes += common.StorageSize(len(key) + len(value))
}, },
} }
subtask.genTrie = trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { owner := accountHash // local assignment for stacktrie writer closure
rawdb.WriteTrieNode(subtask.genBatch, owner, path, hash, val, s.scheme) options := trie.NewStackTrieOptions()
}, accountHash) options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(subtask.genBatch, owner, path, hash, blob, s.scheme)
})
if s.scheme == rawdb.PathScheme {
// Configure the dangling node cleaner and also filter out boundary nodes
// only in the context of the path scheme. Deletion is forbidden in the
// hash scheme, as it can disrupt state completeness.
options = options.WithCleaner(func(path []byte) {
s.cleanPath(subtask.genBatch, owner, path)
})
// Skip the left boundary if it's not the first range.
// Skip the right boundary if it's not the last range.
options = options.WithSkipBoundary(subtask.Next != common.Hash{}, subtask.Last != common.MaxHash, boundaryStorageNodesGauge)
}
subtask.genTrie = trie.NewStackTrie(options)
} }
} }
} }
@ -798,7 +838,7 @@ func (s *Syncer) loadSyncStatus() {
last := common.BigToHash(new(big.Int).Add(next.Big(), step)) last := common.BigToHash(new(big.Int).Add(next.Big(), step))
if i == accountConcurrency-1 { if i == accountConcurrency-1 {
// Make sure we don't overflow if the step is not a proper divisor // Make sure we don't overflow if the step is not a proper divisor
last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") last = common.MaxHash
} }
batch := ethdb.HookedBatch{ batch := ethdb.HookedBatch{
Batch: s.db.NewBatch(), Batch: s.db.NewBatch(),
@ -806,14 +846,27 @@ func (s *Syncer) loadSyncStatus() {
s.accountBytes += common.StorageSize(len(key) + len(value)) s.accountBytes += common.StorageSize(len(key) + len(value))
}, },
} }
options := trie.NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(batch, common.Hash{}, path, hash, blob, s.scheme)
})
if s.scheme == rawdb.PathScheme {
// Configure the dangling node cleaner and also filter out boundary nodes
// only in the context of the path scheme. Deletion is forbidden in the
// hash scheme, as it can disrupt state completeness.
options = options.WithCleaner(func(path []byte) {
s.cleanPath(batch, common.Hash{}, path)
})
// Skip the left boundary if it's not the first range.
// Skip the right boundary if it's not the last range.
options = options.WithSkipBoundary(next != common.Hash{}, last != common.MaxHash, boundaryAccountNodesGauge)
}
s.tasks = append(s.tasks, &accountTask{ s.tasks = append(s.tasks, &accountTask{
Next: next, Next: next,
Last: last, Last: last,
SubTasks: make(map[common.Hash][]*storageTask), SubTasks: make(map[common.Hash][]*storageTask),
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { genTrie: trie.NewStackTrie(options),
rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme)
}),
}) })
log.Debug("Created account sync task", "from", next, "last", last) log.Debug("Created account sync task", "from", next, "last", last)
next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
@ -1877,7 +1930,7 @@ func (s *Syncer) processAccountResponse(res *accountResponse) {
return return
} }
// Some accounts are incomplete, leave as is for the storage and contract // Some accounts are incomplete, leave as is for the storage and contract
// task assigners to pick up and fill. // task assigners to pick up and fill
} }
// processBytecodeResponse integrates an already validated bytecode response // processBytecodeResponse integrates an already validated bytecode response
@ -1965,6 +2018,7 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
if res.subTask == nil && res.mainTask.needState[j] && (i < len(res.hashes)-1 || !res.cont) { if res.subTask == nil && res.mainTask.needState[j] && (i < len(res.hashes)-1 || !res.cont) {
res.mainTask.needState[j] = false res.mainTask.needState[j] = false
res.mainTask.pend-- res.mainTask.pend--
smallStorageGauge.Inc(1)
} }
// If the last contract was chunked, mark it as needing healing // If the last contract was chunked, mark it as needing healing
// to avoid writing it out to disk prematurely. // to avoid writing it out to disk prematurely.
@ -2000,7 +2054,11 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
log.Debug("Chunked large contract", "initiators", len(keys), "tail", lastKey, "chunks", chunks) log.Debug("Chunked large contract", "initiators", len(keys), "tail", lastKey, "chunks", chunks)
} }
r := newHashRange(lastKey, chunks) r := newHashRange(lastKey, chunks)
if chunks == 1 {
smallStorageGauge.Inc(1)
} else {
largeStorageGauge.Inc(1)
}
// Our first task is the one that was just filled by this response. // Our first task is the one that was just filled by this response.
batch := ethdb.HookedBatch{ batch := ethdb.HookedBatch{
Batch: s.db.NewBatch(), Batch: s.db.NewBatch(),
@ -2008,14 +2066,25 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
s.storageBytes += common.StorageSize(len(key) + len(value)) s.storageBytes += common.StorageSize(len(key) + len(value))
}, },
} }
owner := account // local assignment for stacktrie writer closure
options := trie.NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(batch, owner, path, hash, blob, s.scheme)
})
if s.scheme == rawdb.PathScheme {
options = options.WithCleaner(func(path []byte) {
s.cleanPath(batch, owner, path)
})
// Keep the left boundary as it's the first range.
// Skip the right boundary if it's not the last range.
options = options.WithSkipBoundary(false, r.End() != common.MaxHash, boundaryStorageNodesGauge)
}
tasks = append(tasks, &storageTask{ tasks = append(tasks, &storageTask{
Next: common.Hash{}, Next: common.Hash{},
Last: r.End(), Last: r.End(),
root: acc.Root, root: acc.Root,
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { genTrie: trie.NewStackTrie(options),
rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme)
}, account),
}) })
for r.Next() { for r.Next() {
batch := ethdb.HookedBatch{ batch := ethdb.HookedBatch{
@ -2024,14 +2093,27 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
s.storageBytes += common.StorageSize(len(key) + len(value)) s.storageBytes += common.StorageSize(len(key) + len(value))
}, },
} }
options := trie.NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(batch, owner, path, hash, blob, s.scheme)
})
if s.scheme == rawdb.PathScheme {
// Configure the dangling node cleaner and also filter out boundary nodes
// only in the context of the path scheme. Deletion is forbidden in the
// hash scheme, as it can disrupt state completeness.
options = options.WithCleaner(func(path []byte) {
s.cleanPath(batch, owner, path)
})
// Skip the left boundary as it's not the first range
// Skip the right boundary if it's not the last range.
options = options.WithSkipBoundary(true, r.End() != common.MaxHash, boundaryStorageNodesGauge)
}
tasks = append(tasks, &storageTask{ tasks = append(tasks, &storageTask{
Next: r.Start(), Next: r.Start(),
Last: r.End(), Last: r.End(),
root: acc.Root, root: acc.Root,
genBatch: batch, genBatch: batch,
genTrie: trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { genTrie: trie.NewStackTrie(options),
rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme)
}, account),
}) })
} }
for _, task := range tasks { for _, task := range tasks {
@ -2076,9 +2158,23 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
slots += len(res.hashes[i]) slots += len(res.hashes[i])
if i < len(res.hashes)-1 || res.subTask == nil { if i < len(res.hashes)-1 || res.subTask == nil {
tr := trie.NewStackTrieWithOwner(func(owner common.Hash, path []byte, hash common.Hash, val []byte) { // no need to make local reassignment of account: this closure does not outlive the loop
rawdb.WriteTrieNode(batch, owner, path, hash, val, s.scheme) options := trie.NewStackTrieOptions()
}, account) options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(batch, account, path, hash, blob, s.scheme)
})
if s.scheme == rawdb.PathScheme {
// Configure the dangling node cleaner only in the context of the
// path scheme. Deletion is forbidden in the hash scheme, as it can
// disrupt state completeness.
//
// Notably, boundary nodes can be also kept because the whole storage
// trie is complete.
options = options.WithCleaner(func(path []byte) {
s.cleanPath(batch, account, path)
})
}
tr := trie.NewStackTrie(options)
for j := 0; j < len(res.hashes[i]); j++ { for j := 0; j < len(res.hashes[i]); j++ {
tr.Update(res.hashes[i][j][:], res.slots[i][j]) tr.Update(res.hashes[i][j][:], res.slots[i][j])
} }
@ -2100,18 +2196,25 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
// Large contracts could have generated new trie nodes, flush them to disk // Large contracts could have generated new trie nodes, flush them to disk
if res.subTask != nil { if res.subTask != nil {
if res.subTask.done { if res.subTask.done {
if root, err := res.subTask.genTrie.Commit(); err != nil { root := res.subTask.genTrie.Commit()
log.Error("Failed to commit stack slots", "err", err) if err := res.subTask.genBatch.Write(); err != nil {
} else if root == res.subTask.root { log.Error("Failed to persist stack slots", "err", err)
// If the chunk's root is an overflown but full delivery, clear the heal request }
res.subTask.genBatch.Reset()
// If the chunk's root is an overflown but full delivery,
// clear the heal request.
accountHash := res.accounts[len(res.accounts)-1]
if root == res.subTask.root && rawdb.HasStorageTrieNode(s.db, accountHash, nil, root) {
for i, account := range res.mainTask.res.hashes { for i, account := range res.mainTask.res.hashes {
if account == res.accounts[len(res.accounts)-1] { if account == accountHash {
res.mainTask.needHeal[i] = false res.mainTask.needHeal[i] = false
skipStorageHealingGauge.Inc(1)
} }
} }
} }
} }
if res.subTask.genBatch.ValueSize() > ethdb.IdealBatchSize || res.subTask.done { if res.subTask.genBatch.ValueSize() > ethdb.IdealBatchSize {
if err := res.subTask.genBatch.Write(); err != nil { if err := res.subTask.genBatch.Write(); err != nil {
log.Error("Failed to persist stack slots", "err", err) log.Error("Failed to persist stack slots", "err", err)
} }
@ -2318,9 +2421,7 @@ func (s *Syncer) forwardAccountTask(task *accountTask) {
// flush after finalizing task.done. It's fine even if we crash and lose this // flush after finalizing task.done. It's fine even if we crash and lose this
// write as it will only cause more data to be downloaded during heal. // write as it will only cause more data to be downloaded during heal.
if task.done { if task.done {
if _, err := task.genTrie.Commit(); err != nil { task.genTrie.Commit()
log.Error("Failed to commit stack account", "err", err)
}
} }
if task.genBatch.ValueSize() > ethdb.IdealBatchSize || task.done { if task.genBatch.ValueSize() > ethdb.IdealBatchSize || task.done {
if err := task.genBatch.Write(); err != nil { if err := task.genBatch.Write(); err != nil {
@ -2625,7 +2726,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
// the requested data. For storage range queries that means the state being // the requested data. For storage range queries that means the state being
// retrieved was either already pruned remotely, or the peer is not yet // retrieved was either already pruned remotely, or the peer is not yet
// synced to our head. // synced to our head.
if len(hashes) == 0 { if len(hashes) == 0 && len(proof) == 0 {
logger.Debug("Peer rejected storage request") logger.Debug("Peer rejected storage request")
s.statelessPeers[peer.ID()] = struct{}{} s.statelessPeers[peer.ID()] = struct{}{}
s.lock.Unlock() s.lock.Unlock()
@ -2637,6 +2738,13 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
// Reconstruct the partial tries from the response and verify them // Reconstruct the partial tries from the response and verify them
var cont bool var cont bool
// If a proof was attached while the response is empty, it indicates that the
// requested range specified with 'origin' is empty. Construct an empty state
// response locally to finalize the range.
if len(hashes) == 0 && len(proof) > 0 {
hashes = append(hashes, []common.Hash{})
slots = append(slots, [][]byte{})
}
for i := 0; i < len(hashes); i++ { for i := 0; i < len(hashes); i++ {
// Convert the keys and proofs into an internal format // Convert the keys and proofs into an internal format
keys := make([][]byte, len(hashes[i])) keys := make([][]byte, len(hashes[i]))

@ -22,6 +22,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math/big" "math/big"
mrand "math/rand"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -35,6 +36,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/testutil"
"github.com/ethereum/go-ethereum/trie/triedb/pathdb" "github.com/ethereum/go-ethereum/trie/triedb/pathdb"
"github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/trie/trienode"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
@ -254,7 +256,7 @@ func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, orig
func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) { func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
var size uint64 var size uint64
if limit == (common.Hash{}) { if limit == (common.Hash{}) {
limit = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") limit = common.MaxHash
} }
for _, entry := range t.accountValues { for _, entry := range t.accountValues {
if size > cap { if size > cap {
@ -319,7 +321,7 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm
if len(origin) > 0 { if len(origin) > 0 {
originHash = common.BytesToHash(origin) originHash = common.BytesToHash(origin)
} }
var limitHash = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") var limitHash = common.MaxHash
if len(limit) > 0 { if len(limit) > 0 {
limitHash = common.BytesToHash(limit) limitHash = common.BytesToHash(limit)
} }
@ -762,7 +764,7 @@ func testSyncWithStorage(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 3000, true, false, false)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -772,7 +774,7 @@ func testSyncWithStorage(t *testing.T, scheme string) {
source.storageValues = storageElems source.storageValues = storageElems
return source return source
} }
syncer := setupSyncer(nodeScheme, mkSource("sourceA")) syncer := setupSyncer(scheme, mkSource("sourceA"))
done := checkStall(t, term) done := checkStall(t, term)
if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil { if err := syncer.Sync(sourceAccountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err) t.Fatalf("sync failed: %v", err)
@ -799,7 +801,7 @@ func testMultiSyncManyUseless(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -821,7 +823,7 @@ func testMultiSyncManyUseless(t *testing.T, scheme string) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -853,7 +855,7 @@ func testMultiSyncManyUselessWithLowTimeout(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -875,7 +877,7 @@ func testMultiSyncManyUselessWithLowTimeout(t *testing.T, scheme string) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -912,7 +914,7 @@ func testMultiSyncManyUnresponsive(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer { mkSource := func(name string, noAccount, noStorage, noTrieNode bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -934,7 +936,7 @@ func testMultiSyncManyUnresponsive(t *testing.T, scheme string) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("full", true, true, true), mkSource("full", true, true, true),
mkSource("noAccounts", false, true, true), mkSource("noAccounts", false, true, true),
mkSource("noStorage", true, false, true), mkSource("noStorage", true, false, true),
@ -1215,7 +1217,7 @@ func testSyncBoundaryStorageTrie(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 10, 1000, false, true) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 10, 1000, false, true, false)
mkSource := func(name string) *testPeer { mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -1226,7 +1228,7 @@ func testSyncBoundaryStorageTrie(t *testing.T, scheme string) {
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("peer-a"), mkSource("peer-a"),
mkSource("peer-b"), mkSource("peer-b"),
) )
@ -1257,7 +1259,7 @@ func testSyncWithStorageAndOneCappedPeer(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 300, 1000, false, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 300, 1000, false, false, false)
mkSource := func(name string, slow bool) *testPeer { mkSource := func(name string, slow bool) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -1273,7 +1275,7 @@ func testSyncWithStorageAndOneCappedPeer(t *testing.T, scheme string) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("nice-a", false), mkSource("nice-a", false),
mkSource("slow", true), mkSource("slow", true),
) )
@ -1304,7 +1306,7 @@ func testSyncWithStorageAndCorruptPeer(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
mkSource := func(name string, handler storageHandlerFunc) *testPeer { mkSource := func(name string, handler storageHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -1317,7 +1319,7 @@ func testSyncWithStorageAndCorruptPeer(t *testing.T, scheme string) {
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-a", defaultStorageRequestHandler),
mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler),
mkSource("nice-c", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler),
@ -1348,7 +1350,7 @@ func testSyncWithStorageAndNonProvingPeer(t *testing.T, scheme string) {
}) })
} }
) )
nodeScheme, sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false) sourceAccountTrie, elems, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 100, 3000, true, false, false)
mkSource := func(name string, handler storageHandlerFunc) *testPeer { mkSource := func(name string, handler storageHandlerFunc) *testPeer {
source := newTestPeer(name, t, term) source := newTestPeer(name, t, term)
@ -1360,7 +1362,7 @@ func testSyncWithStorageAndNonProvingPeer(t *testing.T, scheme string) {
return source return source
} }
syncer := setupSyncer( syncer := setupSyncer(
nodeScheme, scheme,
mkSource("nice-a", defaultStorageRequestHandler), mkSource("nice-a", defaultStorageRequestHandler),
mkSource("nice-b", defaultStorageRequestHandler), mkSource("nice-b", defaultStorageRequestHandler),
mkSource("nice-c", defaultStorageRequestHandler), mkSource("nice-c", defaultStorageRequestHandler),
@ -1413,6 +1415,45 @@ func testSyncWithStorageMisbehavingProve(t *testing.T, scheme string) {
verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t) verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t)
} }
// TestSyncWithUnevenStorage tests sync where the storage trie is not even
// and with a few empty ranges.
func TestSyncWithUnevenStorage(t *testing.T) {
t.Parallel()
testSyncWithUnevenStorage(t, rawdb.HashScheme)
testSyncWithUnevenStorage(t, rawdb.PathScheme)
}
func testSyncWithUnevenStorage(t *testing.T, scheme string) {
var (
once sync.Once
cancel = make(chan struct{})
term = func() {
once.Do(func() {
close(cancel)
})
}
)
accountTrie, accounts, storageTries, storageElems := makeAccountTrieWithStorage(scheme, 3, 256, false, false, true)
mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term)
source.accountTrie = accountTrie.Copy()
source.accountValues = accounts
source.setStorageTries(storageTries)
source.storageValues = storageElems
source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode
}
return source
}
syncer := setupSyncer(scheme, mkSource("source"))
if err := syncer.Sync(accountTrie.Hash(), cancel); err != nil {
t.Fatalf("sync failed: %v", err)
}
verifyTrie(scheme, syncer.db, accountTrie.Hash(), t)
}
type kv struct { type kv struct {
k, v []byte k, v []byte
} }
@ -1511,7 +1552,7 @@ func makeBoundaryAccountTrie(scheme string, n int) (string, *trie.Trie, []*kv) {
for i := 0; i < accountConcurrency; i++ { for i := 0; i < accountConcurrency; i++ {
last := common.BigToHash(new(big.Int).Add(next.Big(), step)) last := common.BigToHash(new(big.Int).Add(next.Big(), step))
if i == accountConcurrency-1 { if i == accountConcurrency-1 {
last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") last = common.MaxHash
} }
boundaries = append(boundaries, last) boundaries = append(boundaries, last)
next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
@ -1608,7 +1649,7 @@ func makeAccountTrieWithStorageWithUniqueStorage(scheme string, accounts, slots
} }
// makeAccountTrieWithStorage spits out a trie, along with the leafs // makeAccountTrieWithStorage spits out a trie, along with the leafs
func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, boundary bool) (string, *trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) { func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, boundary bool, uneven bool) (*trie.Trie, []*kv, map[common.Hash]*trie.Trie, map[common.Hash][]*kv) {
var ( var (
db = trie.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme)) db = trie.NewDatabase(rawdb.NewMemoryDatabase(), newDbConfig(scheme))
accTrie = trie.NewEmpty(db) accTrie = trie.NewEmpty(db)
@ -1633,6 +1674,8 @@ func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, bounda
) )
if boundary { if boundary {
stRoot, stNodes, stEntries = makeBoundaryStorageTrie(common.BytesToHash(key), slots, db) stRoot, stNodes, stEntries = makeBoundaryStorageTrie(common.BytesToHash(key), slots, db)
} else if uneven {
stRoot, stNodes, stEntries = makeUnevenStorageTrie(common.BytesToHash(key), slots, db)
} else { } else {
stRoot, stNodes, stEntries = makeStorageTrieWithSeed(common.BytesToHash(key), uint64(slots), 0, db) stRoot, stNodes, stEntries = makeStorageTrieWithSeed(common.BytesToHash(key), uint64(slots), 0, db)
} }
@ -1675,7 +1718,7 @@ func makeAccountTrieWithStorage(scheme string, accounts, slots int, code, bounda
} }
storageTries[common.BytesToHash(key)] = trie storageTries[common.BytesToHash(key)] = trie
} }
return db.Scheme(), accTrie, entries, storageTries, storageEntries return accTrie, entries, storageTries, storageEntries
} }
// makeStorageTrieWithSeed fills a storage trie with n items, returning the // makeStorageTrieWithSeed fills a storage trie with n items, returning the
@ -1721,7 +1764,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
for i := 0; i < accountConcurrency; i++ { for i := 0; i < accountConcurrency; i++ {
last := common.BigToHash(new(big.Int).Add(next.Big(), step)) last := common.BigToHash(new(big.Int).Add(next.Big(), step))
if i == accountConcurrency-1 { if i == accountConcurrency-1 {
last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") last = common.MaxHash
} }
boundaries = append(boundaries, last) boundaries = append(boundaries, last)
next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
@ -1752,6 +1795,38 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
return root, nodes, entries return root, nodes, entries
} }
// makeUnevenStorageTrie constructs a storage tries will states distributed in
// different range unevenly.
func makeUnevenStorageTrie(owner common.Hash, slots int, db *trie.Database) (common.Hash, *trienode.NodeSet, []*kv) {
var (
entries []*kv
tr, _ = trie.New(trie.StorageTrieID(types.EmptyRootHash, owner, types.EmptyRootHash), db)
chosen = make(map[byte]struct{})
)
for i := 0; i < 3; i++ {
var n int
for {
n = mrand.Intn(15) // the last range is set empty deliberately
if _, ok := chosen[byte(n)]; ok {
continue
}
chosen[byte(n)] = struct{}{}
break
}
for j := 0; j < slots/3; j++ {
key := append([]byte{byte(n)}, testutil.RandBytes(31)...)
val, _ := rlp.EncodeToBytes(testutil.RandBytes(32))
elem := &kv{key, val}
tr.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
}
slices.SortFunc(entries, (*kv).cmp)
root, nodes, _ := tr.Commit(false)
return root, nodes, entries
}
func verifyTrie(scheme string, db ethdb.KeyValueStore, root common.Hash, t *testing.T) { func verifyTrie(scheme string, db ethdb.KeyValueStore, root common.Hash, t *testing.T) {
t.Helper() t.Helper()
triedb := trie.NewDatabase(rawdb.NewDatabase(db), newDbConfig(scheme)) triedb := trie.NewDatabase(rawdb.NewDatabase(db), newDbConfig(scheme))

@ -140,9 +140,11 @@ func (f *fuzzer) fuzz() int {
trieA = trie.NewEmpty(dbA) trieA = trie.NewEmpty(dbA)
spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()} spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()}
dbB = trie.NewDatabase(rawdb.NewDatabase(spongeB), nil) dbB = trie.NewDatabase(rawdb.NewDatabase(spongeB), nil)
trieB = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(spongeB, owner, path, hash, blob, dbB.Scheme()) options = trie.NewStackTrieOptions().WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(spongeB, common.Hash{}, path, hash, blob, dbB.Scheme())
}) })
trieB = trie.NewStackTrie(options)
vals []kv vals []kv
useful bool useful bool
maxElements = 10000 maxElements = 10000
@ -204,22 +206,20 @@ func (f *fuzzer) fuzz() int {
// Ensure all the nodes are persisted correctly // Ensure all the nodes are persisted correctly
var ( var (
nodeset = make(map[string][]byte) // path -> blob nodeset = make(map[string][]byte) // path -> blob
trieC = trie.NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) { optionsC = trie.NewStackTrieOptions().WithWriter(func(path []byte, hash common.Hash, blob []byte) {
if crypto.Keccak256Hash(blob) != hash { if crypto.Keccak256Hash(blob) != hash {
panic("invalid node blob") panic("invalid node blob")
} }
if owner != (common.Hash{}) {
panic("invalid node owner")
}
nodeset[string(path)] = common.CopyBytes(blob) nodeset[string(path)] = common.CopyBytes(blob)
}) })
trieC = trie.NewStackTrie(optionsC)
checked int checked int
) )
for _, kv := range vals { for _, kv := range vals {
trieC.MustUpdate(kv.k, kv.v) trieC.MustUpdate(kv.k, kv.v)
} }
rootC, _ := trieC.Commit() rootC := trieC.Commit()
if rootA != rootC { if rootA != rootC {
panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC)) panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootC))
} }

@ -51,9 +51,8 @@ func hexToCompact(hex []byte) []byte {
return buf return buf
} }
// hexToCompactInPlace places the compact key in input buffer, returning the length // hexToCompactInPlace places the compact key in input buffer, returning the compacted key.
// needed for the representation func hexToCompactInPlace(hex []byte) []byte {
func hexToCompactInPlace(hex []byte) int {
var ( var (
hexLen = len(hex) // length of the hex input hexLen = len(hex) // length of the hex input
firstByte = byte(0) firstByte = byte(0)
@ -77,7 +76,7 @@ func hexToCompactInPlace(hex []byte) int {
hex[bi] = hex[ni]<<4 | hex[ni+1] hex[bi] = hex[ni]<<4 | hex[ni+1]
} }
hex[0] = firstByte hex[0] = firstByte
return binLen return hex[:binLen]
} }
func compactToHex(compact []byte) []byte { func compactToHex(compact []byte) []byte {

@ -86,8 +86,7 @@ func TestHexToCompactInPlace(t *testing.T) {
} { } {
hexBytes, _ := hex.DecodeString(key) hexBytes, _ := hex.DecodeString(key)
exp := hexToCompact(hexBytes) exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes) got := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) { if !bytes.Equal(exp, got) {
t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp) t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
} }
@ -102,8 +101,7 @@ func TestHexToCompactInPlaceRandom(t *testing.T) {
hexBytes := keybytesToHex(key) hexBytes := keybytesToHex(key)
hexOrig := []byte(string(hexBytes)) hexOrig := []byte(string(hexBytes))
exp := hexToCompact(hexBytes) exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes) got := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) { if !bytes.Equal(exp, got) {
t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n", t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
@ -119,6 +117,13 @@ func BenchmarkHexToCompact(b *testing.B) {
} }
} }
func BenchmarkHexToCompactInPlace(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
hexToCompactInPlace(testBytes)
}
}
func BenchmarkCompactToHex(b *testing.B) { func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

@ -250,7 +250,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// Special case, two edge proofs for two edge key. // Special case, two edge proofs for two edge key.
proof := memorydb.New() proof := memorydb.New()
first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() last := common.MaxHash.Bytes()
if err := trie.Prove(first, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
@ -451,7 +451,7 @@ func TestAllElementsProof(t *testing.T) {
// Even with non-existent edge proofs, it should still work. // Even with non-existent edge proofs, it should still work.
proof = memorydb.New() proof = memorydb.New()
first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() last := common.MaxHash.Bytes()
if err := trie.Prove(first, proof); err != nil { if err := trie.Prove(first, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
@ -517,7 +517,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
if err := trie.Prove(entries[pos].k, proof); err != nil { if err := trie.Prove(entries[pos].k, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") last := common.MaxHash
if err := trie.Prove(last.Bytes(), proof); err != nil { if err := trie.Prove(last.Bytes(), proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
@ -728,7 +728,7 @@ func TestHasRightElement(t *testing.T) {
} }
} }
if c.end == -1 { if c.end == -1 {
lastKey, end = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries) lastKey, end = common.MaxHash.Bytes(), len(entries)
if err := trie.Prove(lastKey, proof); err != nil { if err := trie.Prove(lastKey, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }

@ -17,183 +17,146 @@
package trie package trie
import ( import (
"bufio"
"bytes" "bytes"
"encoding/gob"
"errors"
"io"
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
) )
var ErrCommitDisabled = errors.New("no database for committing") var (
stPool = sync.Pool{New: func() any { return new(stNode) }}
_ = types.TrieHasher((*StackTrie)(nil))
)
var stPool = sync.Pool{ // StackTrieOptions contains the configured options for manipulating the stackTrie.
New: func() interface{} { type StackTrieOptions struct {
return NewStackTrie(nil) Writer func(path []byte, hash common.Hash, blob []byte) // The function to commit the dirty nodes
}, Cleaner func(path []byte) // The function to clean up dangling nodes
SkipLeftBoundary bool // Flag whether the nodes on the left boundary are skipped for committing
SkipRightBoundary bool // Flag whether the nodes on the right boundary are skipped for committing
boundaryGauge metrics.Gauge // Gauge to track how many boundary nodes are met
} }
// NodeWriteFunc is used to provide all information of a dirty node for committing // NewStackTrieOptions initializes an empty options for stackTrie.
// so that callers can flush nodes into database with desired scheme. func NewStackTrieOptions() *StackTrieOptions { return &StackTrieOptions{} }
type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte)
func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { // WithWriter configures trie node writer within the options.
st := stPool.Get().(*StackTrie) func (o *StackTrieOptions) WithWriter(writer func(path []byte, hash common.Hash, blob []byte)) *StackTrieOptions {
st.owner = owner o.Writer = writer
st.writeFn = writeFn return o
return st
} }
func returnToPool(st *StackTrie) { // WithCleaner configures the cleaner in the option for removing dangling nodes.
st.Reset() func (o *StackTrieOptions) WithCleaner(cleaner func(path []byte)) *StackTrieOptions {
stPool.Put(st) o.Cleaner = cleaner
return o
}
// WithSkipBoundary configures whether the left and right boundary nodes are
// filtered for committing, along with a gauge metrics to track how many
// boundary nodes are met.
func (o *StackTrieOptions) WithSkipBoundary(skipLeft, skipRight bool, gauge metrics.Gauge) *StackTrieOptions {
o.SkipLeftBoundary = skipLeft
o.SkipRightBoundary = skipRight
o.boundaryGauge = gauge
return o
} }
// StackTrie is a trie implementation that expects keys to be inserted // StackTrie is a trie implementation that expects keys to be inserted
// in order. Once it determines that a subtree will no longer be inserted // in order. Once it determines that a subtree will no longer be inserted
// into, it will hash it and free up the memory it uses. // into, it will hash it and free up the memory it uses.
type StackTrie struct { type StackTrie struct {
owner common.Hash // the owner of the trie options *StackTrieOptions
nodeType uint8 // node type (as in branch, ext, leaf) root *stNode
val []byte // value contained by this node if it's a leaf h *hasher
key []byte // key chunk covered by this (leaf|ext) node
children [16]*StackTrie // list of children (for branch and exts) first []byte // The (hex-encoded without terminator) key of first inserted entry, tracked as left boundary.
writeFn NodeWriteFunc // function for committing nodes, can be nil last []byte // The (hex-encoded without terminator) key of last inserted entry, tracked as right boundary.
} }
// NewStackTrie allocates and initializes an empty trie. // NewStackTrie allocates and initializes an empty trie.
func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { func NewStackTrie(options *StackTrieOptions) *StackTrie {
if options == nil {
options = NewStackTrieOptions()
}
return &StackTrie{ return &StackTrie{
nodeType: emptyNode, options: options,
writeFn: writeFn, root: stPool.Get().(*stNode),
h: newHasher(false),
} }
} }
// NewStackTrieWithOwner allocates and initializes an empty trie, but with // Update inserts a (key, value) pair into the stack trie.
// the additional owner field. func (t *StackTrie) Update(key, value []byte) error {
func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { k := keybytesToHex(key)
return &StackTrie{ if len(value) == 0 {
owner: owner, panic("deletion not supported")
nodeType: emptyNode,
writeFn: writeFn,
} }
} k = k[:len(k)-1] // chop the termination flag
// NewFromBinary initialises a serialized stacktrie with the given db. // track the first and last inserted entries.
func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { if t.first == nil {
var st StackTrie t.first = append([]byte{}, k...)
if err := st.UnmarshalBinary(data); err != nil {
return nil, err
} }
// If a database is used, we need to recursively add it to every child if t.last == nil {
if writeFn != nil { t.last = append([]byte{}, k...) // allocate key slice
st.setWriter(writeFn) } else {
} t.last = append(t.last[:0], k...) // reuse key slice
return &st, nil
}
// MarshalBinary implements encoding.BinaryMarshaler
func (st *StackTrie) MarshalBinary() (data []byte, err error) {
var (
b bytes.Buffer
w = bufio.NewWriter(&b)
)
if err := gob.NewEncoder(w).Encode(struct {
Owner common.Hash
NodeType uint8
Val []byte
Key []byte
}{
st.owner,
st.nodeType,
st.val,
st.key,
}); err != nil {
return nil, err
}
for _, child := range st.children {
if child == nil {
w.WriteByte(0)
continue
}
w.WriteByte(1)
if childData, err := child.MarshalBinary(); err != nil {
return nil, err
} else {
w.Write(childData)
}
}
w.Flush()
return b.Bytes(), nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (st *StackTrie) UnmarshalBinary(data []byte) error {
r := bytes.NewReader(data)
return st.unmarshalBinary(r)
}
func (st *StackTrie) unmarshalBinary(r io.Reader) error {
var dec struct {
Owner common.Hash
NodeType uint8
Val []byte
Key []byte
}
if err := gob.NewDecoder(r).Decode(&dec); err != nil {
return err
}
st.owner = dec.Owner
st.nodeType = dec.NodeType
st.val = dec.Val
st.key = dec.Key
var hasChild = make([]byte, 1)
for i := range st.children {
if _, err := r.Read(hasChild); err != nil {
return err
} else if hasChild[0] == 0 {
continue
}
var child StackTrie
if err := child.unmarshalBinary(r); err != nil {
return err
}
st.children[i] = &child
} }
t.insert(t.root, k, value, nil)
return nil return nil
} }
func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { // MustUpdate is a wrapper of Update and will omit any encountered error but
st.writeFn = writeFn // just print out an error message.
for _, child := range st.children { func (t *StackTrie) MustUpdate(key, value []byte) {
if child != nil { if err := t.Update(key, value); err != nil {
child.setWriter(writeFn) log.Error("Unhandled trie error in StackTrie.Update", "err", err)
}
} }
} }
func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { // Reset resets the stack trie object to empty state.
st := stackTrieFromPool(writeFn, owner) func (t *StackTrie) Reset() {
st.nodeType = leafNode t.options = NewStackTrieOptions()
t.root = stPool.Get().(*stNode)
t.first = nil
t.last = nil
}
// stNode represents a node within a StackTrie
type stNode struct {
typ uint8 // node type (as in branch, ext, leaf)
key []byte // key chunk covered by this (leaf|ext) node
val []byte // value contained by this node if it's a leaf
children [16]*stNode // list of children (for branch and exts)
}
// newLeaf constructs a leaf node with provided node key and value. The key
// will be deep-copied in the function and safe to modify afterwards, but
// value is not.
func newLeaf(key, val []byte) *stNode {
st := stPool.Get().(*stNode)
st.typ = leafNode
st.key = append(st.key, key...) st.key = append(st.key, key...)
st.val = val st.val = val
return st return st
} }
func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { // newExt constructs an extension node with provided node key and child. The
st := stackTrieFromPool(writeFn, owner) // key will be deep-copied in the function and safe to modify afterwards.
st.nodeType = extNode func newExt(key []byte, child *stNode) *stNode {
st := stPool.Get().(*stNode)
st.typ = extNode
st.key = append(st.key, key...) st.key = append(st.key, key...)
st.children[0] = child st.children[0] = child
return st return st
} }
// List all values that StackTrie#nodeType can hold // List all values that stNode#nodeType can hold
const ( const (
emptyNode = iota emptyNode = iota
branchNode branchNode
@ -202,59 +165,40 @@ const (
hashedNode hashedNode
) )
// Update inserts a (key, value) pair into the stack trie. func (n *stNode) reset() *stNode {
func (st *StackTrie) Update(key, value []byte) error { n.key = n.key[:0]
k := keybytesToHex(key) n.val = nil
if len(value) == 0 { for i := range n.children {
panic("deletion not supported") n.children[i] = nil
} }
st.insert(k[:len(k)-1], value, nil) n.typ = emptyNode
return nil return n
}
// MustUpdate is a wrapper of Update and will omit any encountered error but
// just print out an error message.
func (st *StackTrie) MustUpdate(key, value []byte) {
if err := st.Update(key, value); err != nil {
log.Error("Unhandled trie error in StackTrie.Update", "err", err)
}
}
func (st *StackTrie) Reset() {
st.owner = common.Hash{}
st.writeFn = nil
st.key = st.key[:0]
st.val = nil
for i := range st.children {
st.children[i] = nil
}
st.nodeType = emptyNode
} }
// Helper function that, given a full key, determines the index // Helper function that, given a full key, determines the index
// at which the chunk pointed by st.keyOffset is different from // at which the chunk pointed by st.keyOffset is different from
// the same chunk in the full key. // the same chunk in the full key.
func (st *StackTrie) getDiffIndex(key []byte) int { func (n *stNode) getDiffIndex(key []byte) int {
for idx, nibble := range st.key { for idx, nibble := range n.key {
if nibble != key[idx] { if nibble != key[idx] {
return idx return idx
} }
} }
return len(st.key) return len(n.key)
} }
// Helper function to that inserts a (key, value) pair into // Helper function to that inserts a (key, value) pair into
// the trie. // the trie.
func (st *StackTrie) insert(key, value []byte, prefix []byte) { func (t *StackTrie) insert(st *stNode, key, value []byte, path []byte) {
switch st.nodeType { switch st.typ {
case branchNode: /* Branch */ case branchNode: /* Branch */
idx := int(key[0]) idx := int(key[0])
// Unresolve elder siblings // Unresolve elder siblings
for i := idx - 1; i >= 0; i-- { for i := idx - 1; i >= 0; i-- {
if st.children[i] != nil { if st.children[i] != nil {
if st.children[i].nodeType != hashedNode { if st.children[i].typ != hashedNode {
st.children[i].hash(append(prefix, byte(i))) t.hash(st.children[i], append(path, byte(i)))
} }
break break
} }
@ -262,9 +206,9 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
// Add new child // Add new child
if st.children[idx] == nil { if st.children[idx] == nil {
st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) st.children[idx] = newLeaf(key[1:], value)
} else { } else {
st.children[idx].insert(key[1:], value, append(prefix, key[0])) t.insert(st.children[idx], key[1:], value, append(path, key[0]))
} }
case extNode: /* Ext */ case extNode: /* Ext */
@ -279,46 +223,46 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
if diffidx == len(st.key) { if diffidx == len(st.key) {
// Ext key and key segment are identical, recurse into // Ext key and key segment are identical, recurse into
// the child node. // the child node.
st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) t.insert(st.children[0], key[diffidx:], value, append(path, key[:diffidx]...))
return return
} }
// Save the original part. Depending if the break is // Save the original part. Depending if the break is
// at the extension's last byte or not, create an // at the extension's last byte or not, create an
// intermediate extension or use the extension's child // intermediate extension or use the extension's child
// node directly. // node directly.
var n *StackTrie var n *stNode
if diffidx < len(st.key)-1 { if diffidx < len(st.key)-1 {
// Break on the non-last byte, insert an intermediate // Break on the non-last byte, insert an intermediate
// extension. The path prefix of the newly-inserted // extension. The path prefix of the newly-inserted
// extension should also contain the different byte. // extension should also contain the different byte.
n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) n = newExt(st.key[diffidx+1:], st.children[0])
n.hash(append(prefix, st.key[:diffidx+1]...)) t.hash(n, append(path, st.key[:diffidx+1]...))
} else { } else {
// Break on the last byte, no need to insert // Break on the last byte, no need to insert
// an extension node: reuse the current node. // an extension node: reuse the current node.
// The path prefix of the original part should // The path prefix of the original part should
// still be same. // still be same.
n = st.children[0] n = st.children[0]
n.hash(append(prefix, st.key...)) t.hash(n, append(path, st.key...))
} }
var p *StackTrie var p *stNode
if diffidx == 0 { if diffidx == 0 {
// the break is on the first byte, so // the break is on the first byte, so
// the current node is converted into // the current node is converted into
// a branch node. // a branch node.
st.children[0] = nil st.children[0] = nil
p = st p = st
st.nodeType = branchNode st.typ = branchNode
} else { } else {
// the common prefix is at least one byte // the common prefix is at least one byte
// long, insert a new intermediate branch // long, insert a new intermediate branch
// node. // node.
st.children[0] = stackTrieFromPool(st.writeFn, st.owner) st.children[0] = stPool.Get().(*stNode)
st.children[0].nodeType = branchNode st.children[0].typ = branchNode
p = st.children[0] p = st.children[0]
} }
// Create a leaf for the inserted part // Create a leaf for the inserted part
o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) o := newLeaf(key[diffidx+1:], value)
// Insert both child leaves where they belong: // Insert both child leaves where they belong:
origIdx := st.key[diffidx] origIdx := st.key[diffidx]
@ -344,18 +288,18 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
// Check if the split occurs at the first nibble of the // Check if the split occurs at the first nibble of the
// chunk. In that case, no prefix extnode is necessary. // chunk. In that case, no prefix extnode is necessary.
// Otherwise, create that // Otherwise, create that
var p *StackTrie var p *stNode
if diffidx == 0 { if diffidx == 0 {
// Convert current leaf into a branch // Convert current leaf into a branch
st.nodeType = branchNode st.typ = branchNode
p = st p = st
st.children[0] = nil st.children[0] = nil
} else { } else {
// Convert current node into an ext, // Convert current node into an ext,
// and insert a child branch node. // and insert a child branch node.
st.nodeType = extNode st.typ = extNode
st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) st.children[0] = stPool.Get().(*stNode)
st.children[0].nodeType = branchNode st.children[0].typ = branchNode
p = st.children[0] p = st.children[0]
} }
@ -363,11 +307,11 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
// value and another containing the new value. The child leaf // value and another containing the new value. The child leaf
// is hashed directly in order to free up some memory. // is hashed directly in order to free up some memory.
origIdx := st.key[diffidx] origIdx := st.key[diffidx]
p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val)
p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) t.hash(p.children[origIdx], append(path, st.key[:diffidx+1]...))
newIdx := key[diffidx] newIdx := key[diffidx]
p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) p.children[newIdx] = newLeaf(key[diffidx+1:], value)
// Finally, cut off the key part that has been passed // Finally, cut off the key part that has been passed
// over to the children. // over to the children.
@ -375,7 +319,7 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
st.val = nil st.val = nil
case emptyNode: /* Empty */ case emptyNode: /* Empty */
st.nodeType = leafNode st.typ = leafNode
st.key = key st.key = key
st.val = value st.val = value
@ -398,25 +342,19 @@ func (st *StackTrie) insert(key, value []byte, prefix []byte) {
// - And the 'st.type' will be 'hashedNode' AGAIN // - And the 'st.type' will be 'hashedNode' AGAIN
// //
// This method also sets 'st.type' to hashedNode, and clears 'st.key'. // This method also sets 'st.type' to hashedNode, and clears 'st.key'.
func (st *StackTrie) hash(path []byte) { func (t *StackTrie) hash(st *stNode, path []byte) {
h := newHasher(false) var (
defer returnHasherToPool(h) blob []byte // RLP-encoded node blob
internal [][]byte // List of node paths covered by the extension node
st.hashRec(h, path) )
} switch st.typ {
func (st *StackTrie) hashRec(hasher *hasher, path []byte) {
// The switch below sets this to the RLP-encoding of this node.
var encodedNode []byte
switch st.nodeType {
case hashedNode: case hashedNode:
return return
case emptyNode: case emptyNode:
st.val = types.EmptyRootHash.Bytes() st.val = types.EmptyRootHash.Bytes()
st.key = st.key[:0] st.key = st.key[:0]
st.nodeType = hashedNode st.typ = hashedNode
return return
case branchNode: case branchNode:
@ -426,109 +364,113 @@ func (st *StackTrie) hashRec(hasher *hasher, path []byte) {
nodes.Children[i] = nilValueNode nodes.Children[i] = nilValueNode
continue continue
} }
child.hashRec(hasher, append(path, byte(i))) t.hash(child, append(path, byte(i)))
if len(child.val) < 32 { if len(child.val) < 32 {
nodes.Children[i] = rawNode(child.val) nodes.Children[i] = rawNode(child.val)
} else { } else {
nodes.Children[i] = hashNode(child.val) nodes.Children[i] = hashNode(child.val)
} }
// Release child back to pool.
st.children[i] = nil st.children[i] = nil
returnToPool(child) stPool.Put(child.reset()) // Release child back to pool.
} }
nodes.encode(t.h.encbuf)
nodes.encode(hasher.encbuf) blob = t.h.encodedBytes()
encodedNode = hasher.encodedBytes()
case extNode: case extNode:
st.children[0].hashRec(hasher, append(path, st.key...)) // recursively hash and commit child as the first step
t.hash(st.children[0], append(path, st.key...))
n := shortNode{Key: hexToCompact(st.key)} // Collect the path of internal nodes between shortNode and its **in disk**
// child. This is essential in the case of path mode scheme to avoid leaving
// danging nodes within the range of this internal path on disk, which would
// break the guarantee for state healing.
if len(st.children[0].val) >= 32 && t.options.Cleaner != nil {
for i := 1; i < len(st.key); i++ {
internal = append(internal, append(path, st.key[:i]...))
}
}
// encode the extension node
n := shortNode{Key: hexToCompactInPlace(st.key)}
if len(st.children[0].val) < 32 { if len(st.children[0].val) < 32 {
n.Val = rawNode(st.children[0].val) n.Val = rawNode(st.children[0].val)
} else { } else {
n.Val = hashNode(st.children[0].val) n.Val = hashNode(st.children[0].val)
} }
n.encode(t.h.encbuf)
blob = t.h.encodedBytes()
n.encode(hasher.encbuf) stPool.Put(st.children[0].reset()) // Release child back to pool.
encodedNode = hasher.encodedBytes()
// Release child back to pool.
returnToPool(st.children[0])
st.children[0] = nil st.children[0] = nil
case leafNode: case leafNode:
st.key = append(st.key, byte(16)) st.key = append(st.key, byte(16))
n := shortNode{Key: hexToCompact(st.key), Val: valueNode(st.val)} n := shortNode{Key: hexToCompactInPlace(st.key), Val: valueNode(st.val)}
n.encode(hasher.encbuf) n.encode(t.h.encbuf)
encodedNode = hasher.encodedBytes() blob = t.h.encodedBytes()
default: default:
panic("invalid node type") panic("invalid node type")
} }
st.nodeType = hashedNode st.typ = hashedNode
st.key = st.key[:0] st.key = st.key[:0]
if len(encodedNode) < 32 {
st.val = common.CopyBytes(encodedNode) // Skip committing the non-root node if the size is smaller than 32 bytes.
if len(blob) < 32 && len(path) > 0 {
st.val = common.CopyBytes(blob)
return return
} }
// Write the hash to the 'val'. We allocate a new val here to not mutate // Write the hash to the 'val'. We allocate a new val here to not mutate
// input values // input values.
st.val = hasher.hashData(encodedNode) st.val = t.h.hashData(blob)
if st.writeFn != nil {
st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) // Short circuit if the stack trie is not configured for writing.
if t.options.Writer == nil {
return
} }
// Skip committing if the node is on the left boundary and stackTrie is
// configured to filter the boundary.
if t.options.SkipLeftBoundary && bytes.HasPrefix(t.first, path) {
if t.options.boundaryGauge != nil {
t.options.boundaryGauge.Inc(1)
}
return
}
// Skip committing if the node is on the right boundary and stackTrie is
// configured to filter the boundary.
if t.options.SkipRightBoundary && bytes.HasPrefix(t.last, path) {
if t.options.boundaryGauge != nil {
t.options.boundaryGauge.Inc(1)
}
return
}
// Clean up the internal dangling nodes covered by the extension node.
// This should be done before writing the node to adhere to the committing
// order from bottom to top.
for _, path := range internal {
t.options.Cleaner(path)
}
t.options.Writer(path, common.BytesToHash(st.val), blob)
} }
// Hash returns the hash of the current node. // Hash will firstly hash the entire trie if it's still not hashed and then commit
func (st *StackTrie) Hash() (h common.Hash) { // all nodes to the associated database. Actually most of the trie nodes have been
hasher := newHasher(false) // committed already. The main purpose here is to commit the nodes on right boundary.
defer returnHasherToPool(hasher)
st.hashRec(hasher, nil)
if len(st.val) == 32 {
copy(h[:], st.val)
return h
}
// If the node's RLP isn't 32 bytes long, the node will not
// be hashed, and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing.
hasher.sha.Reset()
hasher.sha.Write(st.val)
hasher.sha.Read(h[:])
return h
}
// Commit will firstly hash the entire trie if it's still not hashed
// and then commit all nodes to the associated database. Actually most
// of the trie nodes MAY have been committed already. The main purpose
// here is to commit the root node.
// //
// The associated database is expected, otherwise the whole commit // For stack trie, Hash and Commit are functionally identical.
// functionality should be disabled. func (t *StackTrie) Hash() common.Hash {
func (st *StackTrie) Commit() (h common.Hash, err error) { n := t.root
if st.writeFn == nil { t.hash(n, nil)
return common.Hash{}, ErrCommitDisabled return common.BytesToHash(n.val)
} }
hasher := newHasher(false)
defer returnHasherToPool(hasher) // Commit will firstly hash the entire trie if it's still not hashed and then commit
// all nodes to the associated database. Actually most of the trie nodes have been
st.hashRec(hasher, nil) // committed already. The main purpose here is to commit the nodes on right boundary.
if len(st.val) == 32 { //
copy(h[:], st.val) // For stack trie, Hash and Commit are functionally identical.
return h, nil func (t *StackTrie) Commit() common.Hash {
} return t.Hash()
// If the node's RLP isn't 32 bytes long, the node will not
// be hashed (and committed), and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing+commit.
hasher.sha.Reset()
hasher.sha.Write(st.val)
hasher.sha.Read(h[:])
st.writeFn(st.owner, nil, h, st.val)
return h, nil
} }

@ -19,11 +19,14 @@ package trie
import ( import (
"bytes" "bytes"
"math/big" "math/big"
"math/rand"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/trie/testutil"
"golang.org/x/exp/slices"
) )
func TestStackTrieInsertAndHash(t *testing.T) { func TestStackTrieInsertAndHash(t *testing.T) {
@ -166,12 +169,11 @@ func TestStackTrieInsertAndHash(t *testing.T) {
{"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"}, {"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"},
}, },
} }
st := NewStackTrie(nil)
for i, test := range tests { for i, test := range tests {
// The StackTrie does not allow Insert(), Hash(), Insert(), ... // The StackTrie does not allow Insert(), Hash(), Insert(), ...
// so we will create new trie for every sequence length of inserts. // so we will create new trie for every sequence length of inserts.
for l := 1; l <= len(test); l++ { for l := 1; l <= len(test); l++ {
st.Reset() st := NewStackTrie(nil)
for j := 0; j < l; j++ { for j := 0; j < l; j++ {
kv := &test[j] kv := &test[j]
if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil { if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil {
@ -346,47 +348,86 @@ func TestStacktrieNotModifyValues(t *testing.T) {
} }
} }
// TestStacktrieSerialization tests that the stacktrie works well if we func buildPartialTree(entries []*kv, t *testing.T) map[string]common.Hash {
// serialize/unserialize it a lot
func TestStacktrieSerialization(t *testing.T) {
var ( var (
st = NewStackTrie(nil) options = NewStackTrieOptions()
nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), nil)) nodes = make(map[string]common.Hash)
keyB = big.NewInt(1)
keyDelta = big.NewInt(1)
vals [][]byte
keys [][]byte
) )
getValue := func(i int) []byte { var (
if i%2 == 0 { // large first int
return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) last = len(entries) - 1
} else { //small
return big.NewInt(int64(i)).Bytes()
}
}
for i := 0; i < 10; i++ {
vals = append(vals, getValue(i))
keys = append(keys, common.BigToHash(keyB).Bytes())
keyB = keyB.Add(keyB, keyDelta)
keyDelta.Add(keyDelta, common.Big1)
}
for i, k := range keys {
nt.Update(k, common.CopyBytes(vals[i]))
}
for i, k := range keys { noLeft bool
blob, err := st.MarshalBinary() noRight bool
if err != nil { )
t.Fatal(err) // Enter split mode if there are at least two elements
if rand.Intn(5) != 0 {
for {
first = rand.Intn(len(entries))
last = rand.Intn(len(entries))
if first <= last {
break
}
} }
newSt, err := NewFromBinary(blob, nil) if first != 0 {
if err != nil { noLeft = true
t.Fatal(err) }
if last != len(entries)-1 {
noRight = true
} }
st = newSt
st.Update(k, common.CopyBytes(vals[i]))
} }
if have, want := st.Hash(), nt.Hash(); have != want { options = options.WithSkipBoundary(noLeft, noRight, nil)
t.Fatalf("have %#x want %#x", have, want) options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
nodes[string(path)] = hash
})
tr := NewStackTrie(options)
for i := first; i <= last; i++ {
tr.MustUpdate(entries[i].k, entries[i].v)
}
tr.Commit()
return nodes
}
func TestPartialStackTrie(t *testing.T) {
for round := 0; round < 100; round++ {
var (
n = rand.Intn(100) + 1
entries []*kv
)
for i := 0; i < n; i++ {
var val []byte
if rand.Intn(3) == 0 {
val = testutil.RandBytes(3)
} else {
val = testutil.RandBytes(32)
}
entries = append(entries, &kv{
k: testutil.RandBytes(32),
v: val,
})
}
slices.SortFunc(entries, (*kv).cmp)
var (
nodes = make(map[string]common.Hash)
options = NewStackTrieOptions().WithWriter(func(path []byte, hash common.Hash, blob []byte) {
nodes[string(path)] = hash
})
)
tr := NewStackTrie(options)
for i := 0; i < len(entries); i++ {
tr.MustUpdate(entries[i].k, entries[i].v)
}
tr.Commit()
for j := 0; j < 100; j++ {
for path, hash := range buildPartialTree(entries, t) {
if nodes[path] != hash {
t.Errorf("%v, want %x, got %x", []byte(path), nodes[path], hash)
}
}
}
} }
} }

@ -51,6 +51,18 @@ var (
// lookupGauge is the metric to track how many trie node lookups are // lookupGauge is the metric to track how many trie node lookups are
// performed to determine if node needs to be deleted. // performed to determine if node needs to be deleted.
lookupGauge = metrics.NewRegisteredGauge("trie/sync/lookup", nil) lookupGauge = metrics.NewRegisteredGauge("trie/sync/lookup", nil)
// accountNodeSyncedGauge is the metric to track how many account trie
// node are written during the sync.
accountNodeSyncedGauge = metrics.NewRegisteredGauge("trie/sync/nodes/account", nil)
// storageNodeSyncedGauge is the metric to track how many account trie
// node are written during the sync.
storageNodeSyncedGauge = metrics.NewRegisteredGauge("trie/sync/nodes/storage", nil)
// codeSyncedGauge is the metric to track how many contract codes are
// written during the sync.
codeSyncedGauge = metrics.NewRegisteredGauge("trie/sync/codes", nil)
) )
// SyncPath is a path tuple identifying a particular trie node either in a single // SyncPath is a path tuple identifying a particular trie node either in a single
@ -362,10 +374,22 @@ func (s *Sync) ProcessNode(result NodeSyncResult) error {
// storage, returning any occurred error. // storage, returning any occurred error.
func (s *Sync) Commit(dbw ethdb.Batch) error { func (s *Sync) Commit(dbw ethdb.Batch) error {
// Flush the pending node writes into database batch. // Flush the pending node writes into database batch.
var (
account int
storage int
)
for path, value := range s.membatch.nodes { for path, value := range s.membatch.nodes {
owner, inner := ResolvePath([]byte(path)) owner, inner := ResolvePath([]byte(path))
if owner == (common.Hash{}) {
account += 1
} else {
storage += 1
}
rawdb.WriteTrieNode(dbw, owner, inner, s.membatch.hashes[path], value, s.scheme) rawdb.WriteTrieNode(dbw, owner, inner, s.membatch.hashes[path], value, s.scheme)
} }
accountNodeSyncedGauge.Inc(int64(account))
storageNodeSyncedGauge.Inc(int64(storage))
// Flush the pending node deletes into the database batch. // Flush the pending node deletes into the database batch.
// Please note that each written and deleted node has a // Please note that each written and deleted node has a
// unique path, ensuring no duplication occurs. // unique path, ensuring no duplication occurs.
@ -377,6 +401,8 @@ func (s *Sync) Commit(dbw ethdb.Batch) error {
for hash, value := range s.membatch.codes { for hash, value := range s.membatch.codes {
rawdb.WriteCode(dbw, hash, value) rawdb.WriteCode(dbw, hash, value)
} }
codeSyncedGauge.Inc(int64(len(s.membatch.codes)))
s.membatch = newSyncMemBatch() // reset the batch s.membatch = newSyncMemBatch() // reset the batch
return nil return nil
} }

@ -908,9 +908,12 @@ func TestCommitSequenceStackTrie(t *testing.T) {
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used for the stacktrie commits // Another sponge is used for the stacktrie commits
stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"}
stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(stackTrieSponge, owner, path, hash, blob, db.Scheme()) options := NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(stackTrieSponge, common.Hash{}, path, hash, blob, db.Scheme())
}) })
stTrie := NewStackTrie(options)
// Fill the trie with elements // Fill the trie with elements
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
// For the stack trie, we need to do inserts in proper order // For the stack trie, we need to do inserts in proper order
@ -933,10 +936,7 @@ func TestCommitSequenceStackTrie(t *testing.T) {
db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil)
db.Commit(root, false) db.Commit(root, false)
// And flush stacktrie -> disk // And flush stacktrie -> disk
stRoot, err := stTrie.Commit() stRoot := stTrie.Commit()
if err != nil {
t.Fatalf("Failed to commit stack trie %v", err)
}
if stRoot != root { if stRoot != root {
t.Fatalf("root wrong, got %x exp %x", stRoot, root) t.Fatalf("root wrong, got %x exp %x", stRoot, root)
} }
@ -967,9 +967,12 @@ func TestCommitSequenceSmallRoot(t *testing.T) {
trie := NewEmpty(db) trie := NewEmpty(db)
// Another sponge is used for the stacktrie commits // Another sponge is used for the stacktrie commits
stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"}
stTrie := NewStackTrie(func(owner common.Hash, path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(stackTrieSponge, owner, path, hash, blob, db.Scheme()) options := NewStackTrieOptions()
options = options.WithWriter(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(stackTrieSponge, common.Hash{}, path, hash, blob, db.Scheme())
}) })
stTrie := NewStackTrie(options)
// Add a single small-element to the trie(s) // Add a single small-element to the trie(s)
key := make([]byte, 5) key := make([]byte, 5)
key[0] = 1 key[0] = 1
@ -981,10 +984,7 @@ func TestCommitSequenceSmallRoot(t *testing.T) {
db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil) db.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes), nil)
db.Commit(root, false) db.Commit(root, false)
// And flush stacktrie -> disk // And flush stacktrie -> disk
stRoot, err := stTrie.Commit() stRoot := stTrie.Commit()
if err != nil {
t.Fatalf("Failed to commit stack trie %v", err)
}
if stRoot != root { if stRoot != root {
t.Fatalf("root wrong, got %x exp %x", stRoot, root) t.Fatalf("root wrong, got %x exp %x", stRoot, root)
} }

@ -571,7 +571,16 @@ func truncateFromHead(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, nhead
if err != nil { if err != nil {
return 0, err return 0, err
} }
if ohead <= nhead { otail, err := freezer.Tail()
if err != nil {
return 0, err
}
// Ensure that the truncation target falls within the specified range.
if ohead < nhead || nhead < otail {
return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, nhead)
}
// Short circuit if nothing to truncate.
if ohead == nhead {
return 0, nil return 0, nil
} }
// Load the meta objects in range [nhead+1, ohead] // Load the meta objects in range [nhead+1, ohead]
@ -600,11 +609,20 @@ func truncateFromHead(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, nhead
// truncateFromTail removes the extra state histories from the tail with the given // truncateFromTail removes the extra state histories from the tail with the given
// parameters. It returns the number of items removed from the tail. // parameters. It returns the number of items removed from the tail.
func truncateFromTail(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, ntail uint64) (int, error) { func truncateFromTail(db ethdb.Batcher, freezer *rawdb.ResettableFreezer, ntail uint64) (int, error) {
ohead, err := freezer.Ancients()
if err != nil {
return 0, err
}
otail, err := freezer.Tail() otail, err := freezer.Tail()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if otail >= ntail { // Ensure that the truncation target falls within the specified range.
if otail > ntail || ntail > ohead {
return 0, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", otail, ohead, ntail)
}
// Short circuit if nothing to truncate.
if otail == ntail {
return 0, nil return 0, nil
} }
// Load the meta objects in range [otail+1, ntail] // Load the meta objects in range [otail+1, ntail]

@ -224,6 +224,50 @@ func TestTruncateTailHistories(t *testing.T) {
} }
} }
func TestTruncateOutOfRange(t *testing.T) {
var (
hs = makeHistories(10)
db = rawdb.NewMemoryDatabase()
freezer, _ = openFreezer(t.TempDir(), false)
)
defer freezer.Close()
for i := 0; i < len(hs); i++ {
accountData, storageData, accountIndex, storageIndex := hs[i].encode()
rawdb.WriteStateHistory(freezer, uint64(i+1), hs[i].meta.encode(), accountIndex, storageIndex, accountData, storageData)
rawdb.WriteStateID(db, hs[i].meta.root, uint64(i+1))
}
truncateFromTail(db, freezer, uint64(len(hs)/2))
// Ensure of-out-range truncations are rejected correctly.
head, _ := freezer.Ancients()
tail, _ := freezer.Tail()
cases := []struct {
mode int
target uint64
expErr error
}{
{0, head, nil}, // nothing to delete
{0, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
{0, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
{1, tail, nil}, // nothing to delete
{1, head + 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, head+1)},
{1, tail - 1, fmt.Errorf("out of range, tail: %d, head: %d, target: %d", tail, head, tail-1)},
}
for _, c := range cases {
var gotErr error
if c.mode == 0 {
_, gotErr = truncateFromHead(db, freezer, c.target)
} else {
_, gotErr = truncateFromTail(db, freezer, c.target)
}
if !reflect.DeepEqual(gotErr, c.expErr) {
t.Errorf("Unexpected error, want: %v, got: %v", c.expErr, gotErr)
}
}
}
// openFreezer initializes the freezer instance for storing state histories. // openFreezer initializes the freezer instance for storing state histories.
func openFreezer(datadir string, readOnly bool) (*rawdb.ResettableFreezer, error) { func openFreezer(datadir string, readOnly bool) (*rawdb.ResettableFreezer, error) {
return rawdb.NewStateFreezer(datadir, readOnly, 0) return rawdb.NewStateFreezer(datadir, readOnly, 0)