core/state, light, les: make signature of ContractCode hash-independent (#27209)

* core/state, light, les: make signature of ContractCode hash-independent

* push current state for feedback

* les: fix unit test

* core, les, light: fix les unittests

* core/state, trie, les, light: fix state iterator

* core, les: address comments

* les: fix lint

---------

Co-authored-by: Gary Rong <garyrong0905@gmail.com>
This commit is contained in:
Guillaume Ballet 2023-06-28 11:11:02 +02:00 committed by GitHub
parent 85b8d1c06c
commit 8bbb16b70e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 164 additions and 142 deletions

@ -306,9 +306,11 @@ func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) {
// new code scheme.
func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) {
type codeReader interface {
ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error)
ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error)
}
return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Hash{}, hash)
// TODO(rjl493456442) The associated account address is also required
// in Verkle scheme. Fix it once snap-sync is supported for Verkle.
return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Address{}, hash)
}
// State returns a new mutable state based on the current HEAD block.

@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
@ -43,16 +44,16 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error)
OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error)
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code.
ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error)
// ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore
@ -177,8 +178,8 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
}
// OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb)
func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, crypto.Keccak256Hash(address.Bytes()), root), db.triedb)
if err != nil {
return nil, err
}
@ -196,7 +197,7 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
}
// ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 {
return code, nil
@ -213,7 +214,7 @@ func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error
// ContractCodeWithPrefix retrieves a particular contract's code. If the
// code can't be found in the cache, then check the existence with **new**
// db scheme.
func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error) {
func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 {
return code, nil
@ -228,11 +229,11 @@ func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]b
}
// ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached, nil
}
code, err := db.ContractCode(addrHash, codeHash)
code, err := db.ContractCode(addr, codeHash)
return len(code), err
}

@ -18,6 +18,7 @@ package state
import (
"bytes"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
@ -27,7 +28,8 @@ import (
)
// nodeIterator is an iterator to traverse the entire state trie post-order,
// including all of the contract code and contract state tries.
// including all of the contract code and contract state tries. Preimage is
// required in order to resolve the contract address.
type nodeIterator struct {
state *StateDB // State being iterated
@ -113,7 +115,15 @@ func (it *nodeIterator) step() error {
if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil {
return err
}
dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, common.BytesToHash(it.stateIt.LeafKey()), account.Root)
// Lookup the preimage of account hash
preimage := it.state.trie.GetKey(it.stateIt.LeafKey())
if preimage == nil {
return errors.New("account address is not available")
}
address := common.BytesToAddress(preimage)
// Traverse the storage slots belong to the account
dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, address, account.Root)
if err != nil {
return err
}
@ -126,8 +136,7 @@ func (it *nodeIterator) step() error {
}
if !bytes.Equal(account.CodeHash, types.EmptyCodeHash.Bytes()) {
it.codeHash = common.BytesToHash(account.CodeHash)
addrHash := common.BytesToHash(it.stateIt.LeafKey())
it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash))
it.code, err = it.state.db.ContractCode(address, common.BytesToHash(account.CodeHash))
if err != nil {
return fmt.Errorf("code %x: %v", account.CodeHash, err)
}

@ -142,7 +142,7 @@ func (s *stateObject) getTrie(db Database) (Trie, error) {
s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
}
if s.trie == nil {
tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root)
tr, err := db.OpenStorageTrie(s.db.originalRoot, s.address, s.data.Root)
if err != nil {
return nil, err
}
@ -441,7 +441,7 @@ func (s *stateObject) Code(db Database) []byte {
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil
}
code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash()))
code, err := db.ContractCode(s.address, common.BytesToHash(s.CodeHash()))
if err != nil {
s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
}
@ -459,7 +459,7 @@ func (s *stateObject) CodeSize(db Database) int {
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0
}
size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.CodeHash()))
size, err := db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash()))
if err != nil {
s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
}

