diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 17670750ed..deb4b52b4c 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -44,7 +45,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) { state, _ := New(common.Hash{}, db, nil) // Fill it with some arbitrary data - accounts := []*testAccount{} + var accounts []*testAccount for i := byte(0); i < 96; i++ { obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) acc := &testAccount{address: common.BytesToAddress([]byte{i})} @@ -59,6 +60,11 @@ func makeTestState() (Database, common.Hash, []*testAccount) { obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) acc.code = []byte{i, i, i, i, i} } + if i%5 == 0 { + for j := byte(0); j < 5; j++ { + obj.SetState(db, crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j}), crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j})) + } + } state.updateStateObject(obj) accounts = append(accounts, acc) } @@ -126,44 +132,94 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error { // Tests that an empty state is not scheduled for syncing. func TestEmptyStateSync(t *testing.T) { empty := common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - if req := NewStateSync(empty, rawdb.NewMemoryDatabase(), trie.NewSyncBloom(1, memorydb.New())).Missing(1); len(req) != 0 { - t.Errorf("content requested for empty state: %v", req) + sync := NewStateSync(empty, rawdb.NewMemoryDatabase(), trie.NewSyncBloom(1, memorydb.New())) + if nodes, paths, codes := sync.Missing(1); len(nodes) != 0 || len(paths) != 0 || len(codes) != 0 { + t.Errorf(" content requested for empty state: %v, %v, %v", nodes, paths, codes) } } // Tests that given a root hash, a state can sync iteratively on a single thread, // requesting retrieval tasks and returning all of them in one go. -func TestIterativeStateSyncIndividual(t *testing.T) { testIterativeStateSync(t, 1, false) } -func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t, 100, false) } -func TestIterativeStateSyncIndividualFromDisk(t *testing.T) { testIterativeStateSync(t, 1, true) } -func TestIterativeStateSyncBatchedFromDisk(t *testing.T) { testIterativeStateSync(t, 100, true) } +func TestIterativeStateSyncIndividual(t *testing.T) { + testIterativeStateSync(t, 1, false, false) +} +func TestIterativeStateSyncBatched(t *testing.T) { + testIterativeStateSync(t, 100, false, false) +} +func TestIterativeStateSyncIndividualFromDisk(t *testing.T) { + testIterativeStateSync(t, 1, true, false) +} +func TestIterativeStateSyncBatchedFromDisk(t *testing.T) { + testIterativeStateSync(t, 100, true, false) +} +func TestIterativeStateSyncIndividualByPath(t *testing.T) { + testIterativeStateSync(t, 1, false, true) +} +func TestIterativeStateSyncBatchedByPath(t *testing.T) { + testIterativeStateSync(t, 100, false, true) +} -func testIterativeStateSync(t *testing.T, count int, commit bool) { +func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) { // Create a random state to copy srcDb, srcRoot, srcAccounts := makeTestState() if commit { srcDb.TrieDB().Commit(srcRoot, false, nil) } + srcTrie, _ := trie.New(srcRoot, srcDb.TrieDB()) + // Create a destination state and sync with the scheduler dstDb := rawdb.NewMemoryDatabase() sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) - queue := append([]common.Hash{}, sched.Missing(count)...) - for len(queue) > 0 { - results := make([]trie.SyncResult, len(queue)) - for i, hash := range queue { + nodes, paths, codes := sched.Missing(count) + var ( + hashQueue []common.Hash + pathQueue []trie.SyncPath + ) + if !bypath { + hashQueue = append(append(hashQueue[:0], nodes...), codes...) + } else { + hashQueue = append(hashQueue[:0], codes...) + pathQueue = append(pathQueue[:0], paths...) + } + for len(hashQueue)+len(pathQueue) > 0 { + results := make([]trie.SyncResult, len(hashQueue)+len(pathQueue)) + for i, hash := range hashQueue { data, err := srcDb.TrieDB().Node(hash) if err != nil { data, err = srcDb.ContractCode(common.Hash{}, hash) } if err != nil { - t.Fatalf("failed to retrieve node data for %x", hash) + t.Fatalf("failed to retrieve node data for hash %x", hash) } results[i] = trie.SyncResult{Hash: hash, Data: data} } + for i, path := range pathQueue { + if len(path) == 1 { + data, _, err := srcTrie.TryGetNode(path[0]) + if err != nil { + t.Fatalf("failed to retrieve node data for path %x: %v", path, err) + } + results[len(hashQueue)+i] = trie.SyncResult{Hash: crypto.Keccak256Hash(data), Data: data} + } else { + var acc Account + if err := rlp.DecodeBytes(srcTrie.Get(path[0]), &acc); err != nil { + t.Fatalf("failed to decode account on path %x: %v", path, err) + } + stTrie, err := trie.New(acc.Root, srcDb.TrieDB()) + if err != nil { + t.Fatalf("failed to retriev storage trie for path %x: %v", path, err) + } + data, _, err := stTrie.TryGetNode(path[1]) + if err != nil { + t.Fatalf("failed to retrieve node data for path %x: %v", path, err) + } + results[len(hashQueue)+i] = trie.SyncResult{Hash: crypto.Keccak256Hash(data), Data: data} + } + } for _, result := range results { if err := sched.Process(result); err != nil { - t.Fatalf("failed to process result %v", err) + t.Errorf("failed to process result %v", err) } } batch := dstDb.NewBatch() @@ -171,7 +227,14 @@ func testIterativeStateSync(t *testing.T, count int, commit bool) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - queue = append(queue[:0], sched.Missing(count)...) + + nodes, paths, codes = sched.Missing(count) + if !bypath { + hashQueue = append(append(hashQueue[:0], nodes...), codes...) + } else { + hashQueue = append(hashQueue[:0], codes...) + pathQueue = append(pathQueue[:0], paths...) + } } // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) @@ -187,7 +250,9 @@ func TestIterativeDelayedStateSync(t *testing.T) { dstDb := rawdb.NewMemoryDatabase() sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) - queue := append([]common.Hash{}, sched.Missing(0)...) + nodes, _, codes := sched.Missing(0) + queue := append(append([]common.Hash{}, nodes...), codes...) + for len(queue) > 0 { // Sync only half of the scheduled nodes results := make([]trie.SyncResult, len(queue)/2+1) @@ -211,7 +276,9 @@ func TestIterativeDelayedStateSync(t *testing.T) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - queue = append(queue[len(results):], sched.Missing(0)...) + + nodes, _, codes = sched.Missing(0) + queue = append(append(queue[len(results):], nodes...), codes...) } // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) @@ -232,7 +299,8 @@ func testIterativeRandomStateSync(t *testing.T, count int) { sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(count) { + nodes, _, codes := sched.Missing(count) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -259,8 +327,10 @@ func testIterativeRandomStateSync(t *testing.T, count int) { t.Fatalf("failed to commit data: %v", err) } batch.Write() + queue = make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(count) { + nodes, _, codes = sched.Missing(count) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } } @@ -279,7 +349,8 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(0) { + nodes, _, codes := sched.Missing(0) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -312,7 +383,11 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - for _, hash := range sched.Missing(0) { + for _, result := range results { + delete(queue, result.Hash) + } + nodes, _, codes = sched.Missing(0) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } } @@ -341,8 +416,11 @@ func TestIncompleteStateSync(t *testing.T) { dstDb := rawdb.NewMemoryDatabase() sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) - added := []common.Hash{} - queue := append([]common.Hash{}, sched.Missing(1)...) + var added []common.Hash + + nodes, _, codes := sched.Missing(1) + queue := append(append([]common.Hash{}, nodes...), codes...) + for len(queue) > 0 { // Fetch a batch of state nodes results := make([]trie.SyncResult, len(queue)) @@ -382,7 +460,8 @@ func TestIncompleteStateSync(t *testing.T) { } } // Fetch the next batch to retrieve - queue = append(queue[:0], sched.Missing(1)...) + nodes, _, codes = sched.Missing(1) + queue = append(append(queue[:0], nodes...), codes...) } // Sanity check that removing any node from the database is detected for _, node := range added[1:] { diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index bf9e96fe2a..6745aa54ac 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -34,14 +34,15 @@ import ( // stateReq represents a batch of state fetch requests grouped together into // a single data retrieval network packet. type stateReq struct { - nItems uint16 // Number of items requested for download (max is 384, so uint16 is sufficient) - tasks map[common.Hash]*stateTask // Download tasks to track previous attempts - timeout time.Duration // Maximum round trip time for this to complete - timer *time.Timer // Timer to fire when the RTT timeout expires - peer *peerConnection // Peer that we're requesting from - delivered time.Time // Time when the packet was delivered (independent when we process it) - response [][]byte // Response data of the peer (nil for timeouts) - dropped bool // Flag whether the peer dropped off early + nItems uint16 // Number of items requested for download (max is 384, so uint16 is sufficient) + trieTasks map[common.Hash]*trieTask // Trie node download tasks to track previous attempts + codeTasks map[common.Hash]*codeTask // Byte code download tasks to track previous attempts + timeout time.Duration // Maximum round trip time for this to complete + timer *time.Timer // Timer to fire when the RTT timeout expires + peer *peerConnection // Peer that we're requesting from + delivered time.Time // Time when the packet was delivered (independent when we process it) + response [][]byte // Response data of the peer (nil for timeouts) + dropped bool // Flag whether the peer dropped off early } // timedOut returns if this request timed out. @@ -251,9 +252,11 @@ func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []* type stateSync struct { d *Downloader // Downloader instance to access and manage current peerset - sched *trie.Sync // State trie sync scheduler defining the tasks - keccak hash.Hash // Keccak256 hasher to verify deliveries with - tasks map[common.Hash]*stateTask // Set of tasks currently queued for retrieval + sched *trie.Sync // State trie sync scheduler defining the tasks + keccak hash.Hash // Keccak256 hasher to verify deliveries with + + trieTasks map[common.Hash]*trieTask // Set of trie node tasks currently queued for retrieval + codeTasks map[common.Hash]*codeTask // Set of byte code tasks currently queued for retrieval numUncommitted int bytesUncommitted int @@ -269,9 +272,16 @@ type stateSync struct { root common.Hash } -// stateTask represents a single trie node download task, containing a set of +// trieTask represents a single trie node download task, containing a set of // peers already attempted retrieval from to detect stalled syncs and abort. -type stateTask struct { +type trieTask struct { + path [][]byte + attempts map[string]struct{} +} + +// codeTask represents a single byte code download task, containing a set of +// peers already attempted retrieval from to detect stalled syncs and abort. +type codeTask struct { attempts map[string]struct{} } @@ -279,15 +289,16 @@ type stateTask struct { // yet start the sync. The user needs to call run to initiate. func newStateSync(d *Downloader, root common.Hash) *stateSync { return &stateSync{ - d: d, - sched: state.NewStateSync(root, d.stateDB, d.stateBloom), - keccak: sha3.NewLegacyKeccak256(), - tasks: make(map[common.Hash]*stateTask), - deliver: make(chan *stateReq), - cancel: make(chan struct{}), - done: make(chan struct{}), - started: make(chan struct{}), - root: root, + d: d, + sched: state.NewStateSync(root, d.stateDB, d.stateBloom), + keccak: sha3.NewLegacyKeccak256(), + trieTasks: make(map[common.Hash]*trieTask), + codeTasks: make(map[common.Hash]*codeTask), + deliver: make(chan *stateReq), + cancel: make(chan struct{}), + done: make(chan struct{}), + started: make(chan struct{}), + root: root, } } @@ -411,14 +422,15 @@ func (s *stateSync) assignTasks() { // Assign a batch of fetches proportional to the estimated latency/bandwidth cap := p.NodeDataCapacity(s.d.requestRTT()) req := &stateReq{peer: p, timeout: s.d.requestTTL()} - items := s.fillTasks(cap, req) + + nodes, _, codes := s.fillTasks(cap, req) // If the peer was assigned tasks to fetch, send the network request - if len(items) > 0 { - req.peer.log.Trace("Requesting new batch of data", "type", "state", "count", len(items), "root", s.root) + if len(nodes)+len(codes) > 0 { + req.peer.log.Trace("Requesting batch of state data", "nodes", len(nodes), "codes", len(codes), "root", s.root) select { case s.d.trackStateReq <- req: - req.peer.FetchNodeData(items) + req.peer.FetchNodeData(append(nodes, codes...)) // Unified retrieval under eth/6x case <-s.cancel: case <-s.d.cancelCh: } @@ -428,20 +440,34 @@ func (s *stateSync) assignTasks() { // fillTasks fills the given request object with a maximum of n state download // tasks to send to the remote peer. -func (s *stateSync) fillTasks(n int, req *stateReq) []common.Hash { +func (s *stateSync) fillTasks(n int, req *stateReq) (nodes []common.Hash, paths []trie.SyncPath, codes []common.Hash) { // Refill available tasks from the scheduler. - if len(s.tasks) < n { - new := s.sched.Missing(n - len(s.tasks)) - for _, hash := range new { - s.tasks[hash] = &stateTask{make(map[string]struct{})} + if fill := n - (len(s.trieTasks) + len(s.codeTasks)); fill > 0 { + nodes, paths, codes := s.sched.Missing(fill) + for i, hash := range nodes { + s.trieTasks[hash] = &trieTask{ + path: paths[i], + attempts: make(map[string]struct{}), + } + } + for _, hash := range codes { + s.codeTasks[hash] = &codeTask{ + attempts: make(map[string]struct{}), + } } } - // Find tasks that haven't been tried with the request's peer. - items := make([]common.Hash, 0, n) - req.tasks = make(map[common.Hash]*stateTask, n) - for hash, t := range s.tasks { + // Find tasks that haven't been tried with the request's peer. Prefer code + // over trie nodes as those can be written to disk and forgotten about. + nodes = make([]common.Hash, 0, n) + paths = make([]trie.SyncPath, 0, n) + codes = make([]common.Hash, 0, n) + + req.trieTasks = make(map[common.Hash]*trieTask, n) + req.codeTasks = make(map[common.Hash]*codeTask, n) + + for hash, t := range s.codeTasks { // Stop when we've gathered enough requests - if len(items) == n { + if len(nodes)+len(codes) == n { break } // Skip any requests we've already tried from this peer @@ -450,12 +476,30 @@ func (s *stateSync) fillTasks(n int, req *stateReq) []common.Hash { } // Assign the request to this peer t.attempts[req.peer.id] = struct{}{} - items = append(items, hash) - req.tasks[hash] = t - delete(s.tasks, hash) + codes = append(codes, hash) + req.codeTasks[hash] = t + delete(s.codeTasks, hash) } - req.nItems = uint16(len(items)) - return items + for hash, t := range s.trieTasks { + // Stop when we've gathered enough requests + if len(nodes)+len(codes) == n { + break + } + // Skip any requests we've already tried from this peer + if _, ok := t.attempts[req.peer.id]; ok { + continue + } + // Assign the request to this peer + t.attempts[req.peer.id] = struct{}{} + + nodes = append(nodes, hash) + paths = append(paths, t.path) + + req.trieTasks[hash] = t + delete(s.trieTasks, hash) + } + req.nItems = uint16(len(nodes) + len(codes)) + return nodes, paths, codes } // process iterates over a batch of delivered state data, injecting each item @@ -487,11 +531,13 @@ func (s *stateSync) process(req *stateReq) (int, error) { default: return successful, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err) } - delete(req.tasks, hash) + // Delete from both queues (one delivery is enough for the syncer) + delete(req.trieTasks, hash) + delete(req.codeTasks, hash) } // Put unfulfilled tasks back into the retry queue npeers := s.d.peers.Len() - for hash, task := range req.tasks { + for hash, task := range req.trieTasks { // If the node did deliver something, missing items may be due to a protocol // limit or a previous timeout + delayed delivery. Both cases should permit // the node to retry the missing items (to avoid single-peer stalls). @@ -501,10 +547,25 @@ func (s *stateSync) process(req *stateReq) (int, error) { // If we've requested the node too many times already, it may be a malicious // sync where nobody has the right data. Abort. if len(task.attempts) >= npeers { - return successful, fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) + return successful, fmt.Errorf("trie node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) } // Missing item, place into the retry queue. - s.tasks[hash] = task + s.trieTasks[hash] = task + } + for hash, task := range req.codeTasks { + // If the node did deliver something, missing items may be due to a protocol + // limit or a previous timeout + delayed delivery. Both cases should permit + // the node to retry the missing items (to avoid single-peer stalls). + if len(req.response) > 0 || req.timedOut() { + delete(task.attempts, req.peer.id) + } + // If we've requested the node too many times already, it may be a malicious + // sync where nobody has the right data. Abort. + if len(task.attempts) >= npeers { + return successful, fmt.Errorf("byte code %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) + } + // Missing item, place into the retry queue. + s.codeTasks[hash] = task } return successful, nil } @@ -533,7 +594,7 @@ func (s *stateSync) updateStats(written, duplicate, unexpected int, duration tim s.d.syncStatsState.unexpected += uint64(unexpected) if written > 0 || duplicate > 0 || unexpected > 0 { - log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "retry", len(s.tasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) + log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "trieretry", len(s.trieTasks), "coderetry", len(s.codeTasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) } if written > 0 { rawdb.WriteFastTrieProgress(s.d.stateDB, s.d.syncStatsState.processed) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index ae1bbc6aa9..87b364fb1b 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -79,6 +79,12 @@ func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { return t.trie.TryGet(t.hashKey(key)) } +// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not +// possible to use keybyte-encoding as the path might contain odd nibbles. +func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) { + return t.trie.TryGetNode(path) +} + // Update associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. diff --git a/trie/sync.go b/trie/sync.go index 147307fe71..bc93ddd3fb 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -52,6 +52,39 @@ type request struct { callback LeafCallback // Callback to invoke if a leaf node it reached on this branch } +// SyncPath is a path tuple identifying a particular trie node either in a single +// trie (account) or a layered trie (account -> storage). +// +// Content wise the tuple either has 1 element if it addresses a node in a single +// trie or 2 elements if it addresses a node in a stacked trie. +// +// To support aiming arbitrary trie nodes, the path needs to support odd nibble +// lengths. To avoid transferring expanded hex form over the network, the last +// part of the tuple (which needs to index into the middle of a trie) is compact +// encoded. In case of a 2-tuple, the first item is always 32 bytes so that is +// simple binary encoded. +// +// Examples: +// - Path 0x9 -> {0x19} +// - Path 0x99 -> {0x0099} +// - Path 0x01234567890123456789012345678901012345678901234567890123456789019 -> {0x0123456789012345678901234567890101234567890123456789012345678901, 0x19} +// - Path 0x012345678901234567890123456789010123456789012345678901234567890199 -> {0x0123456789012345678901234567890101234567890123456789012345678901, 0x0099} +type SyncPath [][]byte + +// newSyncPath converts an expanded trie path from nibble form into a compact +// version that can be sent over the network. +func newSyncPath(path []byte) SyncPath { + // If the hash is from the account trie, append a single item, if it + // is from the a storage trie, append a tuple. Note, the length 64 is + // clashing between account leaf and storage root. It's fine though + // because having a trie node at 64 depth means a hash collision was + // found and we're long dead. + if len(path) < 64 { + return SyncPath{hexToCompact(path)} + } + return SyncPath{hexToKeybytes(path[:64]), hexToCompact(path[64:])} +} + // SyncResult is a response with requested data along with it's hash. type SyncResult struct { Hash common.Hash // Hash of the originally unknown trie node @@ -193,10 +226,16 @@ func (s *Sync) AddCodeEntry(hash common.Hash, path []byte, parent common.Hash) { s.schedule(req) } -// Missing retrieves the known missing nodes from the trie for retrieval. -func (s *Sync) Missing(max int) []common.Hash { - var requests []common.Hash - for !s.queue.Empty() && (max == 0 || len(requests) < max) { +// Missing retrieves the known missing nodes from the trie for retrieval. To aid +// both eth/6x style fast sync and snap/1x style state sync, the paths of trie +// nodes are returned too, as well as separate hash list for codes. +func (s *Sync) Missing(max int) (nodes []common.Hash, paths []SyncPath, codes []common.Hash) { + var ( + nodeHashes []common.Hash + nodePaths []SyncPath + codeHashes []common.Hash + ) + for !s.queue.Empty() && (max == 0 || len(nodeHashes)+len(codeHashes) < max) { // Retrieve th enext item in line item, prio := s.queue.Peek() @@ -208,9 +247,16 @@ func (s *Sync) Missing(max int) []common.Hash { // Item is allowed to be scheduled, add it to the task list s.queue.Pop() s.fetches[depth]++ - requests = append(requests, item.(common.Hash)) + + hash := item.(common.Hash) + if req, ok := s.nodeReqs[hash]; ok { + nodeHashes = append(nodeHashes, hash) + nodePaths = append(nodePaths, newSyncPath(req.path)) + } else { + codeHashes = append(codeHashes, hash) + } } - return requests + return nodeHashes, nodePaths, codeHashes } // Process injects the received data for requested item. Note it can @@ -322,9 +368,13 @@ func (s *Sync) children(req *request, object node) ([]*request, error) { switch node := (object).(type) { case *shortNode: + key := node.Key + if hasTerm(key) { + key = key[:len(key)-1] + } children = []child{{ node: node.Val, - path: append(append([]byte(nil), req.path...), node.Key...), + path: append(append([]byte(nil), req.path...), key...), }} case *fullNode: for i := 0; i < 17; i++ { @@ -344,7 +394,7 @@ func (s *Sync) children(req *request, object node) ([]*request, error) { // Notify any external watcher of a new key/value node if req.callback != nil { if node, ok := (child.node).(valueNode); ok { - if err := req.callback(req.path, node, req.hash); err != nil { + if err := req.callback(child.path, node, req.hash); err != nil { return nil, err } } diff --git a/trie/sync_test.go b/trie/sync_test.go index 34f3990576..39e0f9575e 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -21,14 +21,15 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" ) // makeTestTrie create a sample test trie to test node-wise reconstruction. -func makeTestTrie() (*Database, *Trie, map[string][]byte) { +func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(memorydb.New()) - trie, _ := New(common.Hash{}, triedb) + trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -59,7 +60,7 @@ func makeTestTrie() (*Database, *Trie, map[string][]byte) { // content map. func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) { // Check root availability and trie contents - trie, err := New(common.BytesToHash(root), db) + trie, err := NewSecure(common.BytesToHash(root), db) if err != nil { t.Fatalf("failed to create trie at %x: %v", root, err) } @@ -76,7 +77,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri // checkTrieConsistency checks that all nodes in a trie are indeed present. func checkTrieConsistency(db *Database, root common.Hash) error { // Create and iterate a trie rooted in a subnode - trie, err := New(root, db) + trie, err := NewSecure(root, db) if err != nil { return nil // Consider a non existent state consistent } @@ -94,18 +95,21 @@ func TestEmptySync(t *testing.T) { emptyB, _ := New(emptyRoot, dbB) for i, trie := range []*Trie{emptyA, emptyB} { - if req := NewSync(trie.Hash(), memorydb.New(), nil, NewSyncBloom(1, memorydb.New())).Missing(1); len(req) != 0 { - t.Errorf("test %d: content requested for empty trie: %v", i, req) + sync := NewSync(trie.Hash(), memorydb.New(), nil, NewSyncBloom(1, memorydb.New())) + if nodes, paths, codes := sync.Missing(1); len(nodes) != 0 || len(paths) != 0 || len(codes) != 0 { + t.Errorf("test %d: content requested for empty trie: %v, %v, %v", i, nodes, paths, codes) } } } // Tests that given a root hash, a trie can sync iteratively on a single thread, // requesting retrieval tasks and returning all of them in one go. -func TestIterativeSyncIndividual(t *testing.T) { testIterativeSync(t, 1) } -func TestIterativeSyncBatched(t *testing.T) { testIterativeSync(t, 100) } +func TestIterativeSyncIndividual(t *testing.T) { testIterativeSync(t, 1, false) } +func TestIterativeSyncBatched(t *testing.T) { testIterativeSync(t, 100, false) } +func TestIterativeSyncIndividualByPath(t *testing.T) { testIterativeSync(t, 1, true) } +func TestIterativeSyncBatchedByPath(t *testing.T) { testIterativeSync(t, 100, true) } -func testIterativeSync(t *testing.T, count int) { +func testIterativeSync(t *testing.T, count int, bypath bool) { // Create a random trie to copy srcDb, srcTrie, srcData := makeTestTrie() @@ -114,16 +118,33 @@ func testIterativeSync(t *testing.T, count int) { triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) - queue := append([]common.Hash{}, sched.Missing(count)...) - for len(queue) > 0 { - results := make([]SyncResult, len(queue)) - for i, hash := range queue { + nodes, paths, codes := sched.Missing(count) + var ( + hashQueue []common.Hash + pathQueue []SyncPath + ) + if !bypath { + hashQueue = append(append(hashQueue[:0], nodes...), codes...) + } else { + hashQueue = append(hashQueue[:0], codes...) + pathQueue = append(pathQueue[:0], paths...) + } + for len(hashQueue)+len(pathQueue) > 0 { + results := make([]SyncResult, len(hashQueue)+len(pathQueue)) + for i, hash := range hashQueue { data, err := srcDb.Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for hash %x: %v", hash, err) } results[i] = SyncResult{hash, data} } + for i, path := range pathQueue { + data, _, err := srcTrie.TryGetNode(path[0]) + if err != nil { + t.Fatalf("failed to retrieve node data for path %x: %v", path, err) + } + results[len(hashQueue)+i] = SyncResult{crypto.Keccak256Hash(data), data} + } for _, result := range results { if err := sched.Process(result); err != nil { t.Fatalf("failed to process result %v", err) @@ -134,7 +155,14 @@ func testIterativeSync(t *testing.T, count int) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - queue = append(queue[:0], sched.Missing(count)...) + + nodes, paths, codes = sched.Missing(count) + if !bypath { + hashQueue = append(append(hashQueue[:0], nodes...), codes...) + } else { + hashQueue = append(hashQueue[:0], codes...) + pathQueue = append(pathQueue[:0], paths...) + } } // Cross check that the two tries are in sync checkTrieContents(t, triedb, srcTrie.Hash().Bytes(), srcData) @@ -151,7 +179,9 @@ func TestIterativeDelayedSync(t *testing.T) { triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) - queue := append([]common.Hash{}, sched.Missing(10000)...) + nodes, _, codes := sched.Missing(10000) + queue := append(append([]common.Hash{}, nodes...), codes...) + for len(queue) > 0 { // Sync only half of the scheduled nodes results := make([]SyncResult, len(queue)/2+1) @@ -172,7 +202,9 @@ func TestIterativeDelayedSync(t *testing.T) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - queue = append(queue[len(results):], sched.Missing(10000)...) + + nodes, _, codes = sched.Missing(10000) + queue = append(append(queue[len(results):], nodes...), codes...) } // Cross check that the two tries are in sync checkTrieContents(t, triedb, srcTrie.Hash().Bytes(), srcData) @@ -194,7 +226,8 @@ func testIterativeRandomSync(t *testing.T, count int) { sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(count) { + nodes, _, codes := sched.Missing(count) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -218,8 +251,10 @@ func testIterativeRandomSync(t *testing.T, count int) { t.Fatalf("failed to commit data: %v", err) } batch.Write() + queue = make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(count) { + nodes, _, codes = sched.Missing(count) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } } @@ -239,7 +274,8 @@ func TestIterativeRandomDelayedSync(t *testing.T) { sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(10000) { + nodes, _, codes := sched.Missing(10000) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -270,7 +306,8 @@ func TestIterativeRandomDelayedSync(t *testing.T) { for _, result := range results { delete(queue, result.Hash) } - for _, hash := range sched.Missing(10000) { + nodes, _, codes = sched.Missing(10000) + for _, hash := range append(nodes, codes...) { queue[hash] = struct{}{} } } @@ -289,7 +326,8 @@ func TestDuplicateAvoidanceSync(t *testing.T) { triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) - queue := append([]common.Hash{}, sched.Missing(0)...) + nodes, _, codes := sched.Missing(0) + queue := append(append([]common.Hash{}, nodes...), codes...) requested := make(map[common.Hash]struct{}) for len(queue) > 0 { @@ -316,7 +354,9 @@ func TestDuplicateAvoidanceSync(t *testing.T) { t.Fatalf("failed to commit data: %v", err) } batch.Write() - queue = append(queue[:0], sched.Missing(0)...) + + nodes, _, codes = sched.Missing(0) + queue = append(append(queue[:0], nodes...), codes...) } // Cross check that the two tries are in sync checkTrieContents(t, triedb, srcTrie.Hash().Bytes(), srcData) @@ -334,7 +374,10 @@ func TestIncompleteSync(t *testing.T) { sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) var added []common.Hash - queue := append([]common.Hash{}, sched.Missing(1)...) + + nodes, _, codes := sched.Missing(1) + queue := append(append([]common.Hash{}, nodes...), codes...) + for len(queue) > 0 { // Fetch a batch of trie nodes results := make([]SyncResult, len(queue)) @@ -366,7 +409,8 @@ func TestIncompleteSync(t *testing.T) { } } // Fetch the next batch to retrieve - queue = append(queue[:0], sched.Missing(1)...) + nodes, _, codes = sched.Missing(1) + queue = append(append(queue[:0], nodes...), codes...) } // Sanity check that removing any node from the database is detected for _, node := range added[1:] { @@ -380,3 +424,58 @@ func TestIncompleteSync(t *testing.T) { diskdb.Put(key, value) } } + +// Tests that trie nodes get scheduled lexicographically when having the same +// depth. +func TestSyncOrdering(t *testing.T) { + // Create a random trie to copy + srcDb, srcTrie, srcData := makeTestTrie() + + // Create a destination trie and sync with the scheduler, tracking the requests + diskdb := memorydb.New() + triedb := NewDatabase(diskdb) + sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) + + nodes, paths, _ := sched.Missing(1) + queue := append([]common.Hash{}, nodes...) + reqs := append([]SyncPath{}, paths...) + + for len(queue) > 0 { + results := make([]SyncResult, len(queue)) + for i, hash := range queue { + data, err := srcDb.Node(hash) + if err != nil { + t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + } + results[i] = SyncResult{hash, data} + } + for _, result := range results { + if err := sched.Process(result); err != nil { + t.Fatalf("failed to process result %v", err) + } + } + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) + } + batch.Write() + + nodes, paths, _ = sched.Missing(1) + queue = append(queue[:0], nodes...) + reqs = append(reqs, paths...) + } + // Cross check that the two tries are in sync + checkTrieContents(t, triedb, srcTrie.Hash().Bytes(), srcData) + + // Check that the trie nodes have been requested path-ordered + for i := 0; i < len(reqs)-1; i++ { + if len(reqs[i]) > 1 || len(reqs[i+1]) > 1 { + // In the case of the trie tests, there's no storage so the tuples + // must always be single items. 2-tuples should be tested in state. + t.Errorf("Invalid request tuples: len(%v) or len(%v) > 1", reqs[i], reqs[i+1]) + } + if bytes.Compare(compactToHex(reqs[i][0]), compactToHex(reqs[i+1][0])) > 0 { + t.Errorf("Invalid request order: %v before %v", compactToHex(reqs[i][0]), compactToHex(reqs[i+1][0])) + } + } +} diff --git a/trie/trie.go b/trie/trie.go index 7ccd37f872..1e1749a4ff 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" ) var ( @@ -102,8 +103,7 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = keybytesToHex(key) - value, newroot, didResolve, err := t.tryGet(t.root, key, 0) + value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) if err == nil && didResolve { t.root = newroot } @@ -146,6 +146,86 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } } +// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not +// possible to use keybyte-encoding as the path might contain odd nibbles. +func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { + item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) + if err != nil { + return nil, resolved, err + } + if resolved > 0 { + t.root = newroot + } + if item == nil { + return nil, resolved, nil + } + enc, err := rlp.EncodeToBytes(item) + if err != nil { + log.Error("Encoding existing trie node failed", "err", err) + return nil, resolved, err + } + return enc, resolved, err +} + +func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item node, newnode node, resolved int, err error) { + // If we reached the requested path, return the current node + if pos >= len(path) { + // Don't return collapsed hash nodes though + if _, ok := origNode.(hashNode); !ok { + // Short nodes have expanded keys, compact them before returning + item := origNode + if sn, ok := item.(*shortNode); ok { + item = &shortNode{ + Key: hexToCompact(sn.Key), + Val: sn.Val, + } + } + return item, origNode, 0, nil + } + } + // Path still needs to be traversed, descend into children + switch n := (origNode).(type) { + case nil: + // Non-existent path requested, abort + return nil, nil, 0, nil + + case valueNode: + // Path prematurely ended, abort + return nil, nil, 0, nil + + case *shortNode: + if len(path)-pos < len(n.Key) || !bytes.Equal(n.Key, path[pos:pos+len(n.Key)]) { + // Path branches off from short node + return nil, n, 0, nil + } + item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key)) + if err == nil && resolved > 0 { + n = n.copy() + n.Val = newnode + } + return item, n, resolved, err + + case *fullNode: + item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1) + if err == nil && resolved > 0 { + n = n.copy() + n.Children[path[pos]] = newnode + } + return item, n, resolved, err + + case hashNode: + child, err := t.resolveHash(n, path[:pos]) + if err != nil { + return nil, n, 1, err + } + item, newnode, resolved, err := t.tryGetNode(child, path, pos) + return item, newnode, resolved + 1, err + + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + // Update associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil.