core/state, light, les: make signature of ContractCode hash-independent (#27209)

* core/state, light, les: make signature of ContractCode hash-independent

* push current state for feedback

* les: fix unit test

* core, les, light: fix les unittests

* core/state, trie, les, light: fix state iterator

* core, les: address comments

* les: fix lint

---------

Co-authored-by: Gary Rong <garyrong0905@gmail.com>
This commit is contained in:
Guillaume Ballet 2023-06-28 11:11:02 +02:00 committed by GitHub
parent 85b8d1c06c
commit 8bbb16b70e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 164 additions and 142 deletions

@ -306,9 +306,11 @@ func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) {
// new code scheme. // new code scheme.
func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) { func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) {
type codeReader interface { type codeReader interface {
ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error)
} }
return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Hash{}, hash) // TODO(rjl493456442) The associated account address is also required
// in Verkle scheme. Fix it once snap-sync is supported for Verkle.
return bc.stateCache.(codeReader).ContractCodeWithPrefix(common.Address{}, hash)
} }
// State returns a new mutable state based on the current HEAD block. // State returns a new mutable state based on the current HEAD block.

@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/trie/trienode"
@ -43,16 +44,16 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error) OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error)
// CopyTrie returns an independent copy of the given trie. // CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
ContractCode(addrHash, codeHash common.Hash) ([]byte, error) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error)
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(addrHash, codeHash common.Hash) (int, error) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database. // DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore DiskDB() ethdb.KeyValueStore
@ -177,8 +178,8 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
} }
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb) tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, crypto.Keccak256Hash(address.Bytes()), root), db.triedb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,7 +197,7 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
} }
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCode(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash) code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 { if len(code) > 0 {
return code, nil return code, nil
@ -213,7 +214,7 @@ func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error
// ContractCodeWithPrefix retrieves a particular contract's code. If the // ContractCodeWithPrefix retrieves a particular contract's code. If the
// code can't be found in the cache, then check the existence with **new** // code can't be found in the cache, then check the existence with **new**
// db scheme. // db scheme.
func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCodeWithPrefix(address common.Address, codeHash common.Hash) ([]byte, error) {
code, _ := db.codeCache.Get(codeHash) code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 { if len(code) > 0 {
return code, nil return code, nil
@ -228,11 +229,11 @@ func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]b
} }
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { func (db *cachingDB) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok { if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached, nil return cached, nil
} }
code, err := db.ContractCode(addrHash, codeHash) code, err := db.ContractCode(addr, codeHash)
return len(code), err return len(code), err
} }

@ -18,6 +18,7 @@ package state
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -27,7 +28,8 @@ import (
) )
// nodeIterator is an iterator to traverse the entire state trie post-order, // nodeIterator is an iterator to traverse the entire state trie post-order,
// including all of the contract code and contract state tries. // including all of the contract code and contract state tries. Preimage is
// required in order to resolve the contract address.
type nodeIterator struct { type nodeIterator struct {
state *StateDB // State being iterated state *StateDB // State being iterated
@ -113,7 +115,15 @@ func (it *nodeIterator) step() error {
if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil {
return err return err
} }
dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, common.BytesToHash(it.stateIt.LeafKey()), account.Root) // Lookup the preimage of account hash
preimage := it.state.trie.GetKey(it.stateIt.LeafKey())
if preimage == nil {
return errors.New("account address is not available")
}
address := common.BytesToAddress(preimage)
// Traverse the storage slots belong to the account
dataTrie, err := it.state.db.OpenStorageTrie(it.state.originalRoot, address, account.Root)
if err != nil { if err != nil {
return err return err
} }
@ -126,8 +136,7 @@ func (it *nodeIterator) step() error {
} }
if !bytes.Equal(account.CodeHash, types.EmptyCodeHash.Bytes()) { if !bytes.Equal(account.CodeHash, types.EmptyCodeHash.Bytes()) {
it.codeHash = common.BytesToHash(account.CodeHash) it.codeHash = common.BytesToHash(account.CodeHash)
addrHash := common.BytesToHash(it.stateIt.LeafKey()) it.code, err = it.state.db.ContractCode(address, common.BytesToHash(account.CodeHash))
it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash))
if err != nil { if err != nil {
return fmt.Errorf("code %x: %v", account.CodeHash, err) return fmt.Errorf("code %x: %v", account.CodeHash, err)
} }