@ -42,7 +42,7 @@ type testAccount struct {
func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) {
// Create an empty state
db := rawdb.NewMemoryDatabase()
sdb := NewDatabase(db)
sdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true})
state, _ := New(types.EmptyRootHash, sdb, nil)
// Fill it with some arbitrary data
@ -100,28 +100,9 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
}
}
// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent.
}
trie, err := trie.New(trie.StateTrieID(root), trie.NewDatabase(db))
if err != nil {
return err
}
it := trie.MustNodeIterator(nil)
for it.Next(true) {
}
return it.Error()
}
// checkStateConsistency checks that all data of a state root is present.
func checkStateConsistency(db ethdb.Database, root common.Hash) error {
// Create and iterate a state trie rooted in a sub-node
if _, err := db.Get(root.Bytes()); err != nil {
return nil // Consider a non existent state consistent.
}
state, err := New(root, NewDatabase(db), nil)
state, err := New(root, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil)
if err != nil {
return err
}
@ -171,7 +152,7 @@ type stateElement struct {
func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState()
srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
if commit {
srcDb.TrieDB().Commit(srcRoot, false)
}
@ -204,7 +185,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
codeResults = make([]trie.CodeSyncResult, len(codeElements))
)
for i, element := range codeElements {
data, err := srcDb.ContractCode(common.Hash{}, element.code)
data, err := srcDb.ContractCode(common.Address{}, element.code)
if err != nil {
t.Fatalf("failed to retrieve contract bytecode for hash %x", element.code)
}
@ -274,6 +255,10 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
})
}
}
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
}
@ -282,7 +267,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState()
srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase()
@ -312,7 +297,7 @@ func TestIterativeDelayedStateSync(t *testing.T) {
if len(codeElements) > 0 {
codeResults := make([]trie.CodeSyncResult, len(codeElements)/2+1)
for i, element := range codeElements[:len(codeResults)] {
data, err := srcDb.ContractCode(common.Hash{}, element.code)
data, err := srcDb.ContractCode(common.Address{}, element.code)
if err != nil {
t.Fatalf("failed to retrieve contract bytecode for %x", element.code)
}
@ -363,6 +348,10 @@ func TestIterativeDelayedStateSync(t *testing.T) {
})
}
}
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
}
@ -375,7 +364,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, count int) {
// Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState()
srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase()
@ -399,7 +388,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
if len(codeQueue) > 0 {
results := make([]trie.CodeSyncResult, 0, len(codeQueue))
for hash := range codeQueue {
data, err := srcDb.ContractCode(common.Hash{}, hash)
data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash)
}
@ -447,6 +436,10 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
codeQueue[hash] = struct{}{}
}
}
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
}
@ -455,7 +448,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
// partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState()
srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase()
@ -481,7 +474,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range codeQueue {
delete(codeQueue, hash)
data, err := srcDb.ContractCode(common.Hash{}, hash)
data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash)
}
@ -537,6 +530,10 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
codeQueue[hash] = struct{}{}
}
}
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
}
@ -555,7 +552,6 @@ func TestIncompleteStateSync(t *testing.T) {
}
}
isCode[types.EmptyCodeHash] = struct{}{}
checkTrieConsistency(db, srcRoot)
// Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase()
@ -588,7 +584,7 @@ func TestIncompleteStateSync(t *testing.T) {
if len(codeQueue) > 0 {
results := make([]trie.CodeSyncResult, 0, len(codeQueue))
for hash := range codeQueue {
data, err := srcDb.ContractCode(common.Hash{}, hash)
data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash)
}
@ -602,7 +598,6 @@ func TestIncompleteStateSync(t *testing.T) {
}
}
}
var nodehashes []common.Hash
if len(nodeQueue) > 0 {
results := make([]trie.NodeSyncResult, 0, len(nodeQueue))
for path, element := range nodeQueue {
@ -617,7 +612,6 @@ func TestIncompleteStateSync(t *testing.T) {
addedPaths = append(addedPaths, element.path)
addedHashes = append(addedHashes, element.hash)
}
nodehashes = append(nodehashes, element.hash)
}
// Process each of the state nodes
for _, result := range results {
@ -632,13 +626,6 @@ func TestIncompleteStateSync(t *testing.T) {
}
batch.Write()
for _, root := range nodehashes {
// Can't use checkStateConsistency here because subtrie keys may have odd
// length and crash in LeafKey.
if err := checkTrieConsistency(dstDb, root); err != nil {
t.Fatalf("state inconsistent: %v", err)
}
}
// Fetch the next batch to retrieve
nodeQueue = make(map[string]stateElement)
codeQueue = make(map[common.Hash]struct{})
@ -654,6 +641,10 @@ func TestIncompleteStateSync(t *testing.T) {
codeQueue[hash] = struct{}{}
}
}
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(db, dstDb)
// Sanity check that removing any node from the database is detected
for _, node := range addedCodes {
val := rawdb.ReadCode(dstDb, node)
@ -678,3 +669,15 @@ func TestIncompleteStateSync(t *testing.T) {
rawdb.WriteTrieNode(dstDb, owner, inner, hash, val, scheme)
}
}
func copyPreimages(srcDb, dstDb ethdb.Database) {
it := srcDb.NewIterator(rawdb.PreimagePrefix, nil)
defer it.Release()
preimages := make(map[common.Hash][]byte)
for it.Next() {
hash := it.Key()[len(rawdb.PreimagePrefix):]
preimages[common.BytesToHash(hash)] = common.CopyBytes(it.Value())
}
rawdb.WritePreimages(dstDb, preimages)
}

@ -302,7 +302,7 @@ func (sf *subfetcher) loop() {
}
sf.trie = trie
} else {
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root)
trie, err := sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return

