core/state: better randomized testing (postcheck) on journalling (#29627)
This PR fixes some flaws with the existing tests. The randomized testing (TestSnapshotRandom) executes a series of steps which modify the state and create journal-events. Later on, we compare the forward-going-states against the backwards-unrolling-journal-states, and check that they are identical. The "identical" check is performed using various accessors. It turned out that we failed to check some things: - the accesslist contents - the transient storage contents - the 'newContract' flag - the dirty storage map This change adds these new checks
This commit is contained in:
parent
a13b92524d
commit
243cde0f54
@ -17,7 +17,10 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
|
||||
func (al *accessList) DeleteAddress(address common.Address) {
|
||||
delete(al.addresses, address)
|
||||
}
|
||||
|
||||
// Equal returns true if the two access lists are identical
|
||||
func (al *accessList) Equal(other *accessList) bool {
|
||||
if !maps.Equal(al.addresses, other.addresses) {
|
||||
return false
|
||||
}
|
||||
return slices.EqualFunc(al.slots, other.slots,
|
||||
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
|
||||
return maps.Equal(m, m2)
|
||||
})
|
||||
}
|
||||
|
||||
// PrettyPrint prints the contents of the access list in a human-readable form
|
||||
func (al *accessList) PrettyPrint() string {
|
||||
out := new(strings.Builder)
|
||||
var sortedAddrs []common.Address
|
||||
for addr := range al.addresses {
|
||||
sortedAddrs = append(sortedAddrs, addr)
|
||||
}
|
||||
slices.SortFunc(sortedAddrs, common.Address.Cmp)
|
||||
for _, addr := range sortedAddrs {
|
||||
idx := al.addresses[addr]
|
||||
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
|
||||
if idx >= 0 {
|
||||
slotmap := al.slots[idx]
|
||||
for h := range slotmap {
|
||||
fmt.Fprintf(out, " %#x\n", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
@ -464,17 +464,17 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject {
|
||||
addrHash: s.addrHash,
|
||||
origin: s.origin,
|
||||
data: s.data,
|
||||
code: s.code,
|
||||
originStorage: s.originStorage.Copy(),
|
||||
pendingStorage: s.pendingStorage.Copy(),
|
||||
dirtyStorage: s.dirtyStorage.Copy(),
|
||||
dirtyCode: s.dirtyCode,
|
||||
selfDestructed: s.selfDestructed,
|
||||
newContract: s.newContract,
|
||||
}
|
||||
if s.trie != nil {
|
||||
obj.trie = db.db.CopyTrie(s.trie)
|
||||
}
|
||||
obj.code = s.code
|
||||
obj.originStorage = s.originStorage.Copy()
|
||||
obj.pendingStorage = s.pendingStorage.Copy()
|
||||
obj.dirtyStorage = s.dirtyStorage.Copy()
|
||||
obj.dirtyCode = s.dirtyCode
|
||||
obj.selfDestructed = s.selfDestructed
|
||||
obj.newContract = s.newContract
|
||||
return obj
|
||||
}
|
||||
|
||||
|
@ -21,9 +21,11 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
it := trie.NewIterator(trieIt)
|
||||
var (
|
||||
it = trie.NewIterator(trieIt)
|
||||
visited = make(map[common.Hash]bool)
|
||||
)
|
||||
|
||||
for it.Next() {
|
||||
key := common.BytesToHash(s.trie.GetKey(it.Key))
|
||||
visited[key] = true
|
||||
if value, dirty := so.dirtyStorage[key]; dirty {
|
||||
if !cb(key, value) {
|
||||
return nil
|
||||
@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
||||
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
|
||||
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
|
||||
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
|
||||
// Check newContract-flag
|
||||
if obj := state.getStateObject(addr); obj != nil {
|
||||
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
|
||||
}
|
||||
// Check storage.
|
||||
if obj := state.getStateObject(addr); obj != nil {
|
||||
forEachStorage(state, addr, func(key, value common.Hash) bool {
|
||||
@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
||||
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
|
||||
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
|
||||
})
|
||||
other := checkstate.getStateObject(addr)
|
||||
// Check dirty storage which is not in trie
|
||||
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
|
||||
print := func(dirty map[common.Hash]common.Hash) string {
|
||||
var keys []common.Hash
|
||||
out := new(strings.Builder)
|
||||
for key := range dirty {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
slices.SortFunc(keys, common.Hash.Cmp)
|
||||
for i, key := range keys {
|
||||
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
|
||||
print(obj.dirtyStorage),
|
||||
print(other.dirtyStorage))
|
||||
}
|
||||
}
|
||||
// Check transient storage.
|
||||
{
|
||||
have := state.transientStorage
|
||||
want := checkstate.transientStorage
|
||||
eq := maps.EqualFunc(have, want,
|
||||
func(a Storage, b Storage) bool {
|
||||
return maps.Equal(a, b)
|
||||
})
|
||||
if !eq {
|
||||
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
|
||||
have.PrettyPrint(),
|
||||
want.PrettyPrint())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !checkstate.accessList.Equal(state.accessList) { // Check access lists
|
||||
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
|
||||
checkstate.accessList.PrettyPrint(),
|
||||
state.accessList.PrettyPrint())
|
||||
}
|
||||
if state.GetRefund() != checkstate.GetRefund() {
|
||||
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
|
||||
state.GetRefund(), checkstate.GetRefund())
|
||||
@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
||||
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
|
||||
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
|
||||
}
|
||||
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
|
||||
getKeys := func(dirty map[common.Address]int) string {
|
||||
var keys []common.Address
|
||||
out := new(strings.Builder)
|
||||
for key := range dirty {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
slices.SortFunc(keys, common.Address.Cmp)
|
||||
for i, key := range keys {
|
||||
fmt.Fprintf(out, " %d. %v\n", i, key)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
have := getKeys(state.journal.dirties)
|
||||
want := getKeys(checkstate.journal.dirties)
|
||||
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,10 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
@ -30,11 +34,20 @@ func newTransientStorage() transientStorage {
|
||||
|
||||
// Set sets the transient-storage `value` for `key` at the given `addr`.
|
||||
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
|
||||
if value == (common.Hash{}) { // this is a 'delete'
|
||||
if _, ok := t[addr]; ok {
|
||||
delete(t[addr], key)
|
||||
if len(t[addr]) == 0 {
|
||||
delete(t, addr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := t[addr]; !ok {
|
||||
t[addr] = make(Storage)
|
||||
}
|
||||
t[addr][key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets the transient storage for `key` at the given `addr`.
|
||||
func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash {
|
||||
@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
|
||||
}
|
||||
return storage
|
||||
}
|
||||
|
||||
// PrettyPrint prints the contents of the access list in a human-readable form
|
||||
func (t transientStorage) PrettyPrint() string {
|
||||
out := new(strings.Builder)
|
||||
var sortedAddrs []common.Address
|
||||
for addr := range t {
|
||||
sortedAddrs = append(sortedAddrs, addr)
|
||||
slices.SortFunc(sortedAddrs, common.Address.Cmp)
|
||||
}
|
||||
|
||||
for _, addr := range sortedAddrs {
|
||||
fmt.Fprintf(out, "%#x:", addr)
|
||||
var sortedKeys []common.Hash
|
||||
storage := t[addr]
|
||||
for key := range storage {
|
||||
sortedKeys = append(sortedKeys, key)
|
||||
}
|
||||
slices.SortFunc(sortedKeys, common.Hash.Cmp)
|
||||
for _, key := range sortedKeys {
|
||||
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user