@ -142,7 +142,7 @@ func (s *stateObject) getTrie(db Database) (Trie, error) {
s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
} }
if s.trie == nil { if s.trie == nil {
tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root) tr, err := db.OpenStorageTrie(s.db.originalRoot, s.address, s.data.Root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -441,7 +441,7 @@ func (s *stateObject) Code(db Database) []byte {
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil return nil
} }
code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) code, err := db.ContractCode(s.address, common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
} }
@ -459,7 +459,7 @@ func (s *stateObject) CodeSize(db Database) int {
if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0 return 0
} }
size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.CodeHash())) size, err := db.ContractCodeSize(s.address, common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
} }

@ -42,7 +42,7 @@ type testAccount struct {
func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) { func makeTestState() (ethdb.Database, Database, common.Hash, []*testAccount) {
// Create an empty state // Create an empty state
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
sdb := NewDatabase(db) sdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true})
state, _ := New(types.EmptyRootHash, sdb, nil) state, _ := New(types.EmptyRootHash, sdb, nil)
// Fill it with some arbitrary data // Fill it with some arbitrary data
@ -100,28 +100,9 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
} }
} }
// checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent.
}
trie, err := trie.New(trie.StateTrieID(root), trie.NewDatabase(db))
if err != nil {
return err
}
it := trie.MustNodeIterator(nil)
for it.Next(true) {
}
return it.Error()
}
// checkStateConsistency checks that all data of a state root is present. // checkStateConsistency checks that all data of a state root is present.
func checkStateConsistency(db ethdb.Database, root common.Hash) error { func checkStateConsistency(db ethdb.Database, root common.Hash) error {
// Create and iterate a state trie rooted in a sub-node state, err := New(root, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), nil)
if _, err := db.Get(root.Bytes()); err != nil {
return nil // Consider a non existent state consistent.
}
state, err := New(root, NewDatabase(db), nil)
if err != nil { if err != nil {
return err return err
} }
@ -171,7 +152,7 @@ type stateElement struct {
func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// Create a random state to copy // Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState() srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
if commit { if commit {
srcDb.TrieDB().Commit(srcRoot, false) srcDb.TrieDB().Commit(srcRoot, false)
} }
@ -204,7 +185,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
codeResults = make([]trie.CodeSyncResult, len(codeElements)) codeResults = make([]trie.CodeSyncResult, len(codeElements))
) )
for i, element := range codeElements { for i, element := range codeElements {
data, err := srcDb.ContractCode(common.Hash{}, element.code) data, err := srcDb.ContractCode(common.Address{}, element.code)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve contract bytecode for hash %x", element.code) t.Fatalf("failed to retrieve contract bytecode for hash %x", element.code)
} }
@ -274,6 +255,10 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
}) })
} }
} }
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync // Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts) checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
} }
@ -282,7 +267,7 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
// partial results are returned, and the others sent only later. // partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) { func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState() srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
@ -312,7 +297,7 @@ func TestIterativeDelayedStateSync(t *testing.T) {
if len(codeElements) > 0 { if len(codeElements) > 0 {
codeResults := make([]trie.CodeSyncResult, len(codeElements)/2+1) codeResults := make([]trie.CodeSyncResult, len(codeElements)/2+1)
for i, element := range codeElements[:len(codeResults)] { for i, element := range codeElements[:len(codeResults)] {
data, err := srcDb.ContractCode(common.Hash{}, element.code) data, err := srcDb.ContractCode(common.Address{}, element.code)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve contract bytecode for %x", element.code) t.Fatalf("failed to retrieve contract bytecode for %x", element.code)
} }
@ -363,6 +348,10 @@ func TestIterativeDelayedStateSync(t *testing.T) {
}) })
} }
} }
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync // Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts) checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
} }
@ -375,7 +364,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, count int) { func testIterativeRandomStateSync(t *testing.T, count int) {
// Create a random state to copy // Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState() srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
@ -399,7 +388,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
if len(codeQueue) > 0 { if len(codeQueue) > 0 {
results := make([]trie.CodeSyncResult, 0, len(codeQueue)) results := make([]trie.CodeSyncResult, 0, len(codeQueue))
for hash := range codeQueue { for hash := range codeQueue {
data, err := srcDb.ContractCode(common.Hash{}, hash) data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash) t.Fatalf("failed to retrieve node data for %x", hash)
} }
@ -447,6 +436,10 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
codeQueue[hash] = struct{}{} codeQueue[hash] = struct{}{}
} }
} }
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync // Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts) checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
} }
@ -455,7 +448,7 @@ func testIterativeRandomStateSync(t *testing.T, count int) {
// partial results are returned (Even those randomly), others sent only later. // partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) { func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
_, srcDb, srcRoot, srcAccounts := makeTestState() srcDisk, srcDb, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
@ -481,7 +474,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range codeQueue { for hash := range codeQueue {
delete(codeQueue, hash) delete(codeQueue, hash)
data, err := srcDb.ContractCode(common.Hash{}, hash) data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash) t.Fatalf("failed to retrieve node data for %x", hash)
} }
@ -537,6 +530,10 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
codeQueue[hash] = struct{}{} codeQueue[hash] = struct{}{}
} }
} }
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(srcDisk, dstDb)
// Cross check that the two states are in sync // Cross check that the two states are in sync
checkStateAccounts(t, dstDb, srcRoot, srcAccounts) checkStateAccounts(t, dstDb, srcRoot, srcAccounts)
} }
@ -555,7 +552,6 @@ func TestIncompleteStateSync(t *testing.T) {
} }
} }
isCode[types.EmptyCodeHash] = struct{}{} isCode[types.EmptyCodeHash] = struct{}{}
checkTrieConsistency(db, srcRoot)
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb := rawdb.NewMemoryDatabase() dstDb := rawdb.NewMemoryDatabase()
@ -588,7 +584,7 @@ func TestIncompleteStateSync(t *testing.T) {
if len(codeQueue) > 0 { if len(codeQueue) > 0 {
results := make([]trie.CodeSyncResult, 0, len(codeQueue)) results := make([]trie.CodeSyncResult, 0, len(codeQueue))
for hash := range codeQueue { for hash := range codeQueue {
data, err := srcDb.ContractCode(common.Hash{}, hash) data, err := srcDb.ContractCode(common.Address{}, hash)
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x", hash) t.Fatalf("failed to retrieve node data for %x", hash)
} }
@ -602,7 +598,6 @@ func TestIncompleteStateSync(t *testing.T) {
} }
} }
} }
var nodehashes []common.Hash
if len(nodeQueue) > 0 { if len(nodeQueue) > 0 {
results := make([]trie.NodeSyncResult, 0, len(nodeQueue)) results := make([]trie.NodeSyncResult, 0, len(nodeQueue))
for path, element := range nodeQueue { for path, element := range nodeQueue {
@ -617,7 +612,6 @@ func TestIncompleteStateSync(t *testing.T) {
addedPaths = append(addedPaths, element.path) addedPaths = append(addedPaths, element.path)
addedHashes = append(addedHashes, element.hash) addedHashes = append(addedHashes, element.hash)
} }
nodehashes = append(nodehashes, element.hash)
} }
// Process each of the state nodes // Process each of the state nodes
for _, result := range results { for _, result := range results {
@ -632,13 +626,6 @@ func TestIncompleteStateSync(t *testing.T) {
} }
batch.Write() batch.Write()
for _, root := range nodehashes {
// Can't use checkStateConsistency here because subtrie keys may have odd
// length and crash in LeafKey.
if err := checkTrieConsistency(dstDb, root); err != nil {
t.Fatalf("state inconsistent: %v", err)
}
}
// Fetch the next batch to retrieve // Fetch the next batch to retrieve
nodeQueue = make(map[string]stateElement) nodeQueue = make(map[string]stateElement)
codeQueue = make(map[common.Hash]struct{}) codeQueue = make(map[common.Hash]struct{})
@ -654,6 +641,10 @@ func TestIncompleteStateSync(t *testing.T) {
codeQueue[hash] = struct{}{} codeQueue[hash] = struct{}{}
} }
} }
// Copy the preimages from source db in order to traverse the state.
srcDb.TrieDB().WritePreimages()
copyPreimages(db, dstDb)
// Sanity check that removing any node from the database is detected // Sanity check that removing any node from the database is detected
for _, node := range addedCodes { for _, node := range addedCodes {
val := rawdb.ReadCode(dstDb, node) val := rawdb.ReadCode(dstDb, node)
@ -678,3 +669,15 @@ func TestIncompleteStateSync(t *testing.T) {
rawdb.WriteTrieNode(dstDb, owner, inner, hash, val, scheme) rawdb.WriteTrieNode(dstDb, owner, inner, hash, val, scheme)
} }
} }
func copyPreimages(srcDb, dstDb ethdb.Database) {
it := srcDb.NewIterator(rawdb.PreimagePrefix, nil)
defer it.Release()
preimages := make(map[common.Hash][]byte)
for it.Next() {
hash := it.Key()[len(rawdb.PreimagePrefix):]
preimages[common.BytesToHash(hash)] = common.CopyBytes(it.Value())
}
rawdb.WritePreimages(dstDb, preimages)
}