@ -117,7 +117,7 @@ func (b *benchmarkProofsOrCode) request(peer *serverPeer, index int) error {
key := make([]byte, 32)
crand.Read(key)
if b.code {
return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccKey: key}})
return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccountAddress: key}})
}
return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}})
}

@ -296,7 +296,7 @@ func testGetCode(t *testing.T, protocol int) {
header := bc.GetHeaderByNumber(i)
req := &CodeReq{
BHash: header.Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]),
AccountAddress: testContractAddr[:],
}
codereqs = append(codereqs, req)
if i >= testContractDeployed {
@ -332,7 +332,7 @@ func testGetStaleCode(t *testing.T, protocol int) {
check := func(number uint64, expected [][]byte) {
req := &CodeReq{
BHash: bc.GetHeaderByNumber(number).Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]),
AccountAddress: testContractAddr[:],
}
sendRequest(rawPeer.app, GetCodeMsg, 42, []*CodeReq{req})
if err := expectResponse(rawPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil {

@ -184,7 +184,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error {
type ProofReq struct {
BHash common.Hash
AccKey, Key []byte
AccountAddress, Key []byte
FromLevel uint
}
@ -207,7 +207,7 @@ func (r *TrieRequest) Request(reqID uint64, peer *serverPeer) error {
peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key)
req := ProofReq{
BHash: r.Id.BlockHash,
AccKey: r.Id.AccKey,
AccountAddress: r.Id.AccountAddress,
Key: r.Key,
}
return peer.requestProofs(reqID, []ProofReq{req})
@ -239,7 +239,7 @@ func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error {
type CodeReq struct {
BHash common.Hash
AccKey []byte
AccountAddress []byte
}
// CodeRequest is the ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface
@ -261,7 +261,7 @@ func (r *CodeRequest) Request(reqID uint64, peer *serverPeer) error {
peer.Log().Debug("Requesting code data", "hash", r.Hash)
req := CodeReq{
BHash: r.Id.BlockHash,
AccKey: r.Id.AccKey,
AccountAddress: r.Id.AccountAddress,
}
return peer.requestCode(reqID, []CodeReq{req})
}

@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light"
@ -77,7 +78,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq
return nil
}
sti := light.StateTrieID(header)
ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{})
ci := light.StorageTrieID(sti, testContractAddr, types.EmptyRootHash)
return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)}
}

