core/state: better randomized testing (postcheck) on journalling (#29627)
This PR fixes some flaws with the existing tests. The randomized testing (TestSnapshotRandom) executes a series of steps which modify the state and create journal-events. Later on, we compare the forward-going-states against the backwards-unrolling-journal-states, and check that they are identical. The "identical" check is performed using various accessors. It turned out that we failed to check some things: - the accesslist contents - the transient storage contents - the 'newContract' flag - the dirty storage map This change adds these new checks
This commit is contained in:
parent
a13b92524d
commit
243cde0f54
@ -17,7 +17,10 @@
|
|||||||
package state
|
package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common"
|
"github.com/ethereum/go-ethereum/common"
|
||||||
)
|
)
|
||||||
@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
|
|||||||
func (al *accessList) DeleteAddress(address common.Address) {
|
func (al *accessList) DeleteAddress(address common.Address) {
|
||||||
delete(al.addresses, address)
|
delete(al.addresses, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Equal returns true if the two access lists are identical
|
||||||
|
func (al *accessList) Equal(other *accessList) bool {
|
||||||
|
if !maps.Equal(al.addresses, other.addresses) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.EqualFunc(al.slots, other.slots,
|
||||||
|
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
|
||||||
|
return maps.Equal(m, m2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrettyPrint prints the contents of the access list in a human-readable form
|
||||||
|
func (al *accessList) PrettyPrint() string {
|
||||||
|
out := new(strings.Builder)
|
||||||
|
var sortedAddrs []common.Address
|
||||||
|
for addr := range al.addresses {
|
||||||
|
sortedAddrs = append(sortedAddrs, addr)
|
||||||
|
}
|
||||||
|
slices.SortFunc(sortedAddrs, common.Address.Cmp)
|
||||||
|
for _, addr := range sortedAddrs {
|
||||||
|
idx := al.addresses[addr]
|
||||||
|
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
|
||||||
|
if idx >= 0 {
|
||||||
|
slotmap := al.slots[idx]
|
||||||
|
for h := range slotmap {
|
||||||
|
fmt.Fprintf(out, " %#x\n", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) {
|
|||||||
|
|
||||||
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
|
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
|
||||||
obj := &stateObject{
|
obj := &stateObject{
|
||||||
db: db,
|
db: db,
|
||||||
address: s.address,
|
address: s.address,
|
||||||
addrHash: s.addrHash,
|
addrHash: s.addrHash,
|
||||||
origin: s.origin,
|
origin: s.origin,
|
||||||
data: s.data,
|
data: s.data,
|
||||||
|
code: s.code,
|
||||||
|
originStorage: s.originStorage.Copy(),
|
||||||
|
pendingStorage: s.pendingStorage.Copy(),
|
||||||
|
dirtyStorage: s.dirtyStorage.Copy(),
|
||||||
|
dirtyCode: s.dirtyCode,
|
||||||
|
selfDestructed: s.selfDestructed,
|
||||||
|
newContract: s.newContract,
|
||||||
}
|
}
|
||||||
if s.trie != nil {
|
if s.trie != nil {
|
||||||
obj.trie = db.db.CopyTrie(s.trie)
|
obj.trie = db.db.CopyTrie(s.trie)
|
||||||
}
|
}
|
||||||
obj.code = s.code
|
|
||||||
obj.originStorage = s.originStorage.Copy()
|
|
||||||
obj.pendingStorage = s.pendingStorage.Copy()
|
|
||||||
obj.dirtyStorage = s.dirtyStorage.Copy()
|
|
||||||
obj.dirtyCode = s.dirtyCode
|
|
||||||
obj.selfDestructed = s.selfDestructed
|
|
||||||
obj.newContract = s.newContract
|
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,9 +21,11 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
it := trie.NewIterator(trieIt)
|
var (
|
||||||
|
it = trie.NewIterator(trieIt)
|
||||||
|
visited = make(map[common.Hash]bool)
|
||||||
|
)
|
||||||
|
|
||||||
for it.Next() {
|
for it.Next() {
|
||||||
key := common.BytesToHash(s.trie.GetKey(it.Key))
|
key := common.BytesToHash(s.trie.GetKey(it.Key))
|
||||||
|
visited[key] = true
|
||||||
if value, dirty := so.dirtyStorage[key]; dirty {
|
if value, dirty := so.dirtyStorage[key]; dirty {
|
||||||
if !cb(key, value) {
|
if !cb(key, value) {
|
||||||
return nil
|
return nil
|
||||||
@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
|||||||
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
|
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
|
||||||
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
|
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
|
||||||
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
|
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
|
||||||
|
// Check newContract-flag
|
||||||
|
if obj := state.getStateObject(addr); obj != nil {
|
||||||
|
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
|
||||||
|
}
|
||||||
// Check storage.
|
// Check storage.
|
||||||
if obj := state.getStateObject(addr); obj != nil {
|
if obj := state.getStateObject(addr); obj != nil {
|
||||||
forEachStorage(state, addr, func(key, value common.Hash) bool {
|
forEachStorage(state, addr, func(key, value common.Hash) bool {
|
||||||
@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
|||||||
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
|
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
|
||||||
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
|
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
|
||||||
})
|
})
|
||||||
|
other := checkstate.getStateObject(addr)
|
||||||
|
// Check dirty storage which is not in trie
|
||||||
|
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
|
||||||
|
print := func(dirty map[common.Hash]common.Hash) string {
|
||||||
|
var keys []common.Hash
|
||||||
|
out := new(strings.Builder)
|
||||||
|
for key := range dirty {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
slices.SortFunc(keys, common.Hash.Cmp)
|
||||||
|
for i, key := range keys {
|
||||||
|
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
|
||||||
|
print(obj.dirtyStorage),
|
||||||
|
print(other.dirtyStorage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check transient storage.
|
||||||
|
{
|
||||||
|
have := state.transientStorage
|
||||||
|
want := checkstate.transientStorage
|
||||||
|
eq := maps.EqualFunc(have, want,
|
||||||
|
func(a Storage, b Storage) bool {
|
||||||
|
return maps.Equal(a, b)
|
||||||
|
})
|
||||||
|
if !eq {
|
||||||
|
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
|
||||||
|
have.PrettyPrint(),
|
||||||
|
want.PrettyPrint())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !checkstate.accessList.Equal(state.accessList) { // Check access lists
|
||||||
|
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
|
||||||
|
checkstate.accessList.PrettyPrint(),
|
||||||
|
state.accessList.PrettyPrint())
|
||||||
|
}
|
||||||
if state.GetRefund() != checkstate.GetRefund() {
|
if state.GetRefund() != checkstate.GetRefund() {
|
||||||
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
|
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
|
||||||
state.GetRefund(), checkstate.GetRefund())
|
state.GetRefund(), checkstate.GetRefund())
|
||||||
@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
|
|||||||
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
|
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
|
||||||
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
|
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
|
||||||
}
|
}
|
||||||
|
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
|
||||||
|
getKeys := func(dirty map[common.Address]int) string {
|
||||||
|
var keys []common.Address
|
||||||
|
out := new(strings.Builder)
|
||||||
|
for key := range dirty {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
slices.SortFunc(keys, common.Address.Cmp)
|
||||||
|
for i, key := range keys {
|
||||||
|
fmt.Fprintf(out, " %d. %v\n", i, key)
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
have := getKeys(state.journal.dirties)
|
||||||
|
want := getKeys(checkstate.journal.dirties)
|
||||||
|
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,10 @@
|
|||||||
package state
|
package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common"
|
"github.com/ethereum/go-ethereum/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,10 +34,19 @@ func newTransientStorage() transientStorage {
|
|||||||
|
|
||||||
// Set sets the transient-storage `value` for `key` at the given `addr`.
|
// Set sets the transient-storage `value` for `key` at the given `addr`.
|
||||||
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
|
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
|
||||||
if _, ok := t[addr]; !ok {
|
if value == (common.Hash{}) { // this is a 'delete'
|
||||||
t[addr] = make(Storage)
|
if _, ok := t[addr]; ok {
|
||||||
|
delete(t[addr], key)
|
||||||
|
if len(t[addr]) == 0 {
|
||||||
|
delete(t, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, ok := t[addr]; !ok {
|
||||||
|
t[addr] = make(Storage)
|
||||||
|
}
|
||||||
|
t[addr][key] = value
|
||||||
}
|
}
|
||||||
t[addr][key] = value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets the transient storage for `key` at the given `addr`.
|
// Get gets the transient storage for `key` at the given `addr`.
|
||||||
@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
|
|||||||
}
|
}
|
||||||
return storage
|
return storage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrettyPrint prints the contents of the access list in a human-readable form
|
||||||
|
func (t transientStorage) PrettyPrint() string {
|
||||||
|
out := new(strings.Builder)
|
||||||
|
var sortedAddrs []common.Address
|
||||||
|
for addr := range t {
|
||||||
|
sortedAddrs = append(sortedAddrs, addr)
|
||||||
|
slices.SortFunc(sortedAddrs, common.Address.Cmp)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range sortedAddrs {
|
||||||
|
fmt.Fprintf(out, "%#x:", addr)
|
||||||
|
var sortedKeys []common.Hash
|
||||||
|
storage := t[addr]
|
||||||
|
for key := range storage {
|
||||||
|
sortedKeys = append(sortedKeys, key)
|
||||||
|
}
|
||||||
|
slices.SortFunc(sortedKeys, common.Hash.Cmp)
|
||||||
|
for _, key := range sortedKeys {
|
||||||
|
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user