@ -302,7 +302,7 @@ func (sf *subfetcher) loop() {
} }
sf.trie = trie sf.trie = trie
} else { } else {
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root) trie, err := sf.db.OpenStorageTrie(sf.state, sf.addr, sf.root)
if err != nil { if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return return

@ -117,7 +117,7 @@ func (b *benchmarkProofsOrCode) request(peer *serverPeer, index int) error {
key := make([]byte, 32) key := make([]byte, 32)
crand.Read(key) crand.Read(key)
if b.code { if b.code {
return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccKey: key}}) return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccountAddress: key}})
} }
return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}}) return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}})
} }

@ -296,7 +296,7 @@ func testGetCode(t *testing.T, protocol int) {
header := bc.GetHeaderByNumber(i) header := bc.GetHeaderByNumber(i)
req := &CodeReq{ req := &CodeReq{
BHash: header.Hash(), BHash: header.Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]), AccountAddress: testContractAddr[:],
} }
codereqs = append(codereqs, req) codereqs = append(codereqs, req)
if i >= testContractDeployed { if i >= testContractDeployed {
@ -332,7 +332,7 @@ func testGetStaleCode(t *testing.T, protocol int) {
check := func(number uint64, expected [][]byte) { check := func(number uint64, expected [][]byte) {
req := &CodeReq{ req := &CodeReq{
BHash: bc.GetHeaderByNumber(number).Hash(), BHash: bc.GetHeaderByNumber(number).Hash(),
AccKey: crypto.Keccak256(testContractAddr[:]), AccountAddress: testContractAddr[:],
} }
sendRequest(rawPeer.app, GetCodeMsg, 42, []*CodeReq{req}) sendRequest(rawPeer.app, GetCodeMsg, 42, []*CodeReq{req})
if err := expectResponse(rawPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil { if err := expectResponse(rawPeer.app, CodeMsg, 42, testBufLimit, expected); err != nil {

@ -184,7 +184,7 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error {
type ProofReq struct { type ProofReq struct {
BHash common.Hash BHash common.Hash
AccKey, Key []byte AccountAddress, Key []byte
FromLevel uint FromLevel uint
} }
@ -207,7 +207,7 @@ func (r *TrieRequest) Request(reqID uint64, peer *serverPeer) error {
peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key) peer.Log().Debug("Requesting trie proof", "root", r.Id.Root, "key", r.Key)
req := ProofReq{ req := ProofReq{
BHash: r.Id.BlockHash, BHash: r.Id.BlockHash,
AccKey: r.Id.AccKey, AccountAddress: r.Id.AccountAddress,
Key: r.Key, Key: r.Key,
} }
return peer.requestProofs(reqID, []ProofReq{req}) return peer.requestProofs(reqID, []ProofReq{req})
@ -239,7 +239,7 @@ func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error {
type CodeReq struct { type CodeReq struct {
BHash common.Hash BHash common.Hash
AccKey []byte AccountAddress []byte
} }
// CodeRequest is the ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface // CodeRequest is the ODR request type for node data (used for retrieving contract code), see LesOdrRequest interface
@ -261,7 +261,7 @@ func (r *CodeRequest) Request(reqID uint64, peer *serverPeer) error {
peer.Log().Debug("Requesting code data", "hash", r.Hash) peer.Log().Debug("Requesting code data", "hash", r.Hash)
req := CodeReq{ req := CodeReq{
BHash: r.Id.BlockHash, BHash: r.Id.BlockHash,
AccKey: r.Id.AccKey, AccountAddress: r.Id.AccountAddress,
} }
return peer.requestCode(reqID, []CodeReq{req}) return peer.requestCode(reqID, []CodeReq{req})
} }

@ -23,6 +23,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
@ -77,7 +78,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq
return nil return nil
} }
sti := light.StateTrieID(header) sti := light.StateTrieID(header)
ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{}) ci := light.StorageTrieID(sti, testContractAddr, types.EmptyRootHash)
return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)}
} }

@ -18,6 +18,7 @@ package les
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
@ -34,7 +35,6 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
@ -358,20 +358,19 @@ func (h *serverHandler) AddTxsSync() bool {
} }
// getAccount retrieves an account from the state based on root. // getAccount retrieves an account from the state based on root.
func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccount, error) { func getAccount(triedb *trie.Database, root common.Hash, addr common.Address) (types.StateAccount, error) {
trie, err := trie.New(trie.StateTrieID(root), triedb) trie, err := trie.NewStateTrie(trie.StateTrieID(root), triedb)
if err != nil { if err != nil {
return types.StateAccount{}, err return types.StateAccount{}, err
} }
blob, err := trie.Get(hash[:]) acc, err := trie.GetAccount(addr)
if err != nil { if err != nil {
return types.StateAccount{}, err return types.StateAccount{}, err
} }
var acc types.StateAccount if acc == nil {
if err = rlp.DecodeBytes(blob, &acc); err != nil { return types.StateAccount{}, fmt.Errorf("account %#x is not present", addr)
return types.StateAccount{}, err
} }
return acc, nil return *acc, nil
} }
// GetHelperTrie returns the post-processed trie root for the given trie ID and section index // GetHelperTrie returns the post-processed trie root for the given trie ID and section index