@ -18,6 +18,7 @@ package les
import (
"errors"
"fmt"
"sync"
"time"
@ -34,7 +35,6 @@ import (
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@ -358,20 +358,19 @@ func (h *serverHandler) AddTxsSync() bool {
}
// getAccount retrieves an account from the state based on root.
func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccount, error) {
trie, err := trie.New(trie.StateTrieID(root), triedb)
func getAccount(triedb *trie.Database, root common.Hash, addr common.Address) (types.StateAccount, error) {
trie, err := trie.NewStateTrie(trie.StateTrieID(root), triedb)
if err != nil {
return types.StateAccount{}, err
}
blob, err := trie.Get(hash[:])
acc, err := trie.GetAccount(addr)
if err != nil {
return types.StateAccount{}, err
}
var acc types.StateAccount
if err = rlp.DecodeBytes(blob, &acc); err != nil {
return types.StateAccount{}, err
if acc == nil {
return types.StateAccount{}, fmt.Errorf("account %#x is not present", addr)
}
return acc, nil
return *acc, nil
}
// GetHelperTrie returns the post-processed trie root for the given trie ID and section index

@ -304,16 +304,16 @@ func handleGetCode(msg Decoder) (serveRequestFn, uint64, uint64, error) {
continue
}
triedb := bc.StateCache().TrieDB()
account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey))
address := common.BytesToAddress(request.AccountAddress)
account, err := getAccount(triedb, header.Root, address)
if err != nil {
p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", address, "err", err)
p.bumpInvalid()
continue
}
code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash))
code, err := bc.StateCache().ContractCode(address, common.BytesToHash(account.CodeHash))
if err != nil {
p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err)
p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", address, "codehash", common.BytesToHash(account.CodeHash), "err", err)
continue
}
// Accumulate the code and abort if enough data was retrieved
@ -413,7 +413,7 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) {
statedb := bc.StateCache()
var trie state.Trie
switch len(request.AccKey) {
switch len(request.AccountAddress) {
case 0:
// No account key specified, open an account trie
trie, err = statedb.OpenTrie(root)
@ -423,15 +423,16 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) {
}
default:
// Account key specified, open a storage trie
account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey))
address := common.BytesToAddress(request.AccountAddress)
account, err := getAccount(statedb.TrieDB(), root, address)
if err != nil {
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err)
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", address, "err", err)
p.bumpInvalid()
continue
}
trie, err = statedb.OpenStorageTrie(root, common.BytesToHash(request.AccKey), account.Root)
trie, err = statedb.OpenStorageTrie(root, address, account.Root)
if trie == nil || err != nil {
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err)
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", address, "root", account.Root, "err", err)
continue
}
}

@ -58,7 +58,7 @@ type TrieID struct {
BlockNumber uint64
StateRoot common.Hash
Root common.Hash
AccKey []byte
AccountAddress []byte
}
// StateTrieID returns a TrieID for a state trie belonging to a certain block
@ -69,19 +69,19 @@ func StateTrieID(header *types.Header) *TrieID {
BlockNumber: header.Number.Uint64(),
StateRoot: header.Root,
Root: header.Root,
AccKey: nil,
AccountAddress: nil,
}
}
// StorageTrieID returns a TrieID for a contract storage trie at a given account
// of a given state trie. It also requires the root hash of the trie for
// checking Merkle proofs.
func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID {
func StorageTrieID(state *TrieID, address common.Address, root common.Hash) *TrieID {
return &TrieID{
BlockHash: state.BlockHash,
BlockNumber: state.BlockNumber,
StateRoot: state.StateRoot,
AccKey: addrHash[:],
AccountAddress: address[:],
Root: root,
}
}

@ -86,8 +86,8 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
err error
t state.Trie
)
if len(req.Id.AccKey) > 0 {
t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToHash(req.Id.AccKey), req.Id.Root)
if len(req.Id.AccountAddress) > 0 {
t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToAddress(req.Id.AccountAddress), req.Id.Root)
} else {
t, err = odr.serverState.OpenTrie(req.Id.Root)
}

@ -55,8 +55,8 @@ func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: db.id}, nil
}
func (db *odrDatabase) OpenStorageTrie(state, addrHash, root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil
func (db *odrDatabase) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: StorageTrieID(db.id, address, root)}, nil
}
func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
@ -72,7 +72,7 @@ func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
}
}
func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
func (db *odrDatabase) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) {
if codeHash == sha3Nil {
return nil, nil
}
@ -81,14 +81,14 @@ func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, err
return code, nil
}
id := *db.id
id.AccKey = addrHash[:]
id.AccountAddress = addr[:]
req := &CodeRequest{Id: &id, Hash: codeHash}
err := db.backend.Retrieve(db.ctx, req)
return req.Data, err
}
func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
code, err := db.ContractCode(addrHash, codeHash)
func (db *odrDatabase) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
code, err := db.ContractCode(addr, codeHash)
return len(code), err
}
@ -207,8 +207,8 @@ func (t *odrTrie) do(key []byte, fn func() error) error {
var err error
if t.trie == nil {
var id *trie.ID
if len(t.id.AccKey) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root)
if len(t.id.AccountAddress) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root)
} else {
id = trie.StateTrieID(t.id.StateRoot)
}
@ -239,8 +239,8 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
if t.trie == nil {
it.do(func() error {
var id *trie.ID
if len(t.id.AccKey) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root)
if len(t.id.AccountAddress) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root)
} else {
id = trie.StateTrieID(t.id.StateRoot)
}

