core/state: better randomized testing (postcheck) on journalling (#29627)

This PR fixes some flaws with the existing tests.

The randomized testing (TestSnapshotRandom) executes a series of steps which modify the state and create journal-events. Later on, we compare the forward-going-states against the backwards-unrolling-journal-states, and check that they are identical.

The "identical" check is performed using various accessors. It turned out that we failed to check some things: 
- the accesslist contents
- the transient storage contents
- the 'newContract' flag
- the dirty storage map

This change adds these new checks
This commit is contained in:
Martin HS 2024-04-25 09:56:25 +02:00 committed by GitHub
parent a13b92524d
commit 243cde0f54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 153 additions and 17 deletions

@ -17,7 +17,10 @@
package state package state
import ( import (
"fmt"
"maps" "maps"
"slices"
"strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
func (al *accessList) DeleteAddress(address common.Address) { func (al *accessList) DeleteAddress(address common.Address) {
delete(al.addresses, address) delete(al.addresses, address)
} }
// Equal returns true if the two access lists are identical
func (al *accessList) Equal(other *accessList) bool {
if !maps.Equal(al.addresses, other.addresses) {
return false
}
return slices.EqualFunc(al.slots, other.slots,
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
return maps.Equal(m, m2)
})
}
// PrettyPrint prints the contents of the access list in a human-readable form
func (al *accessList) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range al.addresses {
sortedAddrs = append(sortedAddrs, addr)
}
slices.SortFunc(sortedAddrs, common.Address.Cmp)
for _, addr := range sortedAddrs {
idx := al.addresses[addr]
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
if idx >= 0 {
slotmap := al.slots[idx]
for h := range slotmap {
fmt.Fprintf(out, " %#x\n", h)
}
}
}
return out.String()
}

@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) {
func (s *stateObject) deepCopy(db *StateDB) *stateObject { func (s *stateObject) deepCopy(db *StateDB) *stateObject {
obj := &stateObject{ obj := &stateObject{
db: db, db: db,
address: s.address, address: s.address,
addrHash: s.addrHash, addrHash: s.addrHash,
origin: s.origin, origin: s.origin,
data: s.data, data: s.data,
code: s.code,
originStorage: s.originStorage.Copy(),
pendingStorage: s.pendingStorage.Copy(),
dirtyStorage: s.dirtyStorage.Copy(),
dirtyCode: s.dirtyCode,
selfDestructed: s.selfDestructed,
newContract: s.newContract,
} }
if s.trie != nil { if s.trie != nil {
obj.trie = db.db.CopyTrie(s.trie) obj.trie = db.db.CopyTrie(s.trie)
} }
obj.code = s.code
obj.originStorage = s.originStorage.Copy()
obj.pendingStorage = s.pendingStorage.Copy()
obj.dirtyStorage = s.dirtyStorage.Copy()
obj.dirtyCode = s.dirtyCode
obj.selfDestructed = s.selfDestructed
obj.newContract = s.newContract
return obj return obj
} }

@ -21,9 +21,11 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"maps"
"math" "math"
"math/rand" "math/rand"
"reflect" "reflect"
"slices"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
if err != nil { if err != nil {
return err return err
} }
it := trie.NewIterator(trieIt) var (
it = trie.NewIterator(trieIt)
visited = make(map[common.Hash]bool)
)
for it.Next() { for it.Next() {
key := common.BytesToHash(s.trie.GetKey(it.Key)) key := common.BytesToHash(s.trie.GetKey(it.Key))
visited[key] = true
if value, dirty := so.dirtyStorage[key]; dirty { if value, dirty := so.dirtyStorage[key]; dirty {
if !cb(key, value) { if !cb(key, value) {
return nil return nil
@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check newContract-flag
if obj := state.getStateObject(addr); obj != nil {
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
}
// Check storage. // Check storage.
if obj := state.getStateObject(addr); obj != nil { if obj := state.getStateObject(addr); obj != nil {
forEachStorage(state, addr, func(key, value common.Hash) bool { forEachStorage(state, addr, func(key, value common.Hash) bool {
@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
forEachStorage(checkstate, addr, func(key, value common.Hash) bool { forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
}) })
other := checkstate.getStateObject(addr)
// Check dirty storage which is not in trie
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
print := func(dirty map[common.Hash]common.Hash) string {
var keys []common.Hash
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Hash.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
}
return out.String()
}
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
print(obj.dirtyStorage),
print(other.dirtyStorage))
}
}
// Check transient storage.
{
have := state.transientStorage
want := checkstate.transientStorage
eq := maps.EqualFunc(have, want,
func(a Storage, b Storage) bool {
return maps.Equal(a, b)
})
if !eq {
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
have.PrettyPrint(),
want.PrettyPrint())
}
} }
if err != nil { if err != nil {
return err return err
} }
} }
if !checkstate.accessList.Equal(state.accessList) { // Check access lists
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
checkstate.accessList.PrettyPrint(),
state.accessList.PrettyPrint())
}
if state.GetRefund() != checkstate.GetRefund() { if state.GetRefund() != checkstate.GetRefund() {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund()) state.GetRefund(), checkstate.GetRefund())
@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
} }
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
getKeys := func(dirty map[common.Address]int) string {
var keys []common.Address
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Address.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v\n", i, key)
}
return out.String()
}
have := getKeys(state.journal.dirties)
want := getKeys(checkstate.journal.dirties)
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
}
return nil return nil
} }

@ -17,6 +17,10 @@
package state package state
import ( import (
"fmt"
"slices"
"strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
@ -30,10 +34,19 @@ func newTransientStorage() transientStorage {
// Set sets the transient-storage `value` for `key` at the given `addr`. // Set sets the transient-storage `value` for `key` at the given `addr`.
func (t transientStorage) Set(addr common.Address, key, value common.Hash) { func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
if _, ok := t[addr]; !ok { if value == (common.Hash{}) { // this is a 'delete'
t[addr] = make(Storage) if _, ok := t[addr]; ok {
delete(t[addr], key)
if len(t[addr]) == 0 {
delete(t, addr)
}
}
} else {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
}
t[addr][key] = value
} }
t[addr][key] = value
} }
// Get gets the transient storage for `key` at the given `addr`. // Get gets the transient storage for `key` at the given `addr`.
@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
} }
return storage return storage
} }
// PrettyPrint prints the contents of the access list in a human-readable form
func (t transientStorage) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range t {
sortedAddrs = append(sortedAddrs, addr)
slices.SortFunc(sortedAddrs, common.Address.Cmp)
}
for _, addr := range sortedAddrs {
fmt.Fprintf(out, "%#x:", addr)
var sortedKeys []common.Hash
storage := t[addr]
for key := range storage {
sortedKeys = append(sortedKeys, key)
}
slices.SortFunc(sortedKeys, common.Hash.Cmp)
for _, key := range sortedKeys {
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
}
}
return out.String()
}