@ -304,16 +304,16 @@ func handleGetCode(msg Decoder) (serveRequestFn, uint64, uint64, error) {
continue continue
} }
triedb := bc.StateCache().TrieDB() triedb := bc.StateCache().TrieDB()
address := common.BytesToAddress(request.AccountAddress)
account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) account, err := getAccount(triedb, header.Root, address)
if err != nil { if err != nil {
p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", address, "err", err)
p.bumpInvalid() p.bumpInvalid()
continue continue
} }
code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) code, err := bc.StateCache().ContractCode(address, common.BytesToHash(account.CodeHash))
if err != nil { if err != nil {
p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", address, "codehash", common.BytesToHash(account.CodeHash), "err", err)
continue continue
} }
// Accumulate the code and abort if enough data was retrieved // Accumulate the code and abort if enough data was retrieved
@ -413,7 +413,7 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) {
statedb := bc.StateCache() statedb := bc.StateCache()
var trie state.Trie var trie state.Trie
switch len(request.AccKey) { switch len(request.AccountAddress) {
case 0: case 0:
// No account key specified, open an account trie // No account key specified, open an account trie
trie, err = statedb.OpenTrie(root) trie, err = statedb.OpenTrie(root)
@ -423,15 +423,16 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) {
} }
default: default:
// Account key specified, open a storage trie // Account key specified, open a storage trie
account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) address := common.BytesToAddress(request.AccountAddress)
account, err := getAccount(statedb.TrieDB(), root, address)
if err != nil { if err != nil {
p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", address, "err", err)
p.bumpInvalid() p.bumpInvalid()
continue continue
} }
trie, err = statedb.OpenStorageTrie(root, common.BytesToHash(request.AccKey), account.Root) trie, err = statedb.OpenStorageTrie(root, address, account.Root)
if trie == nil || err != nil { if trie == nil || err != nil {
p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", address, "root", account.Root, "err", err)
continue continue
} }
} }