@ -46,7 +46,7 @@ var (
testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
chain *core.BlockChain
addrHashes []common.Hash
addresses []common.Address
txHashes []common.Hash
chtTrie *trie.Trie
@ -55,7 +55,7 @@ var (
bloomKeys [][]byte
)
func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) {
func makechain() (bc *core.BlockChain, addresses []common.Address, txHashes []common.Hash) {
gspec := &core.Genesis{
Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
@ -77,7 +77,7 @@ func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) {
tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(params.GWei), nil), signer, bankKey)
}
gen.AddTx(tx)
addrHashes = append(addrHashes, crypto.Keccak256Hash(addr[:]))
addresses = append(addresses, addr)
txHashes = append(txHashes, tx.Hash())
})
bc, _ = core.NewBlockChain(rawdb.NewMemoryDatabase(), nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil)
@ -107,7 +107,7 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [
}
func init() {
chain, addrHashes, txHashes = makechain()
chain, addresses, txHashes = makechain()
chtTrie, bloomTrie, chtKeys, bloomKeys = makeTries()
}
@ -116,7 +116,8 @@ type fuzzer struct {
pool *txpool.TxPool
chainLen int
addr, txs []common.Hash
addresses []common.Address
txs []common.Hash
nonce uint64
chtKeys [][]byte
@ -135,7 +136,7 @@ func newFuzzer(input []byte) *fuzzer {
return &fuzzer{
chain: chain,
chainLen: testChainLen,
addr: addrHashes,
addresses: addresses,
txs: txHashes,
chtTrie: chtTrie,
bloomTrie: bloomTrie,
@ -198,12 +199,12 @@ func (f *fuzzer) randomBlockHash() common.Hash {
return common.BytesToHash(f.read(common.HashLength))
}
func (f *fuzzer) randomAddrHash() []byte {
i := f.randomInt(3 * len(f.addr))
if i < len(f.addr) {
return f.addr[i].Bytes()
func (f *fuzzer) randomAddress() []byte {
i := f.randomInt(3 * len(f.addresses))
if i < len(f.addresses) {
return f.addresses[i].Bytes()
}
return f.read(common.HashLength)
return f.read(common.AddressLength)
}
func (f *fuzzer) randomCHTTrieKey() []byte {
@ -316,7 +317,7 @@ func Fuzz(input []byte) int {
for i := range req.Reqs {
req.Reqs[i] = l.CodeReq{
BHash: f.randomBlockHash(),
AccKey: f.randomAddrHash(),
AccountAddress: f.randomAddress(),
}
}
f.doFuzz(l.GetCodeMsg, req)
@ -334,14 +335,14 @@ func Fuzz(input []byte) int {
if f.randomBool() {
req.Reqs[i] = l.ProofReq{
BHash: f.randomBlockHash(),
AccKey: f.randomAddrHash(),
Key: f.randomAddrHash(),
AccountAddress: f.randomAddress(),
Key: f.randomAddress(),
FromLevel: uint(f.randomX(3)),
}
} else {
req.Reqs[i] = l.ProofReq{
BHash: f.randomBlockHash(),
Key: f.randomAddrHash(),
Key: f.randomAddress(),
FromLevel: uint(f.randomX(3)),
}
}

@ -168,10 +168,15 @@ func (db *Database) Scheme() string {
// It is meant to be called when closing the blockchain object, so that all
// resources held can be released correctly.
func (db *Database) Close() error {
db.WritePreimages()
return db.backend.Close()
}
// WritePreimages flushes all accumulated preimages to disk forcibly.
func (db *Database) WritePreimages() {
if db.preimages != nil {
db.preimages.commit(true)
}
return db.backend.Close()
}
// saveCache saves clean state cache to given directory path