@ -58,7 +58,7 @@ type TrieID struct {
BlockNumber uint64 BlockNumber uint64
StateRoot common.Hash StateRoot common.Hash
Root common.Hash Root common.Hash
AccKey []byte AccountAddress []byte
} }
// StateTrieID returns a TrieID for a state trie belonging to a certain block // StateTrieID returns a TrieID for a state trie belonging to a certain block
@ -69,19 +69,19 @@ func StateTrieID(header *types.Header) *TrieID {
BlockNumber: header.Number.Uint64(), BlockNumber: header.Number.Uint64(),
StateRoot: header.Root, StateRoot: header.Root,
Root: header.Root, Root: header.Root,
AccKey: nil, AccountAddress: nil,
} }
} }
// StorageTrieID returns a TrieID for a contract storage trie at a given account // StorageTrieID returns a TrieID for a contract storage trie at a given account
// of a given state trie. It also requires the root hash of the trie for // of a given state trie. It also requires the root hash of the trie for
// checking Merkle proofs. // checking Merkle proofs.
func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID { func StorageTrieID(state *TrieID, address common.Address, root common.Hash) *TrieID {
return &TrieID{ return &TrieID{
BlockHash: state.BlockHash, BlockHash: state.BlockHash,
BlockNumber: state.BlockNumber, BlockNumber: state.BlockNumber,
StateRoot: state.StateRoot, StateRoot: state.StateRoot,
AccKey: addrHash[:], AccountAddress: address[:],
Root: root, Root: root,
} }
} }

@ -86,8 +86,8 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
err error err error
t state.Trie t state.Trie
) )
if len(req.Id.AccKey) > 0 { if len(req.Id.AccountAddress) > 0 {
t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToHash(req.Id.AccKey), req.Id.Root) t, err = odr.serverState.OpenStorageTrie(req.Id.StateRoot, common.BytesToAddress(req.Id.AccountAddress), req.Id.Root)
} else { } else {
t, err = odr.serverState.OpenTrie(req.Id.Root) t, err = odr.serverState.OpenTrie(req.Id.Root)
} }

@ -55,8 +55,8 @@ func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: db.id}, nil return &odrTrie{db: db, id: db.id}, nil
} }
func (db *odrDatabase) OpenStorageTrie(state, addrHash, root common.Hash) (state.Trie, error) { func (db *odrDatabase) OpenStorageTrie(stateRoot common.Hash, address common.Address, root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil return &odrTrie{db: db, id: StorageTrieID(db.id, address, root)}, nil
} }
func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
@ -72,7 +72,7 @@ func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
} }
} }
func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { func (db *odrDatabase) ContractCode(addr common.Address, codeHash common.Hash) ([]byte, error) {
if codeHash == sha3Nil { if codeHash == sha3Nil {
return nil, nil return nil, nil
} }
@ -81,14 +81,14 @@ func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, err
return code, nil return code, nil
} }
id := *db.id id := *db.id
id.AccKey = addrHash[:] id.AccountAddress = addr[:]
req := &CodeRequest{Id: &id, Hash: codeHash} req := &CodeRequest{Id: &id, Hash: codeHash}
err := db.backend.Retrieve(db.ctx, req) err := db.backend.Retrieve(db.ctx, req)
return req.Data, err return req.Data, err
} }
func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { func (db *odrDatabase) ContractCodeSize(addr common.Address, codeHash common.Hash) (int, error) {
code, err := db.ContractCode(addrHash, codeHash) code, err := db.ContractCode(addr, codeHash)
return len(code), err return len(code), err
} }
@ -207,8 +207,8 @@ func (t *odrTrie) do(key []byte, fn func() error) error {
var err error var err error
if t.trie == nil { if t.trie == nil {
var id *trie.ID var id *trie.ID
if len(t.id.AccKey) > 0 { if len(t.id.AccountAddress) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root) id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root)
} else { } else {
id = trie.StateTrieID(t.id.StateRoot) id = trie.StateTrieID(t.id.StateRoot)
} }
@ -239,8 +239,8 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
if t.trie == nil { if t.trie == nil {
it.do(func() error { it.do(func() error {
var id *trie.ID var id *trie.ID
if len(t.id.AccKey) > 0 { if len(t.id.AccountAddress) > 0 {
id = trie.StorageTrieID(t.id.StateRoot, common.BytesToHash(t.id.AccKey), t.id.Root) id = trie.StorageTrieID(t.id.StateRoot, crypto.Keccak256Hash(t.id.AccountAddress), t.id.Root)
} else { } else {
id = trie.StateTrieID(t.id.StateRoot) id = trie.StateTrieID(t.id.StateRoot)
} }

@ -46,7 +46,7 @@ var (
testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056")
chain *core.BlockChain chain *core.BlockChain
addrHashes []common.Hash addresses []common.Address
txHashes []common.Hash txHashes []common.Hash
chtTrie *trie.Trie chtTrie *trie.Trie
@ -55,7 +55,7 @@ var (
bloomKeys [][]byte bloomKeys [][]byte
) )
func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { func makechain() (bc *core.BlockChain, addresses []common.Address, txHashes []common.Hash) {
gspec := &core.Genesis{ gspec := &core.Genesis{
Config: params.TestChainConfig, Config: params.TestChainConfig,
Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}},
@ -77,7 +77,7 @@ func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) {
tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(params.GWei), nil), signer, bankKey) tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(params.GWei), nil), signer, bankKey)
} }
gen.AddTx(tx) gen.AddTx(tx)
addrHashes = append(addrHashes, crypto.Keccak256Hash(addr[:])) addresses = append(addresses, addr)
txHashes = append(txHashes, tx.Hash()) txHashes = append(txHashes, tx.Hash())
}) })
bc, _ = core.NewBlockChain(rawdb.NewMemoryDatabase(), nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil) bc, _ = core.NewBlockChain(rawdb.NewMemoryDatabase(), nil, gspec, nil, ethash.NewFaker(), vm.Config{}, nil, nil)
@ -107,7 +107,7 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [
} }
func init() { func init() {
chain, addrHashes, txHashes = makechain() chain, addresses, txHashes = makechain()
chtTrie, bloomTrie, chtKeys, bloomKeys = makeTries() chtTrie, bloomTrie, chtKeys, bloomKeys = makeTries()
} }
@ -116,7 +116,8 @@ type fuzzer struct {
pool *txpool.TxPool pool *txpool.TxPool
chainLen int chainLen int
addr, txs []common.Hash addresses []common.Address
txs []common.Hash
nonce uint64 nonce uint64
chtKeys [][]byte chtKeys [][]byte
@ -135,7 +136,7 @@ func newFuzzer(input []byte) *fuzzer {
return &fuzzer{ return &fuzzer{
chain: chain, chain: chain,
chainLen: testChainLen, chainLen: testChainLen,
addr: addrHashes, addresses: addresses,
txs: txHashes, txs: txHashes,
chtTrie: chtTrie, chtTrie: chtTrie,
bloomTrie: bloomTrie, bloomTrie: bloomTrie,
@ -198,12 +199,12 @@ func (f *fuzzer) randomBlockHash() common.Hash {
return common.BytesToHash(f.read(common.HashLength)) return common.BytesToHash(f.read(common.HashLength))
} }
func (f *fuzzer) randomAddrHash() []byte { func (f *fuzzer) randomAddress() []byte {
i := f.randomInt(3 * len(f.addr)) i := f.randomInt(3 * len(f.addresses))
if i < len(f.addr) { if i < len(f.addresses) {
return f.addr[i].Bytes() return f.addresses[i].Bytes()
} }
return f.read(common.HashLength) return f.read(common.AddressLength)
} }
func (f *fuzzer) randomCHTTrieKey() []byte { func (f *fuzzer) randomCHTTrieKey() []byte {
@ -316,7 +317,7 @@ func Fuzz(input []byte) int {
for i := range req.Reqs { for i := range req.Reqs {
req.Reqs[i] = l.CodeReq{ req.Reqs[i] = l.CodeReq{
BHash: f.randomBlockHash(), BHash: f.randomBlockHash(),
AccKey: f.randomAddrHash(), AccountAddress: f.randomAddress(),
} }
} }
f.doFuzz(l.GetCodeMsg, req) f.doFuzz(l.GetCodeMsg, req)
@ -334,14 +335,14 @@ func Fuzz(input []byte) int {
if f.randomBool() { if f.randomBool() {
req.Reqs[i] = l.ProofReq{ req.Reqs[i] = l.ProofReq{
BHash: f.randomBlockHash(), BHash: f.randomBlockHash(),
AccKey: f.randomAddrHash(), AccountAddress: f.randomAddress(),
Key: f.randomAddrHash(), Key: f.randomAddress(),
FromLevel: uint(f.randomX(3)), FromLevel: uint(f.randomX(3)),
} }
} else { } else {
req.Reqs[i] = l.ProofReq{ req.Reqs[i] = l.ProofReq{
BHash: f.randomBlockHash(), BHash: f.randomBlockHash(),
Key: f.randomAddrHash(), Key: f.randomAddress(),
FromLevel: uint(f.randomX(3)), FromLevel: uint(f.randomX(3)),
} }
} }

@ -168,10 +168,15 @@ func (db *Database) Scheme() string {
// It is meant to be called when closing the blockchain object, so that all // It is meant to be called when closing the blockchain object, so that all
// resources held can be released correctly. // resources held can be released correctly.
func (db *Database) Close() error { func (db *Database) Close() error {
db.WritePreimages()
return db.backend.Close()
}
// WritePreimages flushes all accumulated preimages to disk forcibly.
func (db *Database) WritePreimages() {
if db.preimages != nil { if db.preimages != nil {
db.preimages.commit(true) db.preimages.commit(true)
} }
return db.backend.Close()
} }
// saveCache saves clean state cache to given directory path // saveCache saves clean state cache to given directory path