diff --git a/eth/api_backend.go b/eth/api_backend.go
index 49de70e210..7b40a7edd3 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -21,6 +21,7 @@ import (
"errors"
"math/big"
+ "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
@@ -30,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
- "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
@@ -279,8 +279,8 @@ func (b *EthAPIBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.S
return b.eth.TxPool().SubscribeNewTxsEvent(ch)
}
-func (b *EthAPIBackend) Downloader() *downloader.Downloader {
- return b.eth.Downloader()
+func (b *EthAPIBackend) SyncProgress() ethereum.SyncProgress {
+ return b.eth.Downloader().Progress()
}
func (b *EthAPIBackend) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go
index 148359110c..55c0c880f3 100644
--- a/ethstats/ethstats.go
+++ b/ethstats/ethstats.go
@@ -30,12 +30,12 @@ import (
"sync"
"time"
+ "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth/downloader"
ethproto "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/les"
@@ -67,7 +67,7 @@ type backend interface {
HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error)
GetTd(ctx context.Context, hash common.Hash) *big.Int
Stats() (pending int, queued int)
- Downloader() *downloader.Downloader
+ SyncProgress() ethereum.SyncProgress
}
// fullNodeBackend encompasses the functionality necessary for a full node
@@ -777,7 +777,7 @@ func (s *Service) reportStats(conn *connWrapper) error {
mining = fullBackend.Miner().Mining()
hashrate = int(fullBackend.Miner().Hashrate())
- sync := fullBackend.Downloader().Progress()
+ sync := fullBackend.SyncProgress()
syncing = fullBackend.CurrentHeader().Number.Uint64() >= sync.HighestBlock
price, _ := fullBackend.SuggestGasTipCap(context.Background())
@@ -786,7 +786,7 @@ func (s *Service) reportStats(conn *connWrapper) error {
gasprice += int(basefee.Uint64())
}
} else {
- sync := s.backend.Downloader().Progress()
+ sync := s.backend.SyncProgress()
syncing = s.backend.CurrentHeader().Number.Uint64() >= sync.HighestBlock
}
// Assemble the node stats and send it to the server
diff --git a/graphql/graphql.go b/graphql/graphql.go
index d35994234e..4dd96c4b9d 100644
--- a/graphql/graphql.go
+++ b/graphql/graphql.go
@@ -1248,7 +1248,7 @@ func (s *SyncState) KnownStates() *hexutil.Uint64 {
// - pulledStates: number of state entries processed until now
// - knownStates: number of known state entries that still need to be pulled
func (r *Resolver) Syncing() (*SyncState, error) {
- progress := r.backend.Downloader().Progress()
+ progress := r.backend.SyncProgress()
// Return not syncing if the synchronisation already completed
if progress.CurrentBlock >= progress.HighestBlock {
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index 9a82824ada..6997f2c828 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -122,7 +122,7 @@ func (s *PublicEthereumAPI) FeeHistory(ctx context.Context, blockCount rpc.Decim
// - pulledStates: number of state entries processed until now
// - knownStates: number of known state entries that still need to be pulled
func (s *PublicEthereumAPI) Syncing() (interface{}, error) {
- progress := s.b.Downloader().Progress()
+ progress := s.b.SyncProgress()
// Return not syncing if the synchronisation already completed
if progress.CurrentBlock >= progress.HighestBlock {
diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go
index 9954545821..1624f49635 100644
--- a/internal/ethapi/backend.go
+++ b/internal/ethapi/backend.go
@@ -21,6 +21,7 @@ import (
"context"
"math/big"
+ "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
@@ -29,7 +30,6 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
- "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/params"
@@ -40,7 +40,8 @@ import (
// both full and light clients) with access to necessary functions.
type Backend interface {
// General Ethereum API
- Downloader() *downloader.Downloader
+ SyncProgress() ethereum.SyncProgress
+
SuggestGasTipCap(ctx context.Context) (*big.Int, error)
FeeHistory(ctx context.Context, blockCount int, lastBlock rpc.BlockNumber, rewardPercentiles []float64) (*big.Int, [][]*big.Int, []*big.Int, []float64, error)
ChainDb() ethdb.Database
diff --git a/les/api_backend.go b/les/api_backend.go
index 9c80270da0..e12984cb49 100644
--- a/les/api_backend.go
+++ b/les/api_backend.go
@@ -21,6 +21,7 @@ import (
"errors"
"math/big"
+ "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
@@ -30,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
- "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
@@ -257,8 +257,8 @@ func (b *LesApiBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEven
return b.eth.blockchain.SubscribeRemovedLogsEvent(ch)
}
-func (b *LesApiBackend) Downloader() *downloader.Downloader {
- return b.eth.Downloader()
+func (b *LesApiBackend) SyncProgress() ethereum.SyncProgress {
+ return b.eth.Downloader().Progress()
}
func (b *LesApiBackend) ProtocolVersion() int {
diff --git a/les/api_test.go b/les/api_test.go
index f7017c5d98..6a19b0fe4f 100644
--- a/les/api_test.go
+++ b/les/api_test.go
@@ -32,8 +32,9 @@ import (
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/eth"
- "github.com/ethereum/go-ethereum/eth/downloader"
+ ethdownloader "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/ethconfig"
+ "github.com/ethereum/go-ethereum/les/downloader"
"github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node"
@@ -494,14 +495,14 @@ func testSim(t *testing.T, serverCount, clientCount int, serverDir, clientDir []
func newLesClientService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
config := ethconfig.Defaults
- config.SyncMode = downloader.LightSync
+ config.SyncMode = (ethdownloader.SyncMode)(downloader.LightSync)
config.Ethash.PowMode = ethash.ModeFake
return New(stack, &config)
}
func newLesServerService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) {
config := ethconfig.Defaults
- config.SyncMode = downloader.FullSync
+ config.SyncMode = (ethdownloader.SyncMode)(downloader.FullSync)
config.LightServ = testServerCapacity
config.LightPeers = testMaxClients
ethereum, err := eth.New(stack, &config)
diff --git a/les/client.go b/les/client.go
index 1d8a2c6f9a..5d07c783e9 100644
--- a/les/client.go
+++ b/les/client.go
@@ -30,12 +30,12 @@ import (
"github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/ethconfig"
"github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi"
+ "github.com/ethereum/go-ethereum/les/downloader"
"github.com/ethereum/go-ethereum/les/vflux"
vfc "github.com/ethereum/go-ethereum/les/vflux/client"
"github.com/ethereum/go-ethereum/light"
diff --git a/les/client_handler.go b/les/client_handler.go
index 4a550b2074..9583bd57ca 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -28,8 +28,8 @@ import (
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/les/downloader"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
diff --git a/les/downloader/api.go b/les/downloader/api.go
new file mode 100644
index 0000000000..2024d23dea
--- /dev/null
+++ b/les/downloader/api.go
@@ -0,0 +1,166 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "context"
+ "sync"
+
+ "github.com/ethereum/go-ethereum"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/rpc"
+)
+
+// PublicDownloaderAPI provides an API which gives information about the current synchronisation status.
+// It offers only methods that operates on data that can be available to anyone without security risks.
+type PublicDownloaderAPI struct {
+ d *Downloader
+ mux *event.TypeMux
+ installSyncSubscription chan chan interface{}
+ uninstallSyncSubscription chan *uninstallSyncSubscriptionRequest
+}
+
+// NewPublicDownloaderAPI create a new PublicDownloaderAPI. The API has an internal event loop that
+// listens for events from the downloader through the global event mux. In case it receives one of
+// these events it broadcasts it to all syncing subscriptions that are installed through the
+// installSyncSubscription channel.
+func NewPublicDownloaderAPI(d *Downloader, m *event.TypeMux) *PublicDownloaderAPI {
+ api := &PublicDownloaderAPI{
+ d: d,
+ mux: m,
+ installSyncSubscription: make(chan chan interface{}),
+ uninstallSyncSubscription: make(chan *uninstallSyncSubscriptionRequest),
+ }
+
+ go api.eventLoop()
+
+ return api
+}
+
+// eventLoop runs a loop until the event mux closes. It will install and uninstall new
+// sync subscriptions and broadcasts sync status updates to the installed sync subscriptions.
+func (api *PublicDownloaderAPI) eventLoop() {
+ var (
+ sub = api.mux.Subscribe(StartEvent{}, DoneEvent{}, FailedEvent{})
+ syncSubscriptions = make(map[chan interface{}]struct{})
+ )
+
+ for {
+ select {
+ case i := <-api.installSyncSubscription:
+ syncSubscriptions[i] = struct{}{}
+ case u := <-api.uninstallSyncSubscription:
+ delete(syncSubscriptions, u.c)
+ close(u.uninstalled)
+ case event := <-sub.Chan():
+ if event == nil {
+ return
+ }
+
+ var notification interface{}
+ switch event.Data.(type) {
+ case StartEvent:
+ notification = &SyncingResult{
+ Syncing: true,
+ Status: api.d.Progress(),
+ }
+ case DoneEvent, FailedEvent:
+ notification = false
+ }
+ // broadcast
+ for c := range syncSubscriptions {
+ c <- notification
+ }
+ }
+ }
+}
+
+// Syncing provides information when this nodes starts synchronising with the Ethereum network and when it's finished.
+func (api *PublicDownloaderAPI) Syncing(ctx context.Context) (*rpc.Subscription, error) {
+ notifier, supported := rpc.NotifierFromContext(ctx)
+ if !supported {
+ return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported
+ }
+
+ rpcSub := notifier.CreateSubscription()
+
+ go func() {
+ statuses := make(chan interface{})
+ sub := api.SubscribeSyncStatus(statuses)
+
+ for {
+ select {
+ case status := <-statuses:
+ notifier.Notify(rpcSub.ID, status)
+ case <-rpcSub.Err():
+ sub.Unsubscribe()
+ return
+ case <-notifier.Closed():
+ sub.Unsubscribe()
+ return
+ }
+ }
+ }()
+
+ return rpcSub, nil
+}
+
+// SyncingResult provides information about the current synchronisation status for this node.
+type SyncingResult struct {
+ Syncing bool `json:"syncing"`
+ Status ethereum.SyncProgress `json:"status"`
+}
+
+// uninstallSyncSubscriptionRequest uninstalles a syncing subscription in the API event loop.
+type uninstallSyncSubscriptionRequest struct {
+ c chan interface{}
+ uninstalled chan interface{}
+}
+
+// SyncStatusSubscription represents a syncing subscription.
+type SyncStatusSubscription struct {
+ api *PublicDownloaderAPI // register subscription in event loop of this api instance
+ c chan interface{} // channel where events are broadcasted to
+ unsubOnce sync.Once // make sure unsubscribe logic is executed once
+}
+
+// Unsubscribe uninstalls the subscription from the DownloadAPI event loop.
+// The status channel that was passed to subscribeSyncStatus isn't used anymore
+// after this method returns.
+func (s *SyncStatusSubscription) Unsubscribe() {
+ s.unsubOnce.Do(func() {
+ req := uninstallSyncSubscriptionRequest{s.c, make(chan interface{})}
+ s.api.uninstallSyncSubscription <- &req
+
+ for {
+ select {
+ case <-s.c:
+ // drop new status events until uninstall confirmation
+ continue
+ case <-req.uninstalled:
+ return
+ }
+ }
+ })
+}
+
+// SubscribeSyncStatus creates a subscription that will broadcast new synchronisation updates.
+// The given channel must receive interface values, the result can either
+func (api *PublicDownloaderAPI) SubscribeSyncStatus(status chan interface{}) *SyncStatusSubscription {
+ api.installSyncSubscription <- status
+ return &SyncStatusSubscription{api: api, c: status}
+}
diff --git a/les/downloader/downloader.go b/les/downloader/downloader.go
new file mode 100644
index 0000000000..e7dfc4158e
--- /dev/null
+++ b/les/downloader/downloader.go
@@ -0,0 +1,2014 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// This is a temporary package whilst working on the eth/66 blocking refactors.
+// After that work is done, les needs to be refactored to use the new package,
+// or alternatively use a stripped down version of it. Either way, we need to
+// keep the changes scoped so duplicating temporarily seems the sanest.
+package downloader
+
+import (
+ "errors"
+ "fmt"
+ "math/big"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state/snapshot"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/eth/protocols/snap"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+var (
+ MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request
+ MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request
+ MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly
+ MaxReceiptFetch = 256 // Amount of transaction receipts to allow fetching per request
+ MaxStateFetch = 384 // Amount of node state values to allow fetching per request
+
+ maxQueuedHeaders = 32 * 1024 // [eth/62] Maximum number of headers to queue for import (DOS protection)
+ maxHeadersProcess = 2048 // Number of header download results to import at once into the chain
+ maxResultsProcess = 2048 // Number of content download results to import at once into the chain
+ fullMaxForkAncestry uint64 = params.FullImmutabilityThreshold // Maximum chain reorganisation (locally redeclared so tests can reduce it)
+ lightMaxForkAncestry uint64 = params.LightImmutabilityThreshold // Maximum chain reorganisation (locally redeclared so tests can reduce it)
+
+ reorgProtThreshold = 48 // Threshold number of recent blocks to disable mini reorg protection
+ reorgProtHeaderDelay = 2 // Number of headers to delay delivering to cover mini reorgs
+
+ fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync
+ fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected
+ fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it
+ fsHeaderContCheck = 3 * time.Second // Time interval to check for header continuations during state download
+ fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync
+)
+
+var (
+ errBusy = errors.New("busy")
+ errUnknownPeer = errors.New("peer is unknown or unhealthy")
+ errBadPeer = errors.New("action from bad peer ignored")
+ errStallingPeer = errors.New("peer is stalling")
+ errUnsyncedPeer = errors.New("unsynced peer")
+ errNoPeers = errors.New("no peers to keep download active")
+ errTimeout = errors.New("timeout")
+ errEmptyHeaderSet = errors.New("empty header set by peer")
+ errPeersUnavailable = errors.New("no peers available or all tried for download")
+ errInvalidAncestor = errors.New("retrieved ancestor is invalid")
+ errInvalidChain = errors.New("retrieved hash chain is invalid")
+ errInvalidBody = errors.New("retrieved block body is invalid")
+ errInvalidReceipt = errors.New("retrieved receipt is invalid")
+ errCancelStateFetch = errors.New("state data download canceled (requested)")
+ errCancelContentProcessing = errors.New("content processing canceled (requested)")
+ errCanceled = errors.New("syncing canceled (requested)")
+ errNoSyncActive = errors.New("no sync active")
+ errTooOld = errors.New("peer's protocol version too old")
+ errNoAncestorFound = errors.New("no common ancestor found")
+)
+
+type Downloader struct {
+ mode uint32 // Synchronisation mode defining the strategy used (per sync cycle), use d.getMode() to get the SyncMode
+ mux *event.TypeMux // Event multiplexer to announce sync operation events
+
+ checkpoint uint64 // Checkpoint block number to enforce head against (e.g. fast sync)
+ genesis uint64 // Genesis block number to limit sync to (e.g. light client CHT)
+ queue *queue // Scheduler for selecting the hashes to download
+ peers *peerSet // Set of active peers from which download can proceed
+
+ stateDB ethdb.Database // Database to state sync into (and deduplicate via)
+ stateBloom *trie.SyncBloom // Bloom filter for fast trie node and contract code existence checks
+
+ // Statistics
+ syncStatsChainOrigin uint64 // Origin block number where syncing started at
+ syncStatsChainHeight uint64 // Highest block number known when syncing started
+ syncStatsState stateSyncStats
+ syncStatsLock sync.RWMutex // Lock protecting the sync stats fields
+
+ lightchain LightChain
+ blockchain BlockChain
+
+ // Callbacks
+ dropPeer peerDropFn // Drops a peer for misbehaving
+
+ // Status
+ synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
+ synchronising int32
+ notified int32
+ committed int32
+ ancientLimit uint64 // The maximum block number which can be regarded as ancient data.
+
+ // Channels
+ headerCh chan dataPack // Channel receiving inbound block headers
+ bodyCh chan dataPack // Channel receiving inbound block bodies
+ receiptCh chan dataPack // Channel receiving inbound receipts
+ bodyWakeCh chan bool // Channel to signal the block body fetcher of new tasks
+ receiptWakeCh chan bool // Channel to signal the receipt fetcher of new tasks
+ headerProcCh chan []*types.Header // Channel to feed the header processor new tasks
+
+ // State sync
+ pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root
+ pivotLock sync.RWMutex // Lock protecting pivot header reads from updates
+
+ snapSync bool // Whether to run state sync over the snap protocol
+ SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now
+ stateSyncStart chan *stateSync
+ trackStateReq chan *stateReq
+ stateCh chan dataPack // Channel receiving inbound node state data
+
+ // Cancellation and termination
+ cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop)
+ cancelCh chan struct{} // Channel to cancel mid-flight syncs
+ cancelLock sync.RWMutex // Lock to protect the cancel channel and peer in delivers
+ cancelWg sync.WaitGroup // Make sure all fetcher goroutines have exited.
+
+ quitCh chan struct{} // Quit channel to signal termination
+ quitLock sync.Mutex // Lock to prevent double closes
+
+ // Testing hooks
+ syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run
+ bodyFetchHook func([]*types.Header) // Method to call upon starting a block body fetch
+ receiptFetchHook func([]*types.Header) // Method to call upon starting a receipt fetch
+ chainInsertHook func([]*fetchResult) // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
+}
+
+// LightChain encapsulates functions required to synchronise a light chain.
+type LightChain interface {
+ // HasHeader verifies a header's presence in the local chain.
+ HasHeader(common.Hash, uint64) bool
+
+ // GetHeaderByHash retrieves a header from the local chain.
+ GetHeaderByHash(common.Hash) *types.Header
+
+ // CurrentHeader retrieves the head header from the local chain.
+ CurrentHeader() *types.Header
+
+ // GetTd returns the total difficulty of a local block.
+ GetTd(common.Hash, uint64) *big.Int
+
+ // InsertHeaderChain inserts a batch of headers into the local chain.
+ InsertHeaderChain([]*types.Header, int) (int, error)
+
+ // SetHead rewinds the local chain to a new head.
+ SetHead(uint64) error
+}
+
+// BlockChain encapsulates functions required to sync a (full or fast) blockchain.
+type BlockChain interface {
+ LightChain
+
+ // HasBlock verifies a block's presence in the local chain.
+ HasBlock(common.Hash, uint64) bool
+
+ // HasFastBlock verifies a fast block's presence in the local chain.
+ HasFastBlock(common.Hash, uint64) bool
+
+ // GetBlockByHash retrieves a block from the local chain.
+ GetBlockByHash(common.Hash) *types.Block
+
+ // CurrentBlock retrieves the head block from the local chain.
+ CurrentBlock() *types.Block
+
+ // CurrentFastBlock retrieves the head fast block from the local chain.
+ CurrentFastBlock() *types.Block
+
+ // FastSyncCommitHead directly commits the head block to a certain entity.
+ FastSyncCommitHead(common.Hash) error
+
+ // InsertChain inserts a batch of blocks into the local chain.
+ InsertChain(types.Blocks) (int, error)
+
+ // InsertReceiptChain inserts a batch of receipts into the local chain.
+ InsertReceiptChain(types.Blocks, []types.Receipts, uint64) (int, error)
+
+ // Snapshots returns the blockchain snapshot tree to paused it during sync.
+ Snapshots() *snapshot.Tree
+}
+
+// New creates a new downloader to fetch hashes and blocks from remote peers.
+func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn) *Downloader {
+ if lightchain == nil {
+ lightchain = chain
+ }
+ dl := &Downloader{
+ stateDB: stateDb,
+ stateBloom: stateBloom,
+ mux: mux,
+ checkpoint: checkpoint,
+ queue: newQueue(blockCacheMaxItems, blockCacheInitialItems),
+ peers: newPeerSet(),
+ blockchain: chain,
+ lightchain: lightchain,
+ dropPeer: dropPeer,
+ headerCh: make(chan dataPack, 1),
+ bodyCh: make(chan dataPack, 1),
+ receiptCh: make(chan dataPack, 1),
+ bodyWakeCh: make(chan bool, 1),
+ receiptWakeCh: make(chan bool, 1),
+ headerProcCh: make(chan []*types.Header, 1),
+ quitCh: make(chan struct{}),
+ stateCh: make(chan dataPack),
+ SnapSyncer: snap.NewSyncer(stateDb),
+ stateSyncStart: make(chan *stateSync),
+ syncStatsState: stateSyncStats{
+ processed: rawdb.ReadFastTrieProgress(stateDb),
+ },
+ trackStateReq: make(chan *stateReq),
+ }
+ go dl.stateFetcher()
+ return dl
+}
+
+// Progress retrieves the synchronisation boundaries, specifically the origin
+// block where synchronisation started at (may have failed/suspended); the block
+// or header sync is currently at; and the latest known block which the sync targets.
+//
+// In addition, during the state download phase of fast synchronisation the number
+// of processed and the total number of known states are also returned. Otherwise
+// these are zero.
+func (d *Downloader) Progress() ethereum.SyncProgress {
+ // Lock the current stats and return the progress
+ d.syncStatsLock.RLock()
+ defer d.syncStatsLock.RUnlock()
+
+ current := uint64(0)
+ mode := d.getMode()
+ switch {
+ case d.blockchain != nil && mode == FullSync:
+ current = d.blockchain.CurrentBlock().NumberU64()
+ case d.blockchain != nil && mode == FastSync:
+ current = d.blockchain.CurrentFastBlock().NumberU64()
+ case d.lightchain != nil:
+ current = d.lightchain.CurrentHeader().Number.Uint64()
+ default:
+ log.Error("Unknown downloader chain/mode combo", "light", d.lightchain != nil, "full", d.blockchain != nil, "mode", mode)
+ }
+ return ethereum.SyncProgress{
+ StartingBlock: d.syncStatsChainOrigin,
+ CurrentBlock: current,
+ HighestBlock: d.syncStatsChainHeight,
+ PulledStates: d.syncStatsState.processed,
+ KnownStates: d.syncStatsState.processed + d.syncStatsState.pending,
+ }
+}
+
+// Synchronising returns whether the downloader is currently retrieving blocks.
+func (d *Downloader) Synchronising() bool {
+ return atomic.LoadInt32(&d.synchronising) > 0
+}
+
+// RegisterPeer injects a new download peer into the set of block source to be
+// used for fetching hashes and blocks from.
+func (d *Downloader) RegisterPeer(id string, version uint, peer Peer) error {
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:8])
+ }
+ logger.Trace("Registering sync peer")
+ if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil {
+ logger.Error("Failed to register sync peer", "err", err)
+ return err
+ }
+ return nil
+}
+
+// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer.
+func (d *Downloader) RegisterLightPeer(id string, version uint, peer LightPeer) error {
+ return d.RegisterPeer(id, version, &lightPeerWrapper{peer})
+}
+
+// UnregisterPeer remove a peer from the known list, preventing any action from
+// the specified peer. An effort is also made to return any pending fetches into
+// the queue.
+func (d *Downloader) UnregisterPeer(id string) error {
+ // Unregister the peer from the active peer set and revoke any fetch tasks
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:8])
+ }
+ logger.Trace("Unregistering sync peer")
+ if err := d.peers.Unregister(id); err != nil {
+ logger.Error("Failed to unregister sync peer", "err", err)
+ return err
+ }
+ d.queue.Revoke(id)
+
+ return nil
+}
+
+// Synchronise tries to sync up our local block chain with a remote peer, both
+// adding various sanity checks as well as wrapping it with various log entries.
+func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode SyncMode) error {
+ err := d.synchronise(id, head, td, mode)
+
+ switch err {
+ case nil, errBusy, errCanceled:
+ return err
+ }
+ if errors.Is(err, errInvalidChain) || errors.Is(err, errBadPeer) || errors.Is(err, errTimeout) ||
+ errors.Is(err, errStallingPeer) || errors.Is(err, errUnsyncedPeer) || errors.Is(err, errEmptyHeaderSet) ||
+ errors.Is(err, errPeersUnavailable) || errors.Is(err, errTooOld) || errors.Is(err, errInvalidAncestor) {
+ log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err)
+ if d.dropPeer == nil {
+ // The dropPeer method is nil when `--copydb` is used for a local copy.
+ // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored
+ log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", id)
+ } else {
+ d.dropPeer(id)
+ }
+ return err
+ }
+ log.Warn("Synchronisation failed, retrying", "err", err)
+ return err
+}
+
+// synchronise will select the peer and use it for synchronising. If an empty string is given
+// it will use the best peer possible and synchronize if its TD is higher than our own. If any of the
+// checks fail an error will be returned. This method is synchronous
+func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode SyncMode) error {
+ // Mock out the synchronisation if testing
+ if d.synchroniseMock != nil {
+ return d.synchroniseMock(id, hash)
+ }
+ // Make sure only one goroutine is ever allowed past this point at once
+ if !atomic.CompareAndSwapInt32(&d.synchronising, 0, 1) {
+ return errBusy
+ }
+ defer atomic.StoreInt32(&d.synchronising, 0)
+
+ // Post a user notification of the sync (only once per session)
+ if atomic.CompareAndSwapInt32(&d.notified, 0, 1) {
+ log.Info("Block synchronisation started")
+ }
+ // If we are already full syncing, but have a fast-sync bloom filter laying
+ // around, make sure it doesn't use memory any more. This is a special case
+ // when the user attempts to fast sync a new empty network.
+ if mode == FullSync && d.stateBloom != nil {
+ d.stateBloom.Close()
+ }
+ // If snap sync was requested, create the snap scheduler and switch to fast
+ // sync mode. Long term we could drop fast sync or merge the two together,
+ // but until snap becomes prevalent, we should support both. TODO(karalabe).
+ if mode == SnapSync {
+ if !d.snapSync {
+ // Snap sync uses the snapshot namespace to store potentially flakey data until
+ // sync completely heals and finishes. Pause snapshot maintenance in the mean
+ // time to prevent access.
+ if snapshots := d.blockchain.Snapshots(); snapshots != nil { // Only nil in tests
+ snapshots.Disable()
+ }
+ log.Warn("Enabling snapshot sync prototype")
+ d.snapSync = true
+ }
+ mode = FastSync
+ }
+ // Reset the queue, peer set and wake channels to clean any internal leftover state
+ d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems)
+ d.peers.Reset()
+
+ for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} {
+ select {
+ case <-ch:
+ default:
+ }
+ }
+ for _, ch := range []chan dataPack{d.headerCh, d.bodyCh, d.receiptCh} {
+ for empty := false; !empty; {
+ select {
+ case <-ch:
+ default:
+ empty = true
+ }
+ }
+ }
+ for empty := false; !empty; {
+ select {
+ case <-d.headerProcCh:
+ default:
+ empty = true
+ }
+ }
+ // Create cancel channel for aborting mid-flight and mark the master peer
+ d.cancelLock.Lock()
+ d.cancelCh = make(chan struct{})
+ d.cancelPeer = id
+ d.cancelLock.Unlock()
+
+ defer d.Cancel() // No matter what, we can't leave the cancel channel open
+
+ // Atomically set the requested sync mode
+ atomic.StoreUint32(&d.mode, uint32(mode))
+
+ // Retrieve the origin peer and initiate the downloading process
+ p := d.peers.Peer(id)
+ if p == nil {
+ return errUnknownPeer
+ }
+ return d.syncWithPeer(p, hash, td)
+}
+
+func (d *Downloader) getMode() SyncMode {
+ return SyncMode(atomic.LoadUint32(&d.mode))
+}
+
+// syncWithPeer starts a block synchronization based on the hash chain from the
+// specified peer and head hash.
+func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.Int) (err error) {
+ d.mux.Post(StartEvent{})
+ defer func() {
+ // reset on error
+ if err != nil {
+ d.mux.Post(FailedEvent{err})
+ } else {
+ latest := d.lightchain.CurrentHeader()
+ d.mux.Post(DoneEvent{latest})
+ }
+ }()
+ if p.version < eth.ETH66 {
+ return fmt.Errorf("%w: advertized %d < required %d", errTooOld, p.version, eth.ETH66)
+ }
+ mode := d.getMode()
+
+ log.Debug("Synchronising with the network", "peer", p.id, "eth", p.version, "head", hash, "td", td, "mode", mode)
+ defer func(start time.Time) {
+ log.Debug("Synchronisation terminated", "elapsed", common.PrettyDuration(time.Since(start)))
+ }(time.Now())
+
+ // Look up the sync boundaries: the common ancestor and the target block
+ latest, pivot, err := d.fetchHead(p)
+ if err != nil {
+ return err
+ }
+ if mode == FastSync && pivot == nil {
+ // If no pivot block was returned, the head is below the min full block
+ // threshold (i.e. new chain). In that case we won't really fast sync
+ // anyway, but still need a valid pivot block to avoid some code hitting
+ // nil panics on an access.
+ pivot = d.blockchain.CurrentBlock().Header()
+ }
+ height := latest.Number.Uint64()
+
+ origin, err := d.findAncestor(p, latest)
+ if err != nil {
+ return err
+ }
+ d.syncStatsLock.Lock()
+ if d.syncStatsChainHeight <= origin || d.syncStatsChainOrigin > origin {
+ d.syncStatsChainOrigin = origin
+ }
+ d.syncStatsChainHeight = height
+ d.syncStatsLock.Unlock()
+
+ // Ensure our origin point is below any fast sync pivot point
+ if mode == FastSync {
+ if height <= uint64(fsMinFullBlocks) {
+ origin = 0
+ } else {
+ pivotNumber := pivot.Number.Uint64()
+ if pivotNumber <= origin {
+ origin = pivotNumber - 1
+ }
+ // Write out the pivot into the database so a rollback beyond it will
+ // reenable fast sync
+ rawdb.WriteLastPivotNumber(d.stateDB, pivotNumber)
+ }
+ }
+ d.committed = 1
+ if mode == FastSync && pivot.Number.Uint64() != 0 {
+ d.committed = 0
+ }
+ if mode == FastSync {
+ // Set the ancient data limitation.
+ // If we are running fast sync, all block data older than ancientLimit will be
+ // written to the ancient store. More recent data will be written to the active
+ // database and will wait for the freezer to migrate.
+ //
+ // If there is a checkpoint available, then calculate the ancientLimit through
+ // that. Otherwise calculate the ancient limit through the advertised height
+ // of the remote peer.
+ //
+ // The reason for picking checkpoint first is that a malicious peer can give us
+ // a fake (very high) height, forcing the ancient limit to also be very high.
+ // The peer would start to feed us valid blocks until head, resulting in all of
+ // the blocks might be written into the ancient store. A following mini-reorg
+ // could cause issues.
+ if d.checkpoint != 0 && d.checkpoint > fullMaxForkAncestry+1 {
+ d.ancientLimit = d.checkpoint
+ } else if height > fullMaxForkAncestry+1 {
+ d.ancientLimit = height - fullMaxForkAncestry - 1
+ } else {
+ d.ancientLimit = 0
+ }
+ frozen, _ := d.stateDB.Ancients() // Ignore the error here since light client can also hit here.
+
+ // If a part of blockchain data has already been written into active store,
+ // disable the ancient style insertion explicitly.
+ if origin >= frozen && frozen != 0 {
+ d.ancientLimit = 0
+ log.Info("Disabling direct-ancient mode", "origin", origin, "ancient", frozen-1)
+ } else if d.ancientLimit > 0 {
+ log.Debug("Enabling direct-ancient mode", "ancient", d.ancientLimit)
+ }
+ // Rewind the ancient store and blockchain if reorg happens.
+ if origin+1 < frozen {
+ if err := d.lightchain.SetHead(origin + 1); err != nil {
+ return err
+ }
+ }
+ }
+ // Initiate the sync using a concurrent header and content retrieval algorithm
+ d.queue.Prepare(origin+1, mode)
+ if d.syncInitHook != nil {
+ d.syncInitHook(origin, height)
+ }
+ fetchers := []func() error{
+ func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved
+ func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync
+ func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
+ func() error { return d.processHeaders(origin+1, td) },
+ }
+ if mode == FastSync {
+ d.pivotLock.Lock()
+ d.pivotHeader = pivot
+ d.pivotLock.Unlock()
+
+ fetchers = append(fetchers, func() error { return d.processFastSyncContent() })
+ } else if mode == FullSync {
+ fetchers = append(fetchers, d.processFullSyncContent)
+ }
+ return d.spawnSync(fetchers)
+}
+
+// spawnSync runs d.process and all given fetcher functions to completion in
+// separate goroutines, returning the first error that appears.
+func (d *Downloader) spawnSync(fetchers []func() error) error {
+ errc := make(chan error, len(fetchers))
+ d.cancelWg.Add(len(fetchers))
+ for _, fn := range fetchers {
+ fn := fn
+ go func() { defer d.cancelWg.Done(); errc <- fn() }()
+ }
+ // Wait for the first error, then terminate the others.
+ var err error
+ for i := 0; i < len(fetchers); i++ {
+ if i == len(fetchers)-1 {
+ // Close the queue when all fetchers have exited.
+ // This will cause the block processor to end when
+ // it has processed the queue.
+ d.queue.Close()
+ }
+ if err = <-errc; err != nil && err != errCanceled {
+ break
+ }
+ }
+ d.queue.Close()
+ d.Cancel()
+ return err
+}
+
+// cancel aborts all of the operations and resets the queue. However, cancel does
+// not wait for the running download goroutines to finish. This method should be
+// used when cancelling the downloads from inside the downloader.
+func (d *Downloader) cancel() {
+ // Close the current cancel channel
+ d.cancelLock.Lock()
+ defer d.cancelLock.Unlock()
+
+ if d.cancelCh != nil {
+ select {
+ case <-d.cancelCh:
+ // Channel was already closed
+ default:
+ close(d.cancelCh)
+ }
+ }
+}
+
+// Cancel aborts all of the operations and waits for all download goroutines to
+// finish before returning.
+func (d *Downloader) Cancel() {
+ d.cancel()
+ d.cancelWg.Wait()
+}
+
+// Terminate interrupts the downloader, canceling all pending operations.
+// The downloader cannot be reused after calling Terminate.
+func (d *Downloader) Terminate() {
+ // Close the termination channel (make sure double close is allowed)
+ d.quitLock.Lock()
+ select {
+ case <-d.quitCh:
+ default:
+ close(d.quitCh)
+ }
+ if d.stateBloom != nil {
+ d.stateBloom.Close()
+ }
+ d.quitLock.Unlock()
+
+ // Cancel any pending download requests
+ d.Cancel()
+}
+
+// fetchHead retrieves the head header and prior pivot block (if available) from
+// a remote peer.
+func (d *Downloader) fetchHead(p *peerConnection) (head *types.Header, pivot *types.Header, err error) {
+ p.log.Debug("Retrieving remote chain head")
+ mode := d.getMode()
+
+ // Request the advertised remote head block and wait for the response
+ latest, _ := p.peer.Head()
+ fetch := 1
+ if mode == FastSync {
+ fetch = 2 // head + pivot headers
+ }
+ go p.peer.RequestHeadersByHash(latest, fetch, fsMinFullBlocks-1, true)
+
+ ttl := d.peers.rates.TargetTimeout()
+ timeout := time.After(ttl)
+ for {
+ select {
+ case <-d.cancelCh:
+ return nil, nil, errCanceled
+
+ case packet := <-d.headerCh:
+ // Discard anything not from the origin peer
+ if packet.PeerId() != p.id {
+ log.Debug("Received headers from incorrect peer", "peer", packet.PeerId())
+ break
+ }
+ // Make sure the peer gave us at least one and at most the requested headers
+ headers := packet.(*headerPack).headers
+ if len(headers) == 0 || len(headers) > fetch {
+ return nil, nil, fmt.Errorf("%w: returned headers %d != requested %d", errBadPeer, len(headers), fetch)
+ }
+ // The first header needs to be the head, validate against the checkpoint
+ // and request. If only 1 header was returned, make sure there's no pivot
+ // or there was not one requested.
+ head := headers[0]
+ if (mode == FastSync || mode == LightSync) && head.Number.Uint64() < d.checkpoint {
+ return nil, nil, fmt.Errorf("%w: remote head %d below checkpoint %d", errUnsyncedPeer, head.Number, d.checkpoint)
+ }
+ if len(headers) == 1 {
+ if mode == FastSync && head.Number.Uint64() > uint64(fsMinFullBlocks) {
+ return nil, nil, fmt.Errorf("%w: no pivot included along head header", errBadPeer)
+ }
+ p.log.Debug("Remote head identified, no pivot", "number", head.Number, "hash", head.Hash())
+ return head, nil, nil
+ }
+ // At this point we have 2 headers in total and the first is the
+ // validated head of the chain. Check the pivot number and return,
+ pivot := headers[1]
+ if pivot.Number.Uint64() != head.Number.Uint64()-uint64(fsMinFullBlocks) {
+ return nil, nil, fmt.Errorf("%w: remote pivot %d != requested %d", errInvalidChain, pivot.Number, head.Number.Uint64()-uint64(fsMinFullBlocks))
+ }
+ return head, pivot, nil
+
+ case <-timeout:
+ p.log.Debug("Waiting for head header timed out", "elapsed", ttl)
+ return nil, nil, errTimeout
+
+ case <-d.bodyCh:
+ case <-d.receiptCh:
+ // Out of bounds delivery, ignore
+ }
+ }
+}
+
+// calculateRequestSpan calculates what headers to request from a peer when trying to determine the
+// common ancestor.
+// It returns parameters to be used for peer.RequestHeadersByNumber:
+// from - starting block number
+// count - number of headers to request
+// skip - number of headers to skip
+// and also returns 'max', the last block which is expected to be returned by the remote peers,
+// given the (from,count,skip)
+func calculateRequestSpan(remoteHeight, localHeight uint64) (int64, int, int, uint64) {
+ var (
+ from int
+ count int
+ MaxCount = MaxHeaderFetch / 16
+ )
+ // requestHead is the highest block that we will ask for. If requestHead is not offset,
+ // the highest block that we will get is 16 blocks back from head, which means we
+ // will fetch 14 or 15 blocks unnecessarily in the case the height difference
+ // between us and the peer is 1-2 blocks, which is most common
+ requestHead := int(remoteHeight) - 1
+ if requestHead < 0 {
+ requestHead = 0
+ }
+ // requestBottom is the lowest block we want included in the query
+ // Ideally, we want to include the one just below our own head
+ requestBottom := int(localHeight - 1)
+ if requestBottom < 0 {
+ requestBottom = 0
+ }
+ totalSpan := requestHead - requestBottom
+ span := 1 + totalSpan/MaxCount
+ if span < 2 {
+ span = 2
+ }
+ if span > 16 {
+ span = 16
+ }
+
+ count = 1 + totalSpan/span
+ if count > MaxCount {
+ count = MaxCount
+ }
+ if count < 2 {
+ count = 2
+ }
+ from = requestHead - (count-1)*span
+ if from < 0 {
+ from = 0
+ }
+ max := from + (count-1)*span
+ return int64(from), count, span - 1, uint64(max)
+}
+
+// findAncestor tries to locate the common ancestor link of the local chain and
+// a remote peers blockchain. In the general case when our node was in sync and
+// on the correct chain, checking the top N links should already get us a match.
+// In the rare scenario when we ended up on a long reorganisation (i.e. none of
+// the head links match), we do a binary search to find the common ancestor.
+func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header) (uint64, error) {
+ // Figure out the valid ancestor range to prevent rewrite attacks
+ var (
+ floor = int64(-1)
+ localHeight uint64
+ remoteHeight = remoteHeader.Number.Uint64()
+ )
+ mode := d.getMode()
+ switch mode {
+ case FullSync:
+ localHeight = d.blockchain.CurrentBlock().NumberU64()
+ case FastSync:
+ localHeight = d.blockchain.CurrentFastBlock().NumberU64()
+ default:
+ localHeight = d.lightchain.CurrentHeader().Number.Uint64()
+ }
+ p.log.Debug("Looking for common ancestor", "local", localHeight, "remote", remoteHeight)
+
+ // Recap floor value for binary search
+ maxForkAncestry := fullMaxForkAncestry
+ if d.getMode() == LightSync {
+ maxForkAncestry = lightMaxForkAncestry
+ }
+ if localHeight >= maxForkAncestry {
+ // We're above the max reorg threshold, find the earliest fork point
+ floor = int64(localHeight - maxForkAncestry)
+ }
+ // If we're doing a light sync, ensure the floor doesn't go below the CHT, as
+ // all headers before that point will be missing.
+ if mode == LightSync {
+ // If we don't know the current CHT position, find it
+ if d.genesis == 0 {
+ header := d.lightchain.CurrentHeader()
+ for header != nil {
+ d.genesis = header.Number.Uint64()
+ if floor >= int64(d.genesis)-1 {
+ break
+ }
+ header = d.lightchain.GetHeaderByHash(header.ParentHash)
+ }
+ }
+ // We already know the "genesis" block number, cap floor to that
+ if floor < int64(d.genesis)-1 {
+ floor = int64(d.genesis) - 1
+ }
+ }
+
+ ancestor, err := d.findAncestorSpanSearch(p, mode, remoteHeight, localHeight, floor)
+ if err == nil {
+ return ancestor, nil
+ }
+ // The returned error was not nil.
+ // If the error returned does not reflect that a common ancestor was not found, return it.
+ // If the error reflects that a common ancestor was not found, continue to binary search,
+ // where the error value will be reassigned.
+ if !errors.Is(err, errNoAncestorFound) {
+ return 0, err
+ }
+
+ ancestor, err = d.findAncestorBinarySearch(p, mode, remoteHeight, floor)
+ if err != nil {
+ return 0, err
+ }
+ return ancestor, nil
+}
+
+func (d *Downloader) findAncestorSpanSearch(p *peerConnection, mode SyncMode, remoteHeight, localHeight uint64, floor int64) (commonAncestor uint64, err error) {
+ from, count, skip, max := calculateRequestSpan(remoteHeight, localHeight)
+
+ p.log.Trace("Span searching for common ancestor", "count", count, "from", from, "skip", skip)
+ go p.peer.RequestHeadersByNumber(uint64(from), count, skip, false)
+
+ // Wait for the remote response to the head fetch
+ number, hash := uint64(0), common.Hash{}
+
+ ttl := d.peers.rates.TargetTimeout()
+ timeout := time.After(ttl)
+
+ for finished := false; !finished; {
+ select {
+ case <-d.cancelCh:
+ return 0, errCanceled
+
+ case packet := <-d.headerCh:
+ // Discard anything not from the origin peer
+ if packet.PeerId() != p.id {
+ log.Debug("Received headers from incorrect peer", "peer", packet.PeerId())
+ break
+ }
+ // Make sure the peer actually gave something valid
+ headers := packet.(*headerPack).headers
+ if len(headers) == 0 {
+ p.log.Warn("Empty head header set")
+ return 0, errEmptyHeaderSet
+ }
+ // Make sure the peer's reply conforms to the request
+ for i, header := range headers {
+ expectNumber := from + int64(i)*int64(skip+1)
+ if number := header.Number.Int64(); number != expectNumber {
+ p.log.Warn("Head headers broke chain ordering", "index", i, "requested", expectNumber, "received", number)
+ return 0, fmt.Errorf("%w: %v", errInvalidChain, errors.New("head headers broke chain ordering"))
+ }
+ }
+ // Check if a common ancestor was found
+ finished = true
+ for i := len(headers) - 1; i >= 0; i-- {
+ // Skip any headers that underflow/overflow our requested set
+ if headers[i].Number.Int64() < from || headers[i].Number.Uint64() > max {
+ continue
+ }
+ // Otherwise check if we already know the header or not
+ h := headers[i].Hash()
+ n := headers[i].Number.Uint64()
+
+ var known bool
+ switch mode {
+ case FullSync:
+ known = d.blockchain.HasBlock(h, n)
+ case FastSync:
+ known = d.blockchain.HasFastBlock(h, n)
+ default:
+ known = d.lightchain.HasHeader(h, n)
+ }
+ if known {
+ number, hash = n, h
+ break
+ }
+ }
+
+ case <-timeout:
+ p.log.Debug("Waiting for head header timed out", "elapsed", ttl)
+ return 0, errTimeout
+
+ case <-d.bodyCh:
+ case <-d.receiptCh:
+ // Out of bounds delivery, ignore
+ }
+ }
+ // If the head fetch already found an ancestor, return
+ if hash != (common.Hash{}) {
+ if int64(number) <= floor {
+ p.log.Warn("Ancestor below allowance", "number", number, "hash", hash, "allowance", floor)
+ return 0, errInvalidAncestor
+ }
+ p.log.Debug("Found common ancestor", "number", number, "hash", hash)
+ return number, nil
+ }
+ return 0, errNoAncestorFound
+}
+
+func (d *Downloader) findAncestorBinarySearch(p *peerConnection, mode SyncMode, remoteHeight uint64, floor int64) (commonAncestor uint64, err error) {
+ hash := common.Hash{}
+
+ // Ancestor not found, we need to binary search over our chain
+ start, end := uint64(0), remoteHeight
+ if floor > 0 {
+ start = uint64(floor)
+ }
+ p.log.Trace("Binary searching for common ancestor", "start", start, "end", end)
+
+ for start+1 < end {
+ // Split our chain interval in two, and request the hash to cross check
+ check := (start + end) / 2
+
+ ttl := d.peers.rates.TargetTimeout()
+ timeout := time.After(ttl)
+
+ go p.peer.RequestHeadersByNumber(check, 1, 0, false)
+
+ // Wait until a reply arrives to this request
+ for arrived := false; !arrived; {
+ select {
+ case <-d.cancelCh:
+ return 0, errCanceled
+
+ case packet := <-d.headerCh:
+ // Discard anything not from the origin peer
+ if packet.PeerId() != p.id {
+ log.Debug("Received headers from incorrect peer", "peer", packet.PeerId())
+ break
+ }
+ // Make sure the peer actually gave something valid
+ headers := packet.(*headerPack).headers
+ if len(headers) != 1 {
+ p.log.Warn("Multiple headers for single request", "headers", len(headers))
+ return 0, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers))
+ }
+ arrived = true
+
+ // Modify the search interval based on the response
+ h := headers[0].Hash()
+ n := headers[0].Number.Uint64()
+
+ var known bool
+ switch mode {
+ case FullSync:
+ known = d.blockchain.HasBlock(h, n)
+ case FastSync:
+ known = d.blockchain.HasFastBlock(h, n)
+ default:
+ known = d.lightchain.HasHeader(h, n)
+ }
+ if !known {
+ end = check
+ break
+ }
+ header := d.lightchain.GetHeaderByHash(h) // Independent of sync mode, header surely exists
+ if header.Number.Uint64() != check {
+ p.log.Warn("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check)
+ return 0, fmt.Errorf("%w: non-requested header (%d)", errBadPeer, header.Number)
+ }
+ start = check
+ hash = h
+
+ case <-timeout:
+ p.log.Debug("Waiting for search header timed out", "elapsed", ttl)
+ return 0, errTimeout
+
+ case <-d.bodyCh:
+ case <-d.receiptCh:
+ // Out of bounds delivery, ignore
+ }
+ }
+ }
+ // Ensure valid ancestry and return
+ if int64(start) <= floor {
+ p.log.Warn("Ancestor below allowance", "number", start, "hash", hash, "allowance", floor)
+ return 0, errInvalidAncestor
+ }
+ p.log.Debug("Found common ancestor", "number", start, "hash", hash)
+ return start, nil
+}
+
+// fetchHeaders keeps retrieving headers concurrently from the number
+// requested, until no more are returned, potentially throttling on the way. To
+// facilitate concurrency but still protect against malicious nodes sending bad
+// headers, we construct a header chain skeleton using the "origin" peer we are
+// syncing with, and fill in the missing headers using anyone else. Headers from
+// other peers are only accepted if they map cleanly to the skeleton. If no one
+// can fill in the skeleton - not even the origin peer - it's assumed invalid and
+// the origin is dropped.
+func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
+ p.log.Debug("Directing header downloads", "origin", from)
+ defer p.log.Debug("Header download terminated")
+
+ // Create a timeout timer, and the associated header fetcher
+ skeleton := true // Skeleton assembly phase or finishing up
+ pivoting := false // Whether the next request is pivot verification
+ request := time.Now() // time of the last skeleton fetch request
+ timeout := time.NewTimer(0) // timer to dump a non-responsive active peer
+ <-timeout.C // timeout channel should be initially empty
+ defer timeout.Stop()
+
+ var ttl time.Duration
+ getHeaders := func(from uint64) {
+ request = time.Now()
+
+ ttl = d.peers.rates.TargetTimeout()
+ timeout.Reset(ttl)
+
+ if skeleton {
+ p.log.Trace("Fetching skeleton headers", "count", MaxHeaderFetch, "from", from)
+ go p.peer.RequestHeadersByNumber(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false)
+ } else {
+ p.log.Trace("Fetching full headers", "count", MaxHeaderFetch, "from", from)
+ go p.peer.RequestHeadersByNumber(from, MaxHeaderFetch, 0, false)
+ }
+ }
+ getNextPivot := func() {
+ pivoting = true
+ request = time.Now()
+
+ ttl = d.peers.rates.TargetTimeout()
+ timeout.Reset(ttl)
+
+ d.pivotLock.RLock()
+ pivot := d.pivotHeader.Number.Uint64()
+ d.pivotLock.RUnlock()
+
+ p.log.Trace("Fetching next pivot header", "number", pivot+uint64(fsMinFullBlocks))
+ go p.peer.RequestHeadersByNumber(pivot+uint64(fsMinFullBlocks), 2, fsMinFullBlocks-9, false) // move +64 when it's 2x64-8 deep
+ }
+ // Start pulling the header chain skeleton until all is done
+ ancestor := from
+ getHeaders(from)
+
+ mode := d.getMode()
+ for {
+ select {
+ case <-d.cancelCh:
+ return errCanceled
+
+ case packet := <-d.headerCh:
+ // Make sure the active peer is giving us the skeleton headers
+ if packet.PeerId() != p.id {
+ log.Debug("Received skeleton from incorrect peer", "peer", packet.PeerId())
+ break
+ }
+ headerReqTimer.UpdateSince(request)
+ timeout.Stop()
+
+ // If the pivot is being checked, move if it became stale and run the real retrieval
+ var pivot uint64
+
+ d.pivotLock.RLock()
+ if d.pivotHeader != nil {
+ pivot = d.pivotHeader.Number.Uint64()
+ }
+ d.pivotLock.RUnlock()
+
+ if pivoting {
+ if packet.Items() == 2 {
+ // Retrieve the headers and do some sanity checks, just in case
+ headers := packet.(*headerPack).headers
+
+ if have, want := headers[0].Number.Uint64(), pivot+uint64(fsMinFullBlocks); have != want {
+ log.Warn("Peer sent invalid next pivot", "have", have, "want", want)
+ return fmt.Errorf("%w: next pivot number %d != requested %d", errInvalidChain, have, want)
+ }
+ if have, want := headers[1].Number.Uint64(), pivot+2*uint64(fsMinFullBlocks)-8; have != want {
+ log.Warn("Peer sent invalid pivot confirmer", "have", have, "want", want)
+ return fmt.Errorf("%w: next pivot confirmer number %d != requested %d", errInvalidChain, have, want)
+ }
+ log.Warn("Pivot seemingly stale, moving", "old", pivot, "new", headers[0].Number)
+ pivot = headers[0].Number.Uint64()
+
+ d.pivotLock.Lock()
+ d.pivotHeader = headers[0]
+ d.pivotLock.Unlock()
+
+ // Write out the pivot into the database so a rollback beyond
+ // it will reenable fast sync and update the state root that
+ // the state syncer will be downloading.
+ rawdb.WriteLastPivotNumber(d.stateDB, pivot)
+ }
+ pivoting = false
+ getHeaders(from)
+ continue
+ }
+ // If the skeleton's finished, pull any remaining head headers directly from the origin
+ if skeleton && packet.Items() == 0 {
+ skeleton = false
+ getHeaders(from)
+ continue
+ }
+ // If no more headers are inbound, notify the content fetchers and return
+ if packet.Items() == 0 {
+ // Don't abort header fetches while the pivot is downloading
+ if atomic.LoadInt32(&d.committed) == 0 && pivot <= from {
+ p.log.Debug("No headers, waiting for pivot commit")
+ select {
+ case <-time.After(fsHeaderContCheck):
+ getHeaders(from)
+ continue
+ case <-d.cancelCh:
+ return errCanceled
+ }
+ }
+ // Pivot done (or not in fast sync) and no more headers, terminate the process
+ p.log.Debug("No more headers available")
+ select {
+ case d.headerProcCh <- nil:
+ return nil
+ case <-d.cancelCh:
+ return errCanceled
+ }
+ }
+ headers := packet.(*headerPack).headers
+
+ // If we received a skeleton batch, resolve internals concurrently
+ if skeleton {
+ filled, proced, err := d.fillHeaderSkeleton(from, headers)
+ if err != nil {
+ p.log.Debug("Skeleton chain invalid", "err", err)
+ return fmt.Errorf("%w: %v", errInvalidChain, err)
+ }
+ headers = filled[proced:]
+ from += uint64(proced)
+ } else {
+ // If we're closing in on the chain head, but haven't yet reached it, delay
+ // the last few headers so mini reorgs on the head don't cause invalid hash
+ // chain errors.
+ if n := len(headers); n > 0 {
+ // Retrieve the current head we're at
+ var head uint64
+ if mode == LightSync {
+ head = d.lightchain.CurrentHeader().Number.Uint64()
+ } else {
+ head = d.blockchain.CurrentFastBlock().NumberU64()
+ if full := d.blockchain.CurrentBlock().NumberU64(); head < full {
+ head = full
+ }
+ }
+ // If the head is below the common ancestor, we're actually deduplicating
+ // already existing chain segments, so use the ancestor as the fake head.
+ // Otherwise we might end up delaying header deliveries pointlessly.
+ if head < ancestor {
+ head = ancestor
+ }
+ // If the head is way older than this batch, delay the last few headers
+ if head+uint64(reorgProtThreshold) < headers[n-1].Number.Uint64() {
+ delay := reorgProtHeaderDelay
+ if delay > n {
+ delay = n
+ }
+ headers = headers[:n-delay]
+ }
+ }
+ }
+ // Insert all the new headers and fetch the next batch
+ if len(headers) > 0 {
+ p.log.Trace("Scheduling new headers", "count", len(headers), "from", from)
+ select {
+ case d.headerProcCh <- headers:
+ case <-d.cancelCh:
+ return errCanceled
+ }
+ from += uint64(len(headers))
+
+ // If we're still skeleton filling fast sync, check pivot staleness
+ // before continuing to the next skeleton filling
+ if skeleton && pivot > 0 {
+ getNextPivot()
+ } else {
+ getHeaders(from)
+ }
+ } else {
+ // No headers delivered, or all of them being delayed, sleep a bit and retry
+ p.log.Trace("All headers delayed, waiting")
+ select {
+ case <-time.After(fsHeaderContCheck):
+ getHeaders(from)
+ continue
+ case <-d.cancelCh:
+ return errCanceled
+ }
+ }
+
+ case <-timeout.C:
+ if d.dropPeer == nil {
+ // The dropPeer method is nil when `--copydb` is used for a local copy.
+ // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored
+ p.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", p.id)
+ break
+ }
+ // Header retrieval timed out, consider the peer bad and drop
+ p.log.Debug("Header request timed out", "elapsed", ttl)
+ headerTimeoutMeter.Mark(1)
+ d.dropPeer(p.id)
+
+ // Finish the sync gracefully instead of dumping the gathered data though
+ for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} {
+ select {
+ case ch <- false:
+ case <-d.cancelCh:
+ }
+ }
+ select {
+ case d.headerProcCh <- nil:
+ case <-d.cancelCh:
+ }
+ return fmt.Errorf("%w: header request timed out", errBadPeer)
+ }
+ }
+}
+
+// fillHeaderSkeleton concurrently retrieves headers from all our available peers
+// and maps them to the provided skeleton header chain.
+//
+// Any partial results from the beginning of the skeleton is (if possible) forwarded
+// immediately to the header processor to keep the rest of the pipeline full even
+// in the case of header stalls.
+//
+// The method returns the entire filled skeleton and also the number of headers
+// already forwarded for processing.
+func (d *Downloader) fillHeaderSkeleton(from uint64, skeleton []*types.Header) ([]*types.Header, int, error) {
+ log.Debug("Filling up skeleton", "from", from)
+ d.queue.ScheduleSkeleton(from, skeleton)
+
+ var (
+ deliver = func(packet dataPack) (int, error) {
+ pack := packet.(*headerPack)
+ return d.queue.DeliverHeaders(pack.peerID, pack.headers, d.headerProcCh)
+ }
+ expire = func() map[string]int { return d.queue.ExpireHeaders(d.peers.rates.TargetTimeout()) }
+ reserve = func(p *peerConnection, count int) (*fetchRequest, bool, bool) {
+ return d.queue.ReserveHeaders(p, count), false, false
+ }
+ fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) }
+ capacity = func(p *peerConnection) int { return p.HeaderCapacity(d.peers.rates.TargetRoundTrip()) }
+ setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) {
+ p.SetHeadersIdle(accepted, deliveryTime)
+ }
+ )
+ err := d.fetchParts(d.headerCh, deliver, d.queue.headerContCh, expire,
+ d.queue.PendingHeaders, d.queue.InFlightHeaders, reserve,
+ nil, fetch, d.queue.CancelHeaders, capacity, d.peers.HeaderIdlePeers, setIdle, "headers")
+
+ log.Debug("Skeleton fill terminated", "err", err)
+
+ filled, proced := d.queue.RetrieveHeaders()
+ return filled, proced, err
+}
+
+// fetchBodies iteratively downloads the scheduled block bodies, taking any
+// available peers, reserving a chunk of blocks for each, waiting for delivery
+// and also periodically checking for timeouts.
+func (d *Downloader) fetchBodies(from uint64) error {
+ log.Debug("Downloading block bodies", "origin", from)
+
+ var (
+ deliver = func(packet dataPack) (int, error) {
+ pack := packet.(*bodyPack)
+ return d.queue.DeliverBodies(pack.peerID, pack.transactions, pack.uncles)
+ }
+ expire = func() map[string]int { return d.queue.ExpireBodies(d.peers.rates.TargetTimeout()) }
+ fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchBodies(req) }
+ capacity = func(p *peerConnection) int { return p.BlockCapacity(d.peers.rates.TargetRoundTrip()) }
+ setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) { p.SetBodiesIdle(accepted, deliveryTime) }
+ )
+ err := d.fetchParts(d.bodyCh, deliver, d.bodyWakeCh, expire,
+ d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ReserveBodies,
+ d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, d.peers.BodyIdlePeers, setIdle, "bodies")
+
+ log.Debug("Block body download terminated", "err", err)
+ return err
+}
+
+// fetchReceipts iteratively downloads the scheduled block receipts, taking any
+// available peers, reserving a chunk of receipts for each, waiting for delivery
+// and also periodically checking for timeouts.
+func (d *Downloader) fetchReceipts(from uint64) error {
+ log.Debug("Downloading transaction receipts", "origin", from)
+
+ var (
+ deliver = func(packet dataPack) (int, error) {
+ pack := packet.(*receiptPack)
+ return d.queue.DeliverReceipts(pack.peerID, pack.receipts)
+ }
+ expire = func() map[string]int { return d.queue.ExpireReceipts(d.peers.rates.TargetTimeout()) }
+ fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchReceipts(req) }
+ capacity = func(p *peerConnection) int { return p.ReceiptCapacity(d.peers.rates.TargetRoundTrip()) }
+ setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) {
+ p.SetReceiptsIdle(accepted, deliveryTime)
+ }
+ )
+ err := d.fetchParts(d.receiptCh, deliver, d.receiptWakeCh, expire,
+ d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ReserveReceipts,
+ d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "receipts")
+
+ log.Debug("Transaction receipt download terminated", "err", err)
+ return err
+}
+
+// fetchParts iteratively downloads scheduled block parts, taking any available
+// peers, reserving a chunk of fetch requests for each, waiting for delivery and
+// also periodically checking for timeouts.
+//
+// As the scheduling/timeout logic mostly is the same for all downloaded data
+// types, this method is used by each for data gathering and is instrumented with
+// various callbacks to handle the slight differences between processing them.
+//
+// The instrumentation parameters:
+// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer)
+// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers)
+// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`)
+// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed)
+// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping)
+// - pending: task callback for the number of requests still needing download (detect completion/non-completability)
+// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish)
+// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use)
+// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions)
+// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic)
+// - fetch: network callback to actually send a particular download request to a physical remote peer
+// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer)
+// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping)
+// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks
+// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping)
+// - kind: textual label of the type being downloaded to display in log messages
+func (d *Downloader) fetchParts(deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool,
+ expire func() map[string]int, pending func() int, inFlight func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, bool),
+ fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int,
+ idle func() ([]*peerConnection, int), setIdle func(*peerConnection, int, time.Time), kind string) error {
+
+ // Create a ticker to detect expired retrieval tasks
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+
+ update := make(chan struct{}, 1)
+
+ // Prepare the queue and fetch block parts until the block header fetcher's done
+ finished := false
+ for {
+ select {
+ case <-d.cancelCh:
+ return errCanceled
+
+ case packet := <-deliveryCh:
+ deliveryTime := time.Now()
+ // If the peer was previously banned and failed to deliver its pack
+ // in a reasonable time frame, ignore its message.
+ if peer := d.peers.Peer(packet.PeerId()); peer != nil {
+ // Deliver the received chunk of data and check chain validity
+ accepted, err := deliver(packet)
+ if errors.Is(err, errInvalidChain) {
+ return err
+ }
+ // Unless a peer delivered something completely else than requested (usually
+ // caused by a timed out request which came through in the end), set it to
+ // idle. If the delivery's stale, the peer should have already been idled.
+ if !errors.Is(err, errStaleDelivery) {
+ setIdle(peer, accepted, deliveryTime)
+ }
+ // Issue a log to the user to see what's going on
+ switch {
+ case err == nil && packet.Items() == 0:
+ peer.log.Trace("Requested data not delivered", "type", kind)
+ case err == nil:
+ peer.log.Trace("Delivered new batch of data", "type", kind, "count", packet.Stats())
+ default:
+ peer.log.Debug("Failed to deliver retrieved data", "type", kind, "err", err)
+ }
+ }
+ // Blocks assembled, try to update the progress
+ select {
+ case update <- struct{}{}:
+ default:
+ }
+
+ case cont := <-wakeCh:
+ // The header fetcher sent a continuation flag, check if it's done
+ if !cont {
+ finished = true
+ }
+ // Headers arrive, try to update the progress
+ select {
+ case update <- struct{}{}:
+ default:
+ }
+
+ case <-ticker.C:
+ // Sanity check update the progress
+ select {
+ case update <- struct{}{}:
+ default:
+ }
+
+ case <-update:
+ // Short circuit if we lost all our peers
+ if d.peers.Len() == 0 {
+ return errNoPeers
+ }
+ // Check for fetch request timeouts and demote the responsible peers
+ for pid, fails := range expire() {
+ if peer := d.peers.Peer(pid); peer != nil {
+ // If a lot of retrieval elements expired, we might have overestimated the remote peer or perhaps
+ // ourselves. Only reset to minimal throughput but don't drop just yet. If even the minimal times
+ // out that sync wise we need to get rid of the peer.
+ //
+ // The reason the minimum threshold is 2 is because the downloader tries to estimate the bandwidth
+ // and latency of a peer separately, which requires pushing the measures capacity a bit and seeing
+ // how response times reacts, to it always requests one more than the minimum (i.e. min 2).
+ if fails > 2 {
+ peer.log.Trace("Data delivery timed out", "type", kind)
+ setIdle(peer, 0, time.Now())
+ } else {
+ peer.log.Debug("Stalling delivery, dropping", "type", kind)
+
+ if d.dropPeer == nil {
+ // The dropPeer method is nil when `--copydb` is used for a local copy.
+ // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored
+ peer.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", pid)
+ } else {
+ d.dropPeer(pid)
+
+ // If this peer was the master peer, abort sync immediately
+ d.cancelLock.RLock()
+ master := pid == d.cancelPeer
+ d.cancelLock.RUnlock()
+
+ if master {
+ d.cancel()
+ return errTimeout
+ }
+ }
+ }
+ }
+ }
+ // If there's nothing more to fetch, wait or terminate
+ if pending() == 0 {
+ if !inFlight() && finished {
+ log.Debug("Data fetching completed", "type", kind)
+ return nil
+ }
+ break
+ }
+ // Send a download request to all idle peers, until throttled
+ progressed, throttled, running := false, false, inFlight()
+ idles, total := idle()
+ pendCount := pending()
+ for _, peer := range idles {
+ // Short circuit if throttling activated
+ if throttled {
+ break
+ }
+ // Short circuit if there is no more available task.
+ if pendCount = pending(); pendCount == 0 {
+ break
+ }
+ // Reserve a chunk of fetches for a peer. A nil can mean either that
+ // no more headers are available, or that the peer is known not to
+ // have them.
+ request, progress, throttle := reserve(peer, capacity(peer))
+ if progress {
+ progressed = true
+ }
+ if throttle {
+ throttled = true
+ throttleCounter.Inc(1)
+ }
+ if request == nil {
+ continue
+ }
+ if request.From > 0 {
+ peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From)
+ } else {
+ peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
+ }
+ // Fetch the chunk and make sure any errors return the hashes to the queue
+ if fetchHook != nil {
+ fetchHook(request.Headers)
+ }
+ if err := fetch(peer, request); err != nil {
+ // Although we could try and make an attempt to fix this, this error really
+ // means that we've double allocated a fetch task to a peer. If that is the
+ // case, the internal state of the downloader and the queue is very wrong so
+ // better hard crash and note the error instead of silently accumulating into
+ // a much bigger issue.
+ panic(fmt.Sprintf("%v: %s fetch assignment failed", peer, kind))
+ }
+ running = true
+ }
+ // Make sure that we have peers available for fetching. If all peers have been tried
+ // and all failed throw an error
+ if !progressed && !throttled && !running && len(idles) == total && pendCount > 0 {
+ return errPeersUnavailable
+ }
+ }
+ }
+}
+
+// processHeaders takes batches of retrieved headers from an input channel and
+// keeps processing and scheduling them into the header chain and downloader's
+// queue until the stream ends or a failure occurs.
+func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
+ // Keep a count of uncertain headers to roll back
+ var (
+ rollback uint64 // Zero means no rollback (fine as you can't unroll the genesis)
+ rollbackErr error
+ mode = d.getMode()
+ )
+ defer func() {
+ if rollback > 0 {
+ lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0
+ if mode != LightSync {
+ lastFastBlock = d.blockchain.CurrentFastBlock().Number()
+ lastBlock = d.blockchain.CurrentBlock().Number()
+ }
+ if err := d.lightchain.SetHead(rollback - 1); err != nil { // -1 to target the parent of the first uncertain block
+ // We're already unwinding the stack, only print the error to make it more visible
+ log.Error("Failed to roll back chain segment", "head", rollback-1, "err", err)
+ }
+ curFastBlock, curBlock := common.Big0, common.Big0
+ if mode != LightSync {
+ curFastBlock = d.blockchain.CurrentFastBlock().Number()
+ curBlock = d.blockchain.CurrentBlock().Number()
+ }
+ log.Warn("Rolled back chain segment",
+ "header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
+ "fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
+ "block", fmt.Sprintf("%d->%d", lastBlock, curBlock), "reason", rollbackErr)
+ }
+ }()
+ // Wait for batches of headers to process
+ gotHeaders := false
+
+ for {
+ select {
+ case <-d.cancelCh:
+ rollbackErr = errCanceled
+ return errCanceled
+
+ case headers := <-d.headerProcCh:
+ // Terminate header processing if we synced up
+ if len(headers) == 0 {
+ // Notify everyone that headers are fully processed
+ for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} {
+ select {
+ case ch <- false:
+ case <-d.cancelCh:
+ }
+ }
+ // If no headers were retrieved at all, the peer violated its TD promise that it had a
+ // better chain compared to ours. The only exception is if its promised blocks were
+ // already imported by other means (e.g. fetcher):
+ //
+ // R , L : Both at block 10
+ // R: Mine block 11, and propagate it to L
+ // L: Queue block 11 for import
+ // L: Notice that R's head and TD increased compared to ours, start sync
+ // L: Import of block 11 finishes
+ // L: Sync begins, and finds common ancestor at 11
+ // L: Request new headers up from 11 (R's TD was higher, it must have something)
+ // R: Nothing to give
+ if mode != LightSync {
+ head := d.blockchain.CurrentBlock()
+ if !gotHeaders && td.Cmp(d.blockchain.GetTd(head.Hash(), head.NumberU64())) > 0 {
+ return errStallingPeer
+ }
+ }
+ // If fast or light syncing, ensure promised headers are indeed delivered. This is
+ // needed to detect scenarios where an attacker feeds a bad pivot and then bails out
+ // of delivering the post-pivot blocks that would flag the invalid content.
+ //
+ // This check cannot be executed "as is" for full imports, since blocks may still be
+ // queued for processing when the header download completes. However, as long as the
+ // peer gave us something useful, we're already happy/progressed (above check).
+ if mode == FastSync || mode == LightSync {
+ head := d.lightchain.CurrentHeader()
+ if td.Cmp(d.lightchain.GetTd(head.Hash(), head.Number.Uint64())) > 0 {
+ return errStallingPeer
+ }
+ }
+ // Disable any rollback and return
+ rollback = 0
+ return nil
+ }
+ // Otherwise split the chunk of headers into batches and process them
+ gotHeaders = true
+ for len(headers) > 0 {
+ // Terminate if something failed in between processing chunks
+ select {
+ case <-d.cancelCh:
+ rollbackErr = errCanceled
+ return errCanceled
+ default:
+ }
+ // Select the next chunk of headers to import
+ limit := maxHeadersProcess
+ if limit > len(headers) {
+ limit = len(headers)
+ }
+ chunk := headers[:limit]
+
+ // In case of header only syncing, validate the chunk immediately
+ if mode == FastSync || mode == LightSync {
+ // If we're importing pure headers, verify based on their recentness
+ var pivot uint64
+
+ d.pivotLock.RLock()
+ if d.pivotHeader != nil {
+ pivot = d.pivotHeader.Number.Uint64()
+ }
+ d.pivotLock.RUnlock()
+
+ frequency := fsHeaderCheckFrequency
+ if chunk[len(chunk)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot {
+ frequency = 1
+ }
+ if n, err := d.lightchain.InsertHeaderChain(chunk, frequency); err != nil {
+ rollbackErr = err
+
+ // If some headers were inserted, track them as uncertain
+ if (mode == FastSync || frequency > 1) && n > 0 && rollback == 0 {
+ rollback = chunk[0].Number.Uint64()
+ }
+ log.Warn("Invalid header encountered", "number", chunk[n].Number, "hash", chunk[n].Hash(), "parent", chunk[n].ParentHash, "err", err)
+ return fmt.Errorf("%w: %v", errInvalidChain, err)
+ }
+ // All verifications passed, track all headers within the alloted limits
+ if mode == FastSync {
+ head := chunk[len(chunk)-1].Number.Uint64()
+ if head-rollback > uint64(fsHeaderSafetyNet) {
+ rollback = head - uint64(fsHeaderSafetyNet)
+ } else {
+ rollback = 1
+ }
+ }
+ }
+ // Unless we're doing light chains, schedule the headers for associated content retrieval
+ if mode == FullSync || mode == FastSync {
+ // If we've reached the allowed number of pending headers, stall a bit
+ for d.queue.PendingBlocks() >= maxQueuedHeaders || d.queue.PendingReceipts() >= maxQueuedHeaders {
+ select {
+ case <-d.cancelCh:
+ rollbackErr = errCanceled
+ return errCanceled
+ case <-time.After(time.Second):
+ }
+ }
+ // Otherwise insert the headers for content retrieval
+ inserts := d.queue.Schedule(chunk, origin)
+ if len(inserts) != len(chunk) {
+ rollbackErr = fmt.Errorf("stale headers: len inserts %v len(chunk) %v", len(inserts), len(chunk))
+ return fmt.Errorf("%w: stale headers", errBadPeer)
+ }
+ }
+ headers = headers[limit:]
+ origin += uint64(limit)
+ }
+ // Update the highest block number we know if a higher one is found.
+ d.syncStatsLock.Lock()
+ if d.syncStatsChainHeight < origin {
+ d.syncStatsChainHeight = origin - 1
+ }
+ d.syncStatsLock.Unlock()
+
+ // Signal the content downloaders of the availablility of new tasks
+ for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} {
+ select {
+ case ch <- true:
+ default:
+ }
+ }
+ }
+ }
+}
+
+// processFullSyncContent takes fetch results from the queue and imports them into the chain.
+func (d *Downloader) processFullSyncContent() error {
+ for {
+ results := d.queue.Results(true)
+ if len(results) == 0 {
+ return nil
+ }
+ if d.chainInsertHook != nil {
+ d.chainInsertHook(results)
+ }
+ if err := d.importBlockResults(results); err != nil {
+ return err
+ }
+ }
+}
+
+func (d *Downloader) importBlockResults(results []*fetchResult) error {
+ // Check for any early termination requests
+ if len(results) == 0 {
+ return nil
+ }
+ select {
+ case <-d.quitCh:
+ return errCancelContentProcessing
+ default:
+ }
+ // Retrieve the a batch of results to import
+ first, last := results[0].Header, results[len(results)-1].Header
+ log.Debug("Inserting downloaded chain", "items", len(results),
+ "firstnum", first.Number, "firsthash", first.Hash(),
+ "lastnum", last.Number, "lasthash", last.Hash(),
+ )
+ blocks := make([]*types.Block, len(results))
+ for i, result := range results {
+ blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ }
+ if index, err := d.blockchain.InsertChain(blocks); err != nil {
+ if index < len(results) {
+ log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+ } else {
+ // The InsertChain method in blockchain.go will sometimes return an out-of-bounds index,
+ // when it needs to preprocess blocks to import a sidechain.
+ // The importer will put together a new list of blocks to import, which is a superset
+ // of the blocks delivered from the downloader, and the indexing will be off.
+ log.Debug("Downloaded item processing failed on sidechain import", "index", index, "err", err)
+ }
+ return fmt.Errorf("%w: %v", errInvalidChain, err)
+ }
+ return nil
+}
+
+// processFastSyncContent takes fetch results from the queue and writes them to the
+// database. It also controls the synchronisation of state nodes of the pivot block.
+func (d *Downloader) processFastSyncContent() error {
+ // Start syncing state of the reported head block. This should get us most of
+ // the state of the pivot block.
+ d.pivotLock.RLock()
+ sync := d.syncState(d.pivotHeader.Root)
+ d.pivotLock.RUnlock()
+
+ defer func() {
+ // The `sync` object is replaced every time the pivot moves. We need to
+ // defer close the very last active one, hence the lazy evaluation vs.
+ // calling defer sync.Cancel() !!!
+ sync.Cancel()
+ }()
+
+ closeOnErr := func(s *stateSync) {
+ if err := s.Wait(); err != nil && err != errCancelStateFetch && err != errCanceled && err != snap.ErrCancelled {
+ d.queue.Close() // wake up Results
+ }
+ }
+ go closeOnErr(sync)
+
+ // To cater for moving pivot points, track the pivot block and subsequently
+ // accumulated download results separately.
+ var (
+ oldPivot *fetchResult // Locked in pivot block, might change eventually
+ oldTail []*fetchResult // Downloaded content after the pivot
+ )
+ for {
+ // Wait for the next batch of downloaded data to be available, and if the pivot
+ // block became stale, move the goalpost
+ results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness
+ if len(results) == 0 {
+ // If pivot sync is done, stop
+ if oldPivot == nil {
+ return sync.Cancel()
+ }
+ // If sync failed, stop
+ select {
+ case <-d.cancelCh:
+ sync.Cancel()
+ return errCanceled
+ default:
+ }
+ }
+ if d.chainInsertHook != nil {
+ d.chainInsertHook(results)
+ }
+ // If we haven't downloaded the pivot block yet, check pivot staleness
+ // notifications from the header downloader
+ d.pivotLock.RLock()
+ pivot := d.pivotHeader
+ d.pivotLock.RUnlock()
+
+ if oldPivot == nil {
+ if pivot.Root != sync.root {
+ sync.Cancel()
+ sync = d.syncState(pivot.Root)
+
+ go closeOnErr(sync)
+ }
+ } else {
+ results = append(append([]*fetchResult{oldPivot}, oldTail...), results...)
+ }
+ // Split around the pivot block and process the two sides via fast/full sync
+ if atomic.LoadInt32(&d.committed) == 0 {
+ latest := results[len(results)-1].Header
+ // If the height is above the pivot block by 2 sets, it means the pivot
+ // become stale in the network and it was garbage collected, move to a
+ // new pivot.
+ //
+ // Note, we have `reorgProtHeaderDelay` number of blocks withheld, Those
+ // need to be taken into account, otherwise we're detecting the pivot move
+ // late and will drop peers due to unavailable state!!!
+ if height := latest.Number.Uint64(); height >= pivot.Number.Uint64()+2*uint64(fsMinFullBlocks)-uint64(reorgProtHeaderDelay) {
+ log.Warn("Pivot became stale, moving", "old", pivot.Number.Uint64(), "new", height-uint64(fsMinFullBlocks)+uint64(reorgProtHeaderDelay))
+ pivot = results[len(results)-1-fsMinFullBlocks+reorgProtHeaderDelay].Header // must exist as lower old pivot is uncommitted
+
+ d.pivotLock.Lock()
+ d.pivotHeader = pivot
+ d.pivotLock.Unlock()
+
+ // Write out the pivot into the database so a rollback beyond it will
+ // reenable fast sync
+ rawdb.WriteLastPivotNumber(d.stateDB, pivot.Number.Uint64())
+ }
+ }
+ P, beforeP, afterP := splitAroundPivot(pivot.Number.Uint64(), results)
+ if err := d.commitFastSyncData(beforeP, sync); err != nil {
+ return err
+ }
+ if P != nil {
+ // If new pivot block found, cancel old state retrieval and restart
+ if oldPivot != P {
+ sync.Cancel()
+ sync = d.syncState(P.Header.Root)
+
+ go closeOnErr(sync)
+ oldPivot = P
+ }
+ // Wait for completion, occasionally checking for pivot staleness
+ select {
+ case <-sync.done:
+ if sync.err != nil {
+ return sync.err
+ }
+ if err := d.commitPivotBlock(P); err != nil {
+ return err
+ }
+ oldPivot = nil
+
+ case <-time.After(time.Second):
+ oldTail = afterP
+ continue
+ }
+ }
+ // Fast sync done, pivot commit done, full import
+ if err := d.importBlockResults(afterP); err != nil {
+ return err
+ }
+ }
+}
+
+func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, before, after []*fetchResult) {
+ if len(results) == 0 {
+ return nil, nil, nil
+ }
+ if lastNum := results[len(results)-1].Header.Number.Uint64(); lastNum < pivot {
+ // the pivot is somewhere in the future
+ return nil, results, nil
+ }
+ // This can also be optimized, but only happens very seldom
+ for _, result := range results {
+ num := result.Header.Number.Uint64()
+ switch {
+ case num < pivot:
+ before = append(before, result)
+ case num == pivot:
+ p = result
+ default:
+ after = append(after, result)
+ }
+ }
+ return p, before, after
+}
+
+func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error {
+ // Check for any early termination requests
+ if len(results) == 0 {
+ return nil
+ }
+ select {
+ case <-d.quitCh:
+ return errCancelContentProcessing
+ case <-stateSync.done:
+ if err := stateSync.Wait(); err != nil {
+ return err
+ }
+ default:
+ }
+ // Retrieve the a batch of results to import
+ first, last := results[0].Header, results[len(results)-1].Header
+ log.Debug("Inserting fast-sync blocks", "items", len(results),
+ "firstnum", first.Number, "firsthash", first.Hash(),
+ "lastnumn", last.Number, "lasthash", last.Hash(),
+ )
+ blocks := make([]*types.Block, len(results))
+ receipts := make([]types.Receipts, len(results))
+ for i, result := range results {
+ blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ receipts[i] = result.Receipts
+ }
+ if index, err := d.blockchain.InsertReceiptChain(blocks, receipts, d.ancientLimit); err != nil {
+ log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+ return fmt.Errorf("%w: %v", errInvalidChain, err)
+ }
+ return nil
+}
+
+func (d *Downloader) commitPivotBlock(result *fetchResult) error {
+ block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+ log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash())
+
+ // Commit the pivot block as the new head, will require full sync from here on
+ if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}, d.ancientLimit); err != nil {
+ return err
+ }
+ if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil {
+ return err
+ }
+ atomic.StoreInt32(&d.committed, 1)
+
+ // If we had a bloom filter for the state sync, deallocate it now. Note, we only
+ // deallocate internally, but keep the empty wrapper. This ensures that if we do
+ // a rollback after committing the pivot and restarting fast sync, we don't end
+ // up using a nil bloom. Empty bloom is fine, it just returns that it does not
+ // have the info we need, so reach down to the database instead.
+ if d.stateBloom != nil {
+ d.stateBloom.Close()
+ }
+ return nil
+}
+
+// DeliverHeaders injects a new batch of block headers received from a remote
+// node into the download schedule.
+func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) error {
+ return d.deliver(d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter)
+}
+
+// DeliverBodies injects a new batch of block bodies received from a remote node.
+func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) error {
+ return d.deliver(d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter)
+}
+
+// DeliverReceipts injects a new batch of receipts received from a remote node.
+func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) error {
+ return d.deliver(d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter)
+}
+
+// DeliverNodeData injects a new batch of node state data received from a remote node.
+func (d *Downloader) DeliverNodeData(id string, data [][]byte) error {
+ return d.deliver(d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter)
+}
+
+// DeliverSnapPacket is invoked from a peer's message handler when it transmits a
+// data packet for the local node to consume.
+func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error {
+ switch packet := packet.(type) {
+ case *snap.AccountRangePacket:
+ hashes, accounts, err := packet.Unpack()
+ if err != nil {
+ return err
+ }
+ return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof)
+
+ case *snap.StorageRangesPacket:
+ hashset, slotset := packet.Unpack()
+ return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof)
+
+ case *snap.ByteCodesPacket:
+ return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes)
+
+ case *snap.TrieNodesPacket:
+ return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes)
+
+ default:
+ return fmt.Errorf("unexpected snap packet type: %T", packet)
+ }
+}
+
+// deliver injects a new batch of data received from a remote node.
+func (d *Downloader) deliver(destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) {
+ // Update the delivery metrics for both good and failed deliveries
+ inMeter.Mark(int64(packet.Items()))
+ defer func() {
+ if err != nil {
+ dropMeter.Mark(int64(packet.Items()))
+ }
+ }()
+ // Deliver or abort if the sync is canceled while queuing
+ d.cancelLock.RLock()
+ cancel := d.cancelCh
+ d.cancelLock.RUnlock()
+ if cancel == nil {
+ return errNoSyncActive
+ }
+ select {
+ case destCh <- packet:
+ return nil
+ case <-cancel:
+ return errNoSyncActive
+ }
+}
diff --git a/les/downloader/downloader_test.go b/les/downloader/downloader_test.go
new file mode 100644
index 0000000000..17cd3630c9
--- /dev/null
+++ b/les/downloader/downloader_test.go
@@ -0,0 +1,1622 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "errors"
+ "fmt"
+ "math/big"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state/snapshot"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+// Reduce some of the parameters to make the tester faster.
+func init() {
+ fullMaxForkAncestry = 10000
+ lightMaxForkAncestry = 10000
+ blockCacheMaxItems = 1024
+ fsHeaderContCheck = 500 * time.Millisecond
+}
+
+// downloadTester is a test simulator for mocking out local block chain.
+type downloadTester struct {
+ downloader *Downloader
+
+ genesis *types.Block // Genesis blocks used by the tester and peers
+ stateDb ethdb.Database // Database used by the tester for syncing from peers
+ peerDb ethdb.Database // Database of the peers containing all data
+ peers map[string]*downloadTesterPeer
+
+ ownHashes []common.Hash // Hash chain belonging to the tester
+ ownHeaders map[common.Hash]*types.Header // Headers belonging to the tester
+ ownBlocks map[common.Hash]*types.Block // Blocks belonging to the tester
+ ownReceipts map[common.Hash]types.Receipts // Receipts belonging to the tester
+ ownChainTd map[common.Hash]*big.Int // Total difficulties of the blocks in the local chain
+
+ ancientHeaders map[common.Hash]*types.Header // Ancient headers belonging to the tester
+ ancientBlocks map[common.Hash]*types.Block // Ancient blocks belonging to the tester
+ ancientReceipts map[common.Hash]types.Receipts // Ancient receipts belonging to the tester
+ ancientChainTd map[common.Hash]*big.Int // Ancient total difficulties of the blocks in the local chain
+
+ lock sync.RWMutex
+}
+
+// newTester creates a new downloader test mocker.
+func newTester() *downloadTester {
+ tester := &downloadTester{
+ genesis: testGenesis,
+ peerDb: testDB,
+ peers: make(map[string]*downloadTesterPeer),
+ ownHashes: []common.Hash{testGenesis.Hash()},
+ ownHeaders: map[common.Hash]*types.Header{testGenesis.Hash(): testGenesis.Header()},
+ ownBlocks: map[common.Hash]*types.Block{testGenesis.Hash(): testGenesis},
+ ownReceipts: map[common.Hash]types.Receipts{testGenesis.Hash(): nil},
+ ownChainTd: map[common.Hash]*big.Int{testGenesis.Hash(): testGenesis.Difficulty()},
+
+ // Initialize ancient store with test genesis block
+ ancientHeaders: map[common.Hash]*types.Header{testGenesis.Hash(): testGenesis.Header()},
+ ancientBlocks: map[common.Hash]*types.Block{testGenesis.Hash(): testGenesis},
+ ancientReceipts: map[common.Hash]types.Receipts{testGenesis.Hash(): nil},
+ ancientChainTd: map[common.Hash]*big.Int{testGenesis.Hash(): testGenesis.Difficulty()},
+ }
+ tester.stateDb = rawdb.NewMemoryDatabase()
+ tester.stateDb.Put(testGenesis.Root().Bytes(), []byte{0x00})
+
+ tester.downloader = New(0, tester.stateDb, trie.NewSyncBloom(1, tester.stateDb), new(event.TypeMux), tester, nil, tester.dropPeer)
+ return tester
+}
+
+// terminate aborts any operations on the embedded downloader and releases all
+// held resources.
+func (dl *downloadTester) terminate() {
+ dl.downloader.Terminate()
+}
+
+// sync starts synchronizing with a remote peer, blocking until it completes.
+func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error {
+ dl.lock.RLock()
+ hash := dl.peers[id].chain.headBlock().Hash()
+ // If no particular TD was requested, load from the peer's blockchain
+ if td == nil {
+ td = dl.peers[id].chain.td(hash)
+ }
+ dl.lock.RUnlock()
+
+ // Synchronise with the chosen peer and ensure proper cleanup afterwards
+ err := dl.downloader.synchronise(id, hash, td, mode)
+ select {
+ case <-dl.downloader.cancelCh:
+ // Ok, downloader fully cancelled after sync cycle
+ default:
+ // Downloader is still accepting packets, can block a peer up
+ panic("downloader active post sync cycle") // panic will be caught by tester
+ }
+ return err
+}
+
+// HasHeader checks if a header is present in the testers canonical chain.
+func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool {
+ return dl.GetHeaderByHash(hash) != nil
+}
+
+// HasBlock checks if a block is present in the testers canonical chain.
+func (dl *downloadTester) HasBlock(hash common.Hash, number uint64) bool {
+ return dl.GetBlockByHash(hash) != nil
+}
+
+// HasFastBlock checks if a block is present in the testers canonical chain.
+func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if _, ok := dl.ancientReceipts[hash]; ok {
+ return true
+ }
+ _, ok := dl.ownReceipts[hash]
+ return ok
+}
+
+// GetHeader retrieves a header from the testers canonical chain.
+func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+ return dl.getHeaderByHash(hash)
+}
+
+// getHeaderByHash returns the header if found either within ancients or own blocks)
+// This method assumes that the caller holds at least the read-lock (dl.lock)
+func (dl *downloadTester) getHeaderByHash(hash common.Hash) *types.Header {
+ header := dl.ancientHeaders[hash]
+ if header != nil {
+ return header
+ }
+ return dl.ownHeaders[hash]
+}
+
+// GetBlock retrieves a block from the testers canonical chain.
+func (dl *downloadTester) GetBlockByHash(hash common.Hash) *types.Block {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ block := dl.ancientBlocks[hash]
+ if block != nil {
+ return block
+ }
+ return dl.ownBlocks[hash]
+}
+
+// CurrentHeader retrieves the current head header from the canonical chain.
+func (dl *downloadTester) CurrentHeader() *types.Header {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ for i := len(dl.ownHashes) - 1; i >= 0; i-- {
+ if header := dl.ancientHeaders[dl.ownHashes[i]]; header != nil {
+ return header
+ }
+ if header := dl.ownHeaders[dl.ownHashes[i]]; header != nil {
+ return header
+ }
+ }
+ return dl.genesis.Header()
+}
+
+// CurrentBlock retrieves the current head block from the canonical chain.
+func (dl *downloadTester) CurrentBlock() *types.Block {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ for i := len(dl.ownHashes) - 1; i >= 0; i-- {
+ if block := dl.ancientBlocks[dl.ownHashes[i]]; block != nil {
+ if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil {
+ return block
+ }
+ return block
+ }
+ if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
+ if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil {
+ return block
+ }
+ }
+ }
+ return dl.genesis
+}
+
+// CurrentFastBlock retrieves the current head fast-sync block from the canonical chain.
+func (dl *downloadTester) CurrentFastBlock() *types.Block {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ for i := len(dl.ownHashes) - 1; i >= 0; i-- {
+ if block := dl.ancientBlocks[dl.ownHashes[i]]; block != nil {
+ return block
+ }
+ if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil {
+ return block
+ }
+ }
+ return dl.genesis
+}
+
+// FastSyncCommitHead manually sets the head block to a given hash.
+func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
+ // For now only check that the state trie is correct
+ if block := dl.GetBlockByHash(hash); block != nil {
+ _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb))
+ return err
+ }
+ return fmt.Errorf("non existent block: %x", hash[:4])
+}
+
+// GetTd retrieves the block's total difficulty from the canonical chain.
+func (dl *downloadTester) GetTd(hash common.Hash, number uint64) *big.Int {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.getTd(hash)
+}
+
+// getTd retrieves the block's total difficulty if found either within
+// ancients or own blocks).
+// This method assumes that the caller holds at least the read-lock (dl.lock)
+func (dl *downloadTester) getTd(hash common.Hash) *big.Int {
+ if td := dl.ancientChainTd[hash]; td != nil {
+ return td
+ }
+ return dl.ownChainTd[hash]
+}
+
+// InsertHeaderChain injects a new batch of headers into the simulated chain.
+func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (i int, err error) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+ // Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anything in case of errors
+ if dl.getHeaderByHash(headers[0].ParentHash) == nil {
+ return 0, fmt.Errorf("InsertHeaderChain: unknown parent at first position, parent of number %d", headers[0].Number)
+ }
+ var hashes []common.Hash
+ for i := 1; i < len(headers); i++ {
+ hash := headers[i-1].Hash()
+ if headers[i].ParentHash != headers[i-1].Hash() {
+ return i, fmt.Errorf("non-contiguous import at position %d", i)
+ }
+ hashes = append(hashes, hash)
+ }
+ hashes = append(hashes, headers[len(headers)-1].Hash())
+ // Do a full insert if pre-checks passed
+ for i, header := range headers {
+ hash := hashes[i]
+ if dl.getHeaderByHash(hash) != nil {
+ continue
+ }
+ if dl.getHeaderByHash(header.ParentHash) == nil {
+ // This _should_ be impossible, due to precheck and induction
+ return i, fmt.Errorf("InsertHeaderChain: unknown parent at position %d", i)
+ }
+ dl.ownHashes = append(dl.ownHashes, hash)
+ dl.ownHeaders[hash] = header
+
+ td := dl.getTd(header.ParentHash)
+ dl.ownChainTd[hash] = new(big.Int).Add(td, header.Difficulty)
+ }
+ return len(headers), nil
+}
+
+// InsertChain injects a new batch of blocks into the simulated chain.
+func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+ for i, block := range blocks {
+ if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok {
+ return i, fmt.Errorf("InsertChain: unknown parent at position %d / %d", i, len(blocks))
+ } else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil {
+ return i, fmt.Errorf("InsertChain: unknown parent state %x: %v", parent.Root(), err)
+ }
+ if hdr := dl.getHeaderByHash(block.Hash()); hdr == nil {
+ dl.ownHashes = append(dl.ownHashes, block.Hash())
+ dl.ownHeaders[block.Hash()] = block.Header()
+ }
+ dl.ownBlocks[block.Hash()] = block
+ dl.ownReceipts[block.Hash()] = make(types.Receipts, 0)
+ dl.stateDb.Put(block.Root().Bytes(), []byte{0x00})
+ td := dl.getTd(block.ParentHash())
+ dl.ownChainTd[block.Hash()] = new(big.Int).Add(td, block.Difficulty())
+ }
+ return len(blocks), nil
+}
+
+// InsertReceiptChain injects a new batch of receipts into the simulated chain.
+func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []types.Receipts, ancientLimit uint64) (i int, err error) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ for i := 0; i < len(blocks) && i < len(receipts); i++ {
+ if _, ok := dl.ownHeaders[blocks[i].Hash()]; !ok {
+ return i, errors.New("unknown owner")
+ }
+ if _, ok := dl.ancientBlocks[blocks[i].ParentHash()]; !ok {
+ if _, ok := dl.ownBlocks[blocks[i].ParentHash()]; !ok {
+ return i, errors.New("InsertReceiptChain: unknown parent")
+ }
+ }
+ if blocks[i].NumberU64() <= ancientLimit {
+ dl.ancientBlocks[blocks[i].Hash()] = blocks[i]
+ dl.ancientReceipts[blocks[i].Hash()] = receipts[i]
+
+ // Migrate from active db to ancient db
+ dl.ancientHeaders[blocks[i].Hash()] = blocks[i].Header()
+ dl.ancientChainTd[blocks[i].Hash()] = new(big.Int).Add(dl.ancientChainTd[blocks[i].ParentHash()], blocks[i].Difficulty())
+ delete(dl.ownHeaders, blocks[i].Hash())
+ delete(dl.ownChainTd, blocks[i].Hash())
+ } else {
+ dl.ownBlocks[blocks[i].Hash()] = blocks[i]
+ dl.ownReceipts[blocks[i].Hash()] = receipts[i]
+ }
+ }
+ return len(blocks), nil
+}
+
+// SetHead rewinds the local chain to a new head.
+func (dl *downloadTester) SetHead(head uint64) error {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ // Find the hash of the head to reset to
+ var hash common.Hash
+ for h, header := range dl.ownHeaders {
+ if header.Number.Uint64() == head {
+ hash = h
+ }
+ }
+ for h, header := range dl.ancientHeaders {
+ if header.Number.Uint64() == head {
+ hash = h
+ }
+ }
+ if hash == (common.Hash{}) {
+ return fmt.Errorf("unknown head to set: %d", head)
+ }
+ // Find the offset in the header chain
+ var offset int
+ for o, h := range dl.ownHashes {
+ if h == hash {
+ offset = o
+ break
+ }
+ }
+ // Remove all the hashes and associated data afterwards
+ for i := offset + 1; i < len(dl.ownHashes); i++ {
+ delete(dl.ownChainTd, dl.ownHashes[i])
+ delete(dl.ownHeaders, dl.ownHashes[i])
+ delete(dl.ownReceipts, dl.ownHashes[i])
+ delete(dl.ownBlocks, dl.ownHashes[i])
+
+ delete(dl.ancientChainTd, dl.ownHashes[i])
+ delete(dl.ancientHeaders, dl.ownHashes[i])
+ delete(dl.ancientReceipts, dl.ownHashes[i])
+ delete(dl.ancientBlocks, dl.ownHashes[i])
+ }
+ dl.ownHashes = dl.ownHashes[:offset+1]
+ return nil
+}
+
+// Rollback removes some recently added elements from the chain.
+func (dl *downloadTester) Rollback(hashes []common.Hash) {
+}
+
+// newPeer registers a new block download source into the downloader.
+func (dl *downloadTester) newPeer(id string, version uint, chain *testChain) error {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ peer := &downloadTesterPeer{dl: dl, id: id, chain: chain}
+ dl.peers[id] = peer
+ return dl.downloader.RegisterPeer(id, version, peer)
+}
+
+// dropPeer simulates a hard peer removal from the connection pool.
+func (dl *downloadTester) dropPeer(id string) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ delete(dl.peers, id)
+ dl.downloader.UnregisterPeer(id)
+}
+
+// Snapshots implements the BlockChain interface for the downloader, but is a noop.
+func (dl *downloadTester) Snapshots() *snapshot.Tree {
+ return nil
+}
+
+type downloadTesterPeer struct {
+ dl *downloadTester
+ id string
+ chain *testChain
+ missingStates map[common.Hash]bool // State entries that fast sync should not return
+}
+
+// Head constructs a function to retrieve a peer's current head hash
+// and total difficulty.
+func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int) {
+ b := dlp.chain.headBlock()
+ return b.Hash(), dlp.chain.td(b.Hash())
+}
+
+// RequestHeadersByHash constructs a GetBlockHeaders function based on a hashed
+// origin; associated with a particular peer in the download tester. The returned
+// function can be used to retrieve batches of headers from the particular peer.
+func (dlp *downloadTesterPeer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error {
+ result := dlp.chain.headersByHash(origin, amount, skip, reverse)
+ go dlp.dl.downloader.DeliverHeaders(dlp.id, result)
+ return nil
+}
+
+// RequestHeadersByNumber constructs a GetBlockHeaders function based on a numbered
+// origin; associated with a particular peer in the download tester. The returned
+// function can be used to retrieve batches of headers from the particular peer.
+func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error {
+ result := dlp.chain.headersByNumber(origin, amount, skip, reverse)
+ go dlp.dl.downloader.DeliverHeaders(dlp.id, result)
+ return nil
+}
+
+// RequestBodies constructs a getBlockBodies method associated with a particular
+// peer in the download tester. The returned function can be used to retrieve
+// batches of block bodies from the particularly requested peer.
+func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash) error {
+ txs, uncles := dlp.chain.bodies(hashes)
+ go dlp.dl.downloader.DeliverBodies(dlp.id, txs, uncles)
+ return nil
+}
+
+// RequestReceipts constructs a getReceipts method associated with a particular
+// peer in the download tester. The returned function can be used to retrieve
+// batches of block receipts from the particularly requested peer.
+func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash) error {
+ receipts := dlp.chain.receipts(hashes)
+ go dlp.dl.downloader.DeliverReceipts(dlp.id, receipts)
+ return nil
+}
+
+// RequestNodeData constructs a getNodeData method associated with a particular
+// peer in the download tester. The returned function can be used to retrieve
+// batches of node state data from the particularly requested peer.
+func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error {
+ dlp.dl.lock.RLock()
+ defer dlp.dl.lock.RUnlock()
+
+ results := make([][]byte, 0, len(hashes))
+ for _, hash := range hashes {
+ if data, err := dlp.dl.peerDb.Get(hash.Bytes()); err == nil {
+ if !dlp.missingStates[hash] {
+ results = append(results, data)
+ }
+ }
+ }
+ go dlp.dl.downloader.DeliverNodeData(dlp.id, results)
+ return nil
+}
+
+// assertOwnChain checks if the local chain contains the correct number of items
+// of the various chain components.
+func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
+ // Mark this method as a helper to report errors at callsite, not in here
+ t.Helper()
+
+ assertOwnForkedChain(t, tester, 1, []int{length})
+}
+
+// assertOwnForkedChain checks if the local forked chain contains the correct
+// number of items of the various chain components.
+func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
+ // Mark this method as a helper to report errors at callsite, not in here
+ t.Helper()
+
+ // Initialize the counters for the first fork
+ headers, blocks, receipts := lengths[0], lengths[0], lengths[0]
+
+ // Update the counters for each subsequent fork
+ for _, length := range lengths[1:] {
+ headers += length - common
+ blocks += length - common
+ receipts += length - common
+ }
+ if tester.downloader.getMode() == LightSync {
+ blocks, receipts = 1, 1
+ }
+ if hs := len(tester.ownHeaders) + len(tester.ancientHeaders) - 1; hs != headers {
+ t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
+ }
+ if bs := len(tester.ownBlocks) + len(tester.ancientBlocks) - 1; bs != blocks {
+ t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
+ }
+ if rs := len(tester.ownReceipts) + len(tester.ancientReceipts) - 1; rs != receipts {
+ t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts)
+ }
+}
+
+func TestCanonicalSynchronisation66Full(t *testing.T) { testCanonSync(t, eth.ETH66, FullSync) }
+func TestCanonicalSynchronisation66Fast(t *testing.T) { testCanonSync(t, eth.ETH66, FastSync) }
+func TestCanonicalSynchronisation66Light(t *testing.T) { testCanonSync(t, eth.ETH66, LightSync) }
+
+func testCanonSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ // Create a small enough block chain to download
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+ tester.newPeer("peer", protocol, chain)
+
+ // Synchronise with the peer and make sure all relevant data was retrieved
+ if err := tester.sync("peer", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+}
+
+// Tests that if a large batch of blocks are being downloaded, it is throttled
+// until the cached blocks are retrieved.
+func TestThrottling66Full(t *testing.T) { testThrottling(t, eth.ETH66, FullSync) }
+func TestThrottling66Fast(t *testing.T) { testThrottling(t, eth.ETH66, FastSync) }
+
+func testThrottling(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+ tester := newTester()
+
+ // Create a long block chain to download and the tester
+ targetBlocks := testChainBase.len() - 1
+ tester.newPeer("peer", protocol, testChainBase)
+
+ // Wrap the importer to allow stepping
+ blocked, proceed := uint32(0), make(chan struct{})
+ tester.downloader.chainInsertHook = func(results []*fetchResult) {
+ atomic.StoreUint32(&blocked, uint32(len(results)))
+ <-proceed
+ }
+ // Start a synchronisation concurrently
+ errc := make(chan error, 1)
+ go func() {
+ errc <- tester.sync("peer", nil, mode)
+ }()
+ // Iteratively take some blocks, always checking the retrieval count
+ for {
+ // Check the retrieval count synchronously (! reason for this ugly block)
+ tester.lock.RLock()
+ retrieved := len(tester.ownBlocks)
+ tester.lock.RUnlock()
+ if retrieved >= targetBlocks+1 {
+ break
+ }
+ // Wait a bit for sync to throttle itself
+ var cached, frozen int
+ for start := time.Now(); time.Since(start) < 3*time.Second; {
+ time.Sleep(25 * time.Millisecond)
+
+ tester.lock.Lock()
+ tester.downloader.queue.lock.Lock()
+ tester.downloader.queue.resultCache.lock.Lock()
+ {
+ cached = tester.downloader.queue.resultCache.countCompleted()
+ frozen = int(atomic.LoadUint32(&blocked))
+ retrieved = len(tester.ownBlocks)
+ }
+ tester.downloader.queue.resultCache.lock.Unlock()
+ tester.downloader.queue.lock.Unlock()
+ tester.lock.Unlock()
+
+ if cached == blockCacheMaxItems ||
+ cached == blockCacheMaxItems-reorgProtHeaderDelay ||
+ retrieved+cached+frozen == targetBlocks+1 ||
+ retrieved+cached+frozen == targetBlocks+1-reorgProtHeaderDelay {
+ break
+ }
+ }
+ // Make sure we filled up the cache, then exhaust it
+ time.Sleep(25 * time.Millisecond) // give it a chance to screw up
+ tester.lock.RLock()
+ retrieved = len(tester.ownBlocks)
+ tester.lock.RUnlock()
+ if cached != blockCacheMaxItems && cached != blockCacheMaxItems-reorgProtHeaderDelay && retrieved+cached+frozen != targetBlocks+1 && retrieved+cached+frozen != targetBlocks+1-reorgProtHeaderDelay {
+ t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheMaxItems, retrieved, frozen, targetBlocks+1)
+ }
+
+ // Permit the blocked blocks to import
+ if atomic.LoadUint32(&blocked) > 0 {
+ atomic.StoreUint32(&blocked, uint32(0))
+ proceed <- struct{}{}
+ }
+ }
+ // Check that we haven't pulled more blocks than available
+ assertOwnChain(t, tester, targetBlocks+1)
+ if err := <-errc; err != nil {
+ t.Fatalf("block synchronization failed: %v", err)
+ }
+ tester.terminate()
+
+}
+
+// Tests that simple synchronization against a forked chain works correctly. In
+// this test common ancestor lookup should *not* be short circuited, and a full
+// binary search should be executed.
+func TestForkedSync66Full(t *testing.T) { testForkedSync(t, eth.ETH66, FullSync) }
+func TestForkedSync66Fast(t *testing.T) { testForkedSync(t, eth.ETH66, FastSync) }
+func TestForkedSync66Light(t *testing.T) { testForkedSync(t, eth.ETH66, LightSync) }
+
+func testForkedSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chainA := testChainForkLightA.shorten(testChainBase.len() + 80)
+ chainB := testChainForkLightB.shorten(testChainBase.len() + 80)
+ tester.newPeer("fork A", protocol, chainA)
+ tester.newPeer("fork B", protocol, chainB)
+ // Synchronise with the peer and make sure all blocks were retrieved
+ if err := tester.sync("fork A", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chainA.len())
+
+ // Synchronise with the second peer and make sure that fork is pulled too
+ if err := tester.sync("fork B", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnForkedChain(t, tester, testChainBase.len(), []int{chainA.len(), chainB.len()})
+}
+
+// Tests that synchronising against a much shorter but much heavyer fork works
+// corrently and is not dropped.
+func TestHeavyForkedSync66Full(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, FullSync) }
+func TestHeavyForkedSync66Fast(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, FastSync) }
+func TestHeavyForkedSync66Light(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, LightSync) }
+
+func testHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chainA := testChainForkLightA.shorten(testChainBase.len() + 80)
+ chainB := testChainForkHeavy.shorten(testChainBase.len() + 80)
+ tester.newPeer("light", protocol, chainA)
+ tester.newPeer("heavy", protocol, chainB)
+
+ // Synchronise with the peer and make sure all blocks were retrieved
+ if err := tester.sync("light", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chainA.len())
+
+ // Synchronise with the second peer and make sure that fork is pulled too
+ if err := tester.sync("heavy", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnForkedChain(t, tester, testChainBase.len(), []int{chainA.len(), chainB.len()})
+}
+
+// Tests that chain forks are contained within a certain interval of the current
+// chain head, ensuring that malicious peers cannot waste resources by feeding
+// long dead chains.
+func TestBoundedForkedSync66Full(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, FullSync) }
+func TestBoundedForkedSync66Fast(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, FastSync) }
+func TestBoundedForkedSync66Light(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, LightSync) }
+
+func testBoundedForkedSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chainA := testChainForkLightA
+ chainB := testChainForkLightB
+ tester.newPeer("original", protocol, chainA)
+ tester.newPeer("rewriter", protocol, chainB)
+
+ // Synchronise with the peer and make sure all blocks were retrieved
+ if err := tester.sync("original", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chainA.len())
+
+ // Synchronise with the second peer and ensure that the fork is rejected to being too old
+ if err := tester.sync("rewriter", nil, mode); err != errInvalidAncestor {
+ t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor)
+ }
+}
+
+// Tests that chain forks are contained within a certain interval of the current
+// chain head for short but heavy forks too. These are a bit special because they
+// take different ancestor lookup paths.
+func TestBoundedHeavyForkedSync66Full(t *testing.T) {
+ testBoundedHeavyForkedSync(t, eth.ETH66, FullSync)
+}
+func TestBoundedHeavyForkedSync66Fast(t *testing.T) {
+ testBoundedHeavyForkedSync(t, eth.ETH66, FastSync)
+}
+func TestBoundedHeavyForkedSync66Light(t *testing.T) {
+ testBoundedHeavyForkedSync(t, eth.ETH66, LightSync)
+}
+
+func testBoundedHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+ tester := newTester()
+
+ // Create a long enough forked chain
+ chainA := testChainForkLightA
+ chainB := testChainForkHeavy
+ tester.newPeer("original", protocol, chainA)
+
+ // Synchronise with the peer and make sure all blocks were retrieved
+ if err := tester.sync("original", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chainA.len())
+
+ tester.newPeer("heavy-rewriter", protocol, chainB)
+ // Synchronise with the second peer and ensure that the fork is rejected to being too old
+ if err := tester.sync("heavy-rewriter", nil, mode); err != errInvalidAncestor {
+ t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor)
+ }
+ tester.terminate()
+}
+
+// Tests that an inactive downloader will not accept incoming block headers,
+// bodies and receipts.
+func TestInactiveDownloader63(t *testing.T) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ // Check that neither block headers nor bodies are accepted
+ if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive {
+ t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
+ }
+ if err := tester.downloader.DeliverBodies("bad peer", [][]*types.Transaction{}, [][]*types.Header{}); err != errNoSyncActive {
+ t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
+ }
+ if err := tester.downloader.DeliverReceipts("bad peer", [][]*types.Receipt{}); err != errNoSyncActive {
+ t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive)
+ }
+}
+
+// Tests that a canceled download wipes all previously accumulated state.
+func TestCancel66Full(t *testing.T) { testCancel(t, eth.ETH66, FullSync) }
+func TestCancel66Fast(t *testing.T) { testCancel(t, eth.ETH66, FastSync) }
+func TestCancel66Light(t *testing.T) { testCancel(t, eth.ETH66, LightSync) }
+
+func testCancel(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chain := testChainBase.shorten(MaxHeaderFetch)
+ tester.newPeer("peer", protocol, chain)
+
+ // Make sure canceling works with a pristine downloader
+ tester.downloader.Cancel()
+ if !tester.downloader.queue.Idle() {
+ t.Errorf("download queue not idle")
+ }
+ // Synchronise with the peer, but cancel afterwards
+ if err := tester.sync("peer", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ tester.downloader.Cancel()
+ if !tester.downloader.queue.Idle() {
+ t.Errorf("download queue not idle")
+ }
+}
+
+// Tests that synchronisation from multiple peers works as intended (multi thread sanity test).
+func TestMultiSynchronisation66Full(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, FullSync) }
+func TestMultiSynchronisation66Fast(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, FastSync) }
+func TestMultiSynchronisation66Light(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, LightSync) }
+
+func testMultiSynchronisation(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ // Create various peers with various parts of the chain
+ targetPeers := 8
+ chain := testChainBase.shorten(targetPeers * 100)
+
+ for i := 0; i < targetPeers; i++ {
+ id := fmt.Sprintf("peer #%d", i)
+ tester.newPeer(id, protocol, chain.shorten(chain.len()/(i+1)))
+ }
+ if err := tester.sync("peer #0", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+}
+
+// Tests that synchronisations behave well in multi-version protocol environments
+// and not wreak havoc on other nodes in the network.
+func TestMultiProtoSynchronisation66Full(t *testing.T) { testMultiProtoSync(t, eth.ETH66, FullSync) }
+func TestMultiProtoSynchronisation66Fast(t *testing.T) { testMultiProtoSync(t, eth.ETH66, FastSync) }
+func TestMultiProtoSynchronisation66Light(t *testing.T) { testMultiProtoSync(t, eth.ETH66, LightSync) }
+
+func testMultiProtoSync(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ // Create a small enough block chain to download
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+
+ // Create peers of every type
+ tester.newPeer("peer 66", eth.ETH66, chain)
+ //tester.newPeer("peer 65", eth.ETH67, chain)
+
+ // Synchronise with the requested peer and make sure all blocks were retrieved
+ if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+
+ // Check that no peers have been dropped off
+ for _, version := range []int{66} {
+ peer := fmt.Sprintf("peer %d", version)
+ if _, ok := tester.peers[peer]; !ok {
+ t.Errorf("%s dropped", peer)
+ }
+ }
+}
+
+// Tests that if a block is empty (e.g. header only), no body request should be
+// made, and instead the header should be assembled into a whole block in itself.
+func TestEmptyShortCircuit66Full(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, FullSync) }
+func TestEmptyShortCircuit66Fast(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, FastSync) }
+func TestEmptyShortCircuit66Light(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, LightSync) }
+
+func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ // Create a block chain to download
+ chain := testChainBase
+ tester.newPeer("peer", protocol, chain)
+
+ // Instrument the downloader to signal body requests
+ bodiesHave, receiptsHave := int32(0), int32(0)
+ tester.downloader.bodyFetchHook = func(headers []*types.Header) {
+ atomic.AddInt32(&bodiesHave, int32(len(headers)))
+ }
+ tester.downloader.receiptFetchHook = func(headers []*types.Header) {
+ atomic.AddInt32(&receiptsHave, int32(len(headers)))
+ }
+ // Synchronise with the peer and make sure all blocks were retrieved
+ if err := tester.sync("peer", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+
+ // Validate the number of block bodies that should have been requested
+ bodiesNeeded, receiptsNeeded := 0, 0
+ for _, block := range chain.blockm {
+ if mode != LightSync && block != tester.genesis && (len(block.Transactions()) > 0 || len(block.Uncles()) > 0) {
+ bodiesNeeded++
+ }
+ }
+ for _, receipt := range chain.receiptm {
+ if mode == FastSync && len(receipt) > 0 {
+ receiptsNeeded++
+ }
+ }
+ if int(bodiesHave) != bodiesNeeded {
+ t.Errorf("body retrieval count mismatch: have %v, want %v", bodiesHave, bodiesNeeded)
+ }
+ if int(receiptsHave) != receiptsNeeded {
+ t.Errorf("receipt retrieval count mismatch: have %v, want %v", receiptsHave, receiptsNeeded)
+ }
+}
+
+// Tests that headers are enqueued continuously, preventing malicious nodes from
+// stalling the downloader by feeding gapped header chains.
+func TestMissingHeaderAttack66Full(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, FullSync) }
+func TestMissingHeaderAttack66Fast(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, FastSync) }
+func TestMissingHeaderAttack66Light(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, LightSync) }
+
+func testMissingHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+ brokenChain := chain.shorten(chain.len())
+ delete(brokenChain.headerm, brokenChain.chain[brokenChain.len()/2])
+ tester.newPeer("attack", protocol, brokenChain)
+
+ if err := tester.sync("attack", nil, mode); err == nil {
+ t.Fatalf("succeeded attacker synchronisation")
+ }
+ // Synchronise with the valid peer and make sure sync succeeds
+ tester.newPeer("valid", protocol, chain)
+ if err := tester.sync("valid", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+}
+
+// Tests that if requested headers are shifted (i.e. first is missing), the queue
+// detects the invalid numbering.
+func TestShiftedHeaderAttack66Full(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, FullSync) }
+func TestShiftedHeaderAttack66Fast(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, FastSync) }
+func TestShiftedHeaderAttack66Light(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, LightSync) }
+
+func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+
+ // Attempt a full sync with an attacker feeding shifted headers
+ brokenChain := chain.shorten(chain.len())
+ delete(brokenChain.headerm, brokenChain.chain[1])
+ delete(brokenChain.blockm, brokenChain.chain[1])
+ delete(brokenChain.receiptm, brokenChain.chain[1])
+ tester.newPeer("attack", protocol, brokenChain)
+ if err := tester.sync("attack", nil, mode); err == nil {
+ t.Fatalf("succeeded attacker synchronisation")
+ }
+
+ // Synchronise with the valid peer and make sure sync succeeds
+ tester.newPeer("valid", protocol, chain)
+ if err := tester.sync("valid", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ assertOwnChain(t, tester, chain.len())
+}
+
+// Tests that upon detecting an invalid header, the recent ones are rolled back
+// for various failure scenarios. Afterwards a full sync is attempted to make
+// sure no state was corrupted.
+func TestInvalidHeaderRollback66Fast(t *testing.T) { testInvalidHeaderRollback(t, eth.ETH66, FastSync) }
+
+func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+
+ // Create a small enough block chain to download
+ targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks
+ chain := testChainBase.shorten(targetBlocks)
+
+ // Attempt to sync with an attacker that feeds junk during the fast sync phase.
+ // This should result in the last fsHeaderSafetyNet headers being rolled back.
+ missing := fsHeaderSafetyNet + MaxHeaderFetch + 1
+ fastAttackChain := chain.shorten(chain.len())
+ delete(fastAttackChain.headerm, fastAttackChain.chain[missing])
+ tester.newPeer("fast-attack", protocol, fastAttackChain)
+
+ if err := tester.sync("fast-attack", nil, mode); err == nil {
+ t.Fatalf("succeeded fast attacker synchronisation")
+ }
+ if head := tester.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch {
+ t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
+ }
+
+ // Attempt to sync with an attacker that feeds junk during the block import phase.
+ // This should result in both the last fsHeaderSafetyNet number of headers being
+ // rolled back, and also the pivot point being reverted to a non-block status.
+ missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
+ blockAttackChain := chain.shorten(chain.len())
+ delete(fastAttackChain.headerm, fastAttackChain.chain[missing]) // Make sure the fast-attacker doesn't fill in
+ delete(blockAttackChain.headerm, blockAttackChain.chain[missing])
+ tester.newPeer("block-attack", protocol, blockAttackChain)
+
+ if err := tester.sync("block-attack", nil, mode); err == nil {
+ t.Fatalf("succeeded block attacker synchronisation")
+ }
+ if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
+ t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
+ }
+ if mode == FastSync {
+ if head := tester.CurrentBlock().NumberU64(); head != 0 {
+ t.Errorf("fast sync pivot block #%d not rolled back", head)
+ }
+ }
+
+ // Attempt to sync with an attacker that withholds promised blocks after the
+ // fast sync pivot point. This could be a trial to leave the node with a bad
+ // but already imported pivot block.
+ withholdAttackChain := chain.shorten(chain.len())
+ tester.newPeer("withhold-attack", protocol, withholdAttackChain)
+ tester.downloader.syncInitHook = func(uint64, uint64) {
+ for i := missing; i < withholdAttackChain.len(); i++ {
+ delete(withholdAttackChain.headerm, withholdAttackChain.chain[i])
+ }
+ tester.downloader.syncInitHook = nil
+ }
+ if err := tester.sync("withhold-attack", nil, mode); err == nil {
+ t.Fatalf("succeeded withholding attacker synchronisation")
+ }
+ if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
+ t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
+ }
+ if mode == FastSync {
+ if head := tester.CurrentBlock().NumberU64(); head != 0 {
+ t.Errorf("fast sync pivot block #%d not rolled back", head)
+ }
+ }
+
+ // synchronise with the valid peer and make sure sync succeeds. Since the last rollback
+ // should also disable fast syncing for this process, verify that we did a fresh full
+ // sync. Note, we can't assert anything about the receipts since we won't purge the
+ // database of them, hence we can't use assertOwnChain.
+ tester.newPeer("valid", protocol, chain)
+ if err := tester.sync("valid", nil, mode); err != nil {
+ t.Fatalf("failed to synchronise blocks: %v", err)
+ }
+ if hs := len(tester.ownHeaders); hs != chain.len() {
+ t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, chain.len())
+ }
+ if mode != LightSync {
+ if bs := len(tester.ownBlocks); bs != chain.len() {
+ t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, chain.len())
+ }
+ }
+ tester.terminate()
+}
+
+// Tests that a peer advertising a high TD doesn't get to stall the downloader
+// afterwards by not sending any useful hashes.
+func TestHighTDStarvationAttack66Full(t *testing.T) {
+ testHighTDStarvationAttack(t, eth.ETH66, FullSync)
+}
+func TestHighTDStarvationAttack66Fast(t *testing.T) {
+ testHighTDStarvationAttack(t, eth.ETH66, FastSync)
+}
+func TestHighTDStarvationAttack66Light(t *testing.T) {
+ testHighTDStarvationAttack(t, eth.ETH66, LightSync)
+}
+
+func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+
+ chain := testChainBase.shorten(1)
+ tester.newPeer("attack", protocol, chain)
+ if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer {
+ t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer)
+ }
+ tester.terminate()
+}
+
+// Tests that misbehaving peers are disconnected, whilst behaving ones are not.
+func TestBlockHeaderAttackerDropping66(t *testing.T) { testBlockHeaderAttackerDropping(t, eth.ETH66) }
+
+func testBlockHeaderAttackerDropping(t *testing.T, protocol uint) {
+ t.Parallel()
+
+ // Define the disconnection requirement for individual hash fetch errors
+ tests := []struct {
+ result error
+ drop bool
+ }{
+ {nil, false}, // Sync succeeded, all is well
+ {errBusy, false}, // Sync is already in progress, no problem
+ {errUnknownPeer, false}, // Peer is unknown, was already dropped, don't double drop
+ {errBadPeer, true}, // Peer was deemed bad for some reason, drop it
+ {errStallingPeer, true}, // Peer was detected to be stalling, drop it
+ {errUnsyncedPeer, true}, // Peer was detected to be unsynced, drop it
+ {errNoPeers, false}, // No peers to download from, soft race, no issue
+ {errTimeout, true}, // No hashes received in due time, drop the peer
+ {errEmptyHeaderSet, true}, // No headers were returned as a response, drop as it's a dead end
+ {errPeersUnavailable, true}, // Nobody had the advertised blocks, drop the advertiser
+ {errInvalidAncestor, true}, // Agreed upon ancestor is not acceptable, drop the chain rewriter
+ {errInvalidChain, true}, // Hash chain was detected as invalid, definitely drop
+ {errInvalidBody, false}, // A bad peer was detected, but not the sync origin
+ {errInvalidReceipt, false}, // A bad peer was detected, but not the sync origin
+ {errCancelContentProcessing, false}, // Synchronisation was canceled, origin may be innocent, don't drop
+ }
+ // Run the tests and check disconnection status
+ tester := newTester()
+ defer tester.terminate()
+ chain := testChainBase.shorten(1)
+
+ for i, tt := range tests {
+ // Register a new peer and ensure its presence
+ id := fmt.Sprintf("test %d", i)
+ if err := tester.newPeer(id, protocol, chain); err != nil {
+ t.Fatalf("test %d: failed to register new peer: %v", i, err)
+ }
+ if _, ok := tester.peers[id]; !ok {
+ t.Fatalf("test %d: registered peer not found", i)
+ }
+ // Simulate a synchronisation and check the required result
+ tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result }
+
+ tester.downloader.Synchronise(id, tester.genesis.Hash(), big.NewInt(1000), FullSync)
+ if _, ok := tester.peers[id]; !ok != tt.drop {
+ t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop)
+ }
+ }
+}
+
+// Tests that synchronisation progress (origin block number, current block number
+// and highest block number) is tracked and updated correctly.
+func TestSyncProgress66Full(t *testing.T) { testSyncProgress(t, eth.ETH66, FullSync) }
+func TestSyncProgress66Fast(t *testing.T) { testSyncProgress(t, eth.ETH66, FastSync) }
+func TestSyncProgress66Light(t *testing.T) { testSyncProgress(t, eth.ETH66, LightSync) }
+
+func testSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+
+ // Set a sync init hook to catch progress changes
+ starting := make(chan struct{})
+ progress := make(chan struct{})
+
+ tester.downloader.syncInitHook = func(origin, latest uint64) {
+ starting <- struct{}{}
+ <-progress
+ }
+ checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{})
+
+ // Synchronise half the blocks and check initial progress
+ tester.newPeer("peer-half", protocol, chain.shorten(chain.len()/2))
+ pending := new(sync.WaitGroup)
+ pending.Add(1)
+
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("peer-half", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{
+ HighestBlock: uint64(chain.len()/2 - 1),
+ })
+ progress <- struct{}{}
+ pending.Wait()
+
+ // Synchronise all the blocks and check continuation progress
+ tester.newPeer("peer-full", protocol, chain)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("peer-full", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "completing", ethereum.SyncProgress{
+ StartingBlock: uint64(chain.len()/2 - 1),
+ CurrentBlock: uint64(chain.len()/2 - 1),
+ HighestBlock: uint64(chain.len() - 1),
+ })
+
+ // Check final progress after successful sync
+ progress <- struct{}{}
+ pending.Wait()
+ checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{
+ StartingBlock: uint64(chain.len()/2 - 1),
+ CurrentBlock: uint64(chain.len() - 1),
+ HighestBlock: uint64(chain.len() - 1),
+ })
+}
+
+func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.SyncProgress) {
+ // Mark this method as a helper to report errors at callsite, not in here
+ t.Helper()
+
+ p := d.Progress()
+ p.KnownStates, p.PulledStates = 0, 0
+ want.KnownStates, want.PulledStates = 0, 0
+ if p != want {
+ t.Fatalf("%s progress mismatch:\nhave %+v\nwant %+v", stage, p, want)
+ }
+}
+
+// Tests that synchronisation progress (origin block number and highest block
+// number) is tracked and updated correctly in case of a fork (or manual head
+// revertal).
+func TestForkedSyncProgress66Full(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, FullSync) }
+func TestForkedSyncProgress66Fast(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, FastSync) }
+func TestForkedSyncProgress66Light(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, LightSync) }
+
+func testForkedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+ chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHeaderFetch)
+ chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHeaderFetch)
+
+ // Set a sync init hook to catch progress changes
+ starting := make(chan struct{})
+ progress := make(chan struct{})
+
+ tester.downloader.syncInitHook = func(origin, latest uint64) {
+ starting <- struct{}{}
+ <-progress
+ }
+ checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{})
+
+ // Synchronise with one of the forks and check progress
+ tester.newPeer("fork A", protocol, chainA)
+ pending := new(sync.WaitGroup)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("fork A", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+
+ checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{
+ HighestBlock: uint64(chainA.len() - 1),
+ })
+ progress <- struct{}{}
+ pending.Wait()
+
+ // Simulate a successful sync above the fork
+ tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight
+
+ // Synchronise with the second fork and check progress resets
+ tester.newPeer("fork B", protocol, chainB)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("fork B", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "forking", ethereum.SyncProgress{
+ StartingBlock: uint64(testChainBase.len()) - 1,
+ CurrentBlock: uint64(chainA.len() - 1),
+ HighestBlock: uint64(chainB.len() - 1),
+ })
+
+ // Check final progress after successful sync
+ progress <- struct{}{}
+ pending.Wait()
+ checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{
+ StartingBlock: uint64(testChainBase.len()) - 1,
+ CurrentBlock: uint64(chainB.len() - 1),
+ HighestBlock: uint64(chainB.len() - 1),
+ })
+}
+
+// Tests that if synchronisation is aborted due to some failure, then the progress
+// origin is not updated in the next sync cycle, as it should be considered the
+// continuation of the previous sync and not a new instance.
+func TestFailedSyncProgress66Full(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, FullSync) }
+func TestFailedSyncProgress66Fast(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, FastSync) }
+func TestFailedSyncProgress66Light(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, LightSync) }
+
+func testFailedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+
+ // Set a sync init hook to catch progress changes
+ starting := make(chan struct{})
+ progress := make(chan struct{})
+
+ tester.downloader.syncInitHook = func(origin, latest uint64) {
+ starting <- struct{}{}
+ <-progress
+ }
+ checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{})
+
+ // Attempt a full sync with a faulty peer
+ brokenChain := chain.shorten(chain.len())
+ missing := brokenChain.len() / 2
+ delete(brokenChain.headerm, brokenChain.chain[missing])
+ delete(brokenChain.blockm, brokenChain.chain[missing])
+ delete(brokenChain.receiptm, brokenChain.chain[missing])
+ tester.newPeer("faulty", protocol, brokenChain)
+
+ pending := new(sync.WaitGroup)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("faulty", nil, mode); err == nil {
+ panic("succeeded faulty synchronisation")
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{
+ HighestBlock: uint64(brokenChain.len() - 1),
+ })
+ progress <- struct{}{}
+ pending.Wait()
+ afterFailedSync := tester.downloader.Progress()
+
+ // Synchronise with a good peer and check that the progress origin remind the same
+ // after a failure
+ tester.newPeer("valid", protocol, chain)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("valid", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "completing", afterFailedSync)
+
+ // Check final progress after successful sync
+ progress <- struct{}{}
+ pending.Wait()
+ checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{
+ CurrentBlock: uint64(chain.len() - 1),
+ HighestBlock: uint64(chain.len() - 1),
+ })
+}
+
+// Tests that if an attacker fakes a chain height, after the attack is detected,
+// the progress height is successfully reduced at the next sync invocation.
+func TestFakedSyncProgress66Full(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, FullSync) }
+func TestFakedSyncProgress66Fast(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, FastSync) }
+func TestFakedSyncProgress66Light(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, LightSync) }
+
+func testFakedSyncProgress(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ tester := newTester()
+ defer tester.terminate()
+ chain := testChainBase.shorten(blockCacheMaxItems - 15)
+
+ // Set a sync init hook to catch progress changes
+ starting := make(chan struct{})
+ progress := make(chan struct{})
+ tester.downloader.syncInitHook = func(origin, latest uint64) {
+ starting <- struct{}{}
+ <-progress
+ }
+ checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{})
+
+ // Create and sync with an attacker that promises a higher chain than available.
+ brokenChain := chain.shorten(chain.len())
+ numMissing := 5
+ for i := brokenChain.len() - 2; i > brokenChain.len()-numMissing; i-- {
+ delete(brokenChain.headerm, brokenChain.chain[i])
+ }
+ tester.newPeer("attack", protocol, brokenChain)
+
+ pending := new(sync.WaitGroup)
+ pending.Add(1)
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("attack", nil, mode); err == nil {
+ panic("succeeded attacker synchronisation")
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{
+ HighestBlock: uint64(brokenChain.len() - 1),
+ })
+ progress <- struct{}{}
+ pending.Wait()
+ afterFailedSync := tester.downloader.Progress()
+
+ // Synchronise with a good peer and check that the progress height has been reduced to
+ // the true value.
+ validChain := chain.shorten(chain.len() - numMissing)
+ tester.newPeer("valid", protocol, validChain)
+ pending.Add(1)
+
+ go func() {
+ defer pending.Done()
+ if err := tester.sync("valid", nil, mode); err != nil {
+ panic(fmt.Sprintf("failed to synchronise blocks: %v", err))
+ }
+ }()
+ <-starting
+ checkProgress(t, tester.downloader, "completing", ethereum.SyncProgress{
+ CurrentBlock: afterFailedSync.CurrentBlock,
+ HighestBlock: uint64(validChain.len() - 1),
+ })
+
+ // Check final progress after successful sync.
+ progress <- struct{}{}
+ pending.Wait()
+ checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{
+ CurrentBlock: uint64(validChain.len() - 1),
+ HighestBlock: uint64(validChain.len() - 1),
+ })
+}
+
+// This test reproduces an issue where unexpected deliveries would
+// block indefinitely if they arrived at the right time.
+func TestDeliverHeadersHang66Full(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, FullSync) }
+func TestDeliverHeadersHang66Fast(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, FastSync) }
+func TestDeliverHeadersHang66Light(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, LightSync) }
+
+func testDeliverHeadersHang(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ master := newTester()
+ defer master.terminate()
+ chain := testChainBase.shorten(15)
+
+ for i := 0; i < 200; i++ {
+ tester := newTester()
+ tester.peerDb = master.peerDb
+ tester.newPeer("peer", protocol, chain)
+
+ // Whenever the downloader requests headers, flood it with
+ // a lot of unrequested header deliveries.
+ tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{
+ peer: tester.downloader.peers.peers["peer"].peer,
+ tester: tester,
+ }
+ if err := tester.sync("peer", nil, mode); err != nil {
+ t.Errorf("test %d: sync failed: %v", i, err)
+ }
+ tester.terminate()
+ }
+}
+
+type floodingTestPeer struct {
+ peer Peer
+ tester *downloadTester
+}
+
+func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() }
+func (ftp *floodingTestPeer) RequestHeadersByHash(hash common.Hash, count int, skip int, reverse bool) error {
+ return ftp.peer.RequestHeadersByHash(hash, count, skip, reverse)
+}
+func (ftp *floodingTestPeer) RequestBodies(hashes []common.Hash) error {
+ return ftp.peer.RequestBodies(hashes)
+}
+func (ftp *floodingTestPeer) RequestReceipts(hashes []common.Hash) error {
+ return ftp.peer.RequestReceipts(hashes)
+}
+func (ftp *floodingTestPeer) RequestNodeData(hashes []common.Hash) error {
+ return ftp.peer.RequestNodeData(hashes)
+}
+
+func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int, reverse bool) error {
+ deliveriesDone := make(chan struct{}, 500)
+ for i := 0; i < cap(deliveriesDone)-1; i++ {
+ peer := fmt.Sprintf("fake-peer%d", i)
+ go func() {
+ ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}})
+ deliveriesDone <- struct{}{}
+ }()
+ }
+
+ // None of the extra deliveries should block.
+ timeout := time.After(60 * time.Second)
+ launched := false
+ for i := 0; i < cap(deliveriesDone); i++ {
+ select {
+ case <-deliveriesDone:
+ if !launched {
+ // Start delivering the requested headers
+ // after one of the flooding responses has arrived.
+ go func() {
+ ftp.peer.RequestHeadersByNumber(from, count, skip, reverse)
+ deliveriesDone <- struct{}{}
+ }()
+ launched = true
+ }
+ case <-timeout:
+ panic("blocked")
+ }
+ }
+ return nil
+}
+
+func TestRemoteHeaderRequestSpan(t *testing.T) {
+ testCases := []struct {
+ remoteHeight uint64
+ localHeight uint64
+ expected []int
+ }{
+ // Remote is way higher. We should ask for the remote head and go backwards
+ {1500, 1000,
+ []int{1323, 1339, 1355, 1371, 1387, 1403, 1419, 1435, 1451, 1467, 1483, 1499},
+ },
+ {15000, 13006,
+ []int{14823, 14839, 14855, 14871, 14887, 14903, 14919, 14935, 14951, 14967, 14983, 14999},
+ },
+ // Remote is pretty close to us. We don't have to fetch as many
+ {1200, 1150,
+ []int{1149, 1154, 1159, 1164, 1169, 1174, 1179, 1184, 1189, 1194, 1199},
+ },
+ // Remote is equal to us (so on a fork with higher td)
+ // We should get the closest couple of ancestors
+ {1500, 1500,
+ []int{1497, 1499},
+ },
+ // We're higher than the remote! Odd
+ {1000, 1500,
+ []int{997, 999},
+ },
+ // Check some weird edgecases that it behaves somewhat rationally
+ {0, 1500,
+ []int{0, 2},
+ },
+ {6000000, 0,
+ []int{5999823, 5999839, 5999855, 5999871, 5999887, 5999903, 5999919, 5999935, 5999951, 5999967, 5999983, 5999999},
+ },
+ {0, 0,
+ []int{0, 2},
+ },
+ }
+ reqs := func(from, count, span int) []int {
+ var r []int
+ num := from
+ for len(r) < count {
+ r = append(r, num)
+ num += span + 1
+ }
+ return r
+ }
+ for i, tt := range testCases {
+ from, count, span, max := calculateRequestSpan(tt.remoteHeight, tt.localHeight)
+ data := reqs(int(from), count, span)
+
+ if max != uint64(data[len(data)-1]) {
+ t.Errorf("test %d: wrong last value %d != %d", i, data[len(data)-1], max)
+ }
+ failed := false
+ if len(data) != len(tt.expected) {
+ failed = true
+ t.Errorf("test %d: length wrong, expected %d got %d", i, len(tt.expected), len(data))
+ } else {
+ for j, n := range data {
+ if n != tt.expected[j] {
+ failed = true
+ break
+ }
+ }
+ }
+ if failed {
+ res := strings.Replace(fmt.Sprint(data), " ", ",", -1)
+ exp := strings.Replace(fmt.Sprint(tt.expected), " ", ",", -1)
+ t.Logf("got: %v\n", res)
+ t.Logf("exp: %v\n", exp)
+ t.Errorf("test %d: wrong values", i)
+ }
+ }
+}
+
+// Tests that peers below a pre-configured checkpoint block are prevented from
+// being fast-synced from, avoiding potential cheap eclipse attacks.
+func TestCheckpointEnforcement66Full(t *testing.T) { testCheckpointEnforcement(t, eth.ETH66, FullSync) }
+func TestCheckpointEnforcement66Fast(t *testing.T) { testCheckpointEnforcement(t, eth.ETH66, FastSync) }
+func TestCheckpointEnforcement66Light(t *testing.T) {
+ testCheckpointEnforcement(t, eth.ETH66, LightSync)
+}
+
+func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) {
+ t.Parallel()
+
+ // Create a new tester with a particular hard coded checkpoint block
+ tester := newTester()
+ defer tester.terminate()
+
+ tester.downloader.checkpoint = uint64(fsMinFullBlocks) + 256
+ chain := testChainBase.shorten(int(tester.downloader.checkpoint) - 1)
+
+ // Attempt to sync with the peer and validate the result
+ tester.newPeer("peer", protocol, chain)
+
+ var expect error
+ if mode == FastSync || mode == LightSync {
+ expect = errUnsyncedPeer
+ }
+ if err := tester.sync("peer", nil, mode); !errors.Is(err, expect) {
+ t.Fatalf("block sync error mismatch: have %v, want %v", err, expect)
+ }
+ if mode == FastSync || mode == LightSync {
+ assertOwnChain(t, tester, 1)
+ } else {
+ assertOwnChain(t, tester, chain.len())
+ }
+}
diff --git a/les/downloader/events.go b/les/downloader/events.go
new file mode 100644
index 0000000000..25255a3a72
--- /dev/null
+++ b/les/downloader/events.go
@@ -0,0 +1,25 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import "github.com/ethereum/go-ethereum/core/types"
+
+type DoneEvent struct {
+ Latest *types.Header
+}
+type StartEvent struct{}
+type FailedEvent struct{ Err error }
diff --git a/les/downloader/metrics.go b/les/downloader/metrics.go
new file mode 100644
index 0000000000..c38732043a
--- /dev/null
+++ b/les/downloader/metrics.go
@@ -0,0 +1,45 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Contains the metrics collected by the downloader.
+
+package downloader
+
+import (
+ "github.com/ethereum/go-ethereum/metrics"
+)
+
+var (
+ headerInMeter = metrics.NewRegisteredMeter("eth/downloader/headers/in", nil)
+ headerReqTimer = metrics.NewRegisteredTimer("eth/downloader/headers/req", nil)
+ headerDropMeter = metrics.NewRegisteredMeter("eth/downloader/headers/drop", nil)
+ headerTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/headers/timeout", nil)
+
+ bodyInMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/in", nil)
+ bodyReqTimer = metrics.NewRegisteredTimer("eth/downloader/bodies/req", nil)
+ bodyDropMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/drop", nil)
+ bodyTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/timeout", nil)
+
+ receiptInMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/in", nil)
+ receiptReqTimer = metrics.NewRegisteredTimer("eth/downloader/receipts/req", nil)
+ receiptDropMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/drop", nil)
+ receiptTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/timeout", nil)
+
+ stateInMeter = metrics.NewRegisteredMeter("eth/downloader/states/in", nil)
+ stateDropMeter = metrics.NewRegisteredMeter("eth/downloader/states/drop", nil)
+
+ throttleCounter = metrics.NewRegisteredCounter("eth/downloader/throttle", nil)
+)
diff --git a/les/downloader/modes.go b/les/downloader/modes.go
new file mode 100644
index 0000000000..3ea14d22d7
--- /dev/null
+++ b/les/downloader/modes.go
@@ -0,0 +1,81 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import "fmt"
+
+// SyncMode represents the synchronisation mode of the downloader.
+// It is a uint32 as it is used with atomic operations.
+type SyncMode uint32
+
+const (
+ FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks
+ FastSync // Quickly download the headers, full sync only at the chain
+ SnapSync // Download the chain and the state via compact snapshots
+ LightSync // Download only the headers and terminate afterwards
+)
+
+func (mode SyncMode) IsValid() bool {
+ return mode >= FullSync && mode <= LightSync
+}
+
+// String implements the stringer interface.
+func (mode SyncMode) String() string {
+ switch mode {
+ case FullSync:
+ return "full"
+ case FastSync:
+ return "fast"
+ case SnapSync:
+ return "snap"
+ case LightSync:
+ return "light"
+ default:
+ return "unknown"
+ }
+}
+
+func (mode SyncMode) MarshalText() ([]byte, error) {
+ switch mode {
+ case FullSync:
+ return []byte("full"), nil
+ case FastSync:
+ return []byte("fast"), nil
+ case SnapSync:
+ return []byte("snap"), nil
+ case LightSync:
+ return []byte("light"), nil
+ default:
+ return nil, fmt.Errorf("unknown sync mode %d", mode)
+ }
+}
+
+func (mode *SyncMode) UnmarshalText(text []byte) error {
+ switch string(text) {
+ case "full":
+ *mode = FullSync
+ case "fast":
+ *mode = FastSync
+ case "snap":
+ *mode = SnapSync
+ case "light":
+ *mode = LightSync
+ default:
+ return fmt.Errorf(`unknown sync mode %q, want "full", "fast" or "light"`, text)
+ }
+ return nil
+}
diff --git a/les/downloader/peer.go b/les/downloader/peer.go
new file mode 100644
index 0000000000..8632948329
--- /dev/null
+++ b/les/downloader/peer.go
@@ -0,0 +1,501 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Contains the active peer-set of the downloader, maintaining both failures
+// as well as reputation metrics to prioritize the block retrievals.
+
+package downloader
+
+import (
+ "errors"
+ "math/big"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
+ "github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/msgrate"
+)
+
+const (
+ maxLackingHashes = 4096 // Maximum number of entries allowed on the list or lacking items
+)
+
+var (
+ errAlreadyFetching = errors.New("already fetching blocks from peer")
+ errAlreadyRegistered = errors.New("peer is already registered")
+ errNotRegistered = errors.New("peer is not registered")
+)
+
+// peerConnection represents an active peer from which hashes and blocks are retrieved.
+type peerConnection struct {
+ id string // Unique identifier of the peer
+
+ headerIdle int32 // Current header activity state of the peer (idle = 0, active = 1)
+ blockIdle int32 // Current block activity state of the peer (idle = 0, active = 1)
+ receiptIdle int32 // Current receipt activity state of the peer (idle = 0, active = 1)
+ stateIdle int32 // Current node data activity state of the peer (idle = 0, active = 1)
+
+ headerStarted time.Time // Time instance when the last header fetch was started
+ blockStarted time.Time // Time instance when the last block (body) fetch was started
+ receiptStarted time.Time // Time instance when the last receipt fetch was started
+ stateStarted time.Time // Time instance when the last node data fetch was started
+
+ rates *msgrate.Tracker // Tracker to hone in on the number of items retrievable per second
+ lacking map[common.Hash]struct{} // Set of hashes not to request (didn't have previously)
+
+ peer Peer
+
+ version uint // Eth protocol version number to switch strategies
+ log log.Logger // Contextual logger to add extra infos to peer logs
+ lock sync.RWMutex
+}
+
+// LightPeer encapsulates the methods required to synchronise with a remote light peer.
+type LightPeer interface {
+ Head() (common.Hash, *big.Int)
+ RequestHeadersByHash(common.Hash, int, int, bool) error
+ RequestHeadersByNumber(uint64, int, int, bool) error
+}
+
+// Peer encapsulates the methods required to synchronise with a remote full peer.
+type Peer interface {
+ LightPeer
+ RequestBodies([]common.Hash) error
+ RequestReceipts([]common.Hash) error
+ RequestNodeData([]common.Hash) error
+}
+
+// lightPeerWrapper wraps a LightPeer struct, stubbing out the Peer-only methods.
+type lightPeerWrapper struct {
+ peer LightPeer
+}
+
+func (w *lightPeerWrapper) Head() (common.Hash, *big.Int) { return w.peer.Head() }
+func (w *lightPeerWrapper) RequestHeadersByHash(h common.Hash, amount int, skip int, reverse bool) error {
+ return w.peer.RequestHeadersByHash(h, amount, skip, reverse)
+}
+func (w *lightPeerWrapper) RequestHeadersByNumber(i uint64, amount int, skip int, reverse bool) error {
+ return w.peer.RequestHeadersByNumber(i, amount, skip, reverse)
+}
+func (w *lightPeerWrapper) RequestBodies([]common.Hash) error {
+ panic("RequestBodies not supported in light client mode sync")
+}
+func (w *lightPeerWrapper) RequestReceipts([]common.Hash) error {
+ panic("RequestReceipts not supported in light client mode sync")
+}
+func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error {
+ panic("RequestNodeData not supported in light client mode sync")
+}
+
+// newPeerConnection creates a new downloader peer.
+func newPeerConnection(id string, version uint, peer Peer, logger log.Logger) *peerConnection {
+ return &peerConnection{
+ id: id,
+ lacking: make(map[common.Hash]struct{}),
+ peer: peer,
+ version: version,
+ log: logger,
+ }
+}
+
+// Reset clears the internal state of a peer entity.
+func (p *peerConnection) Reset() {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ atomic.StoreInt32(&p.headerIdle, 0)
+ atomic.StoreInt32(&p.blockIdle, 0)
+ atomic.StoreInt32(&p.receiptIdle, 0)
+ atomic.StoreInt32(&p.stateIdle, 0)
+
+ p.lacking = make(map[common.Hash]struct{})
+}
+
+// FetchHeaders sends a header retrieval request to the remote peer.
+func (p *peerConnection) FetchHeaders(from uint64, count int) error {
+ // Short circuit if the peer is already fetching
+ if !atomic.CompareAndSwapInt32(&p.headerIdle, 0, 1) {
+ return errAlreadyFetching
+ }
+ p.headerStarted = time.Now()
+
+ // Issue the header retrieval request (absolute upwards without gaps)
+ go p.peer.RequestHeadersByNumber(from, count, 0, false)
+
+ return nil
+}
+
+// FetchBodies sends a block body retrieval request to the remote peer.
+func (p *peerConnection) FetchBodies(request *fetchRequest) error {
+ // Short circuit if the peer is already fetching
+ if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) {
+ return errAlreadyFetching
+ }
+ p.blockStarted = time.Now()
+
+ go func() {
+ // Convert the header set to a retrievable slice
+ hashes := make([]common.Hash, 0, len(request.Headers))
+ for _, header := range request.Headers {
+ hashes = append(hashes, header.Hash())
+ }
+ p.peer.RequestBodies(hashes)
+ }()
+
+ return nil
+}
+
+// FetchReceipts sends a receipt retrieval request to the remote peer.
+func (p *peerConnection) FetchReceipts(request *fetchRequest) error {
+ // Short circuit if the peer is already fetching
+ if !atomic.CompareAndSwapInt32(&p.receiptIdle, 0, 1) {
+ return errAlreadyFetching
+ }
+ p.receiptStarted = time.Now()
+
+ go func() {
+ // Convert the header set to a retrievable slice
+ hashes := make([]common.Hash, 0, len(request.Headers))
+ for _, header := range request.Headers {
+ hashes = append(hashes, header.Hash())
+ }
+ p.peer.RequestReceipts(hashes)
+ }()
+
+ return nil
+}
+
+// FetchNodeData sends a node state data retrieval request to the remote peer.
+func (p *peerConnection) FetchNodeData(hashes []common.Hash) error {
+ // Short circuit if the peer is already fetching
+ if !atomic.CompareAndSwapInt32(&p.stateIdle, 0, 1) {
+ return errAlreadyFetching
+ }
+ p.stateStarted = time.Now()
+
+ go p.peer.RequestNodeData(hashes)
+
+ return nil
+}
+
+// SetHeadersIdle sets the peer to idle, allowing it to execute new header retrieval
+// requests. Its estimated header retrieval throughput is updated with that measured
+// just now.
+func (p *peerConnection) SetHeadersIdle(delivered int, deliveryTime time.Time) {
+ p.rates.Update(eth.BlockHeadersMsg, deliveryTime.Sub(p.headerStarted), delivered)
+ atomic.StoreInt32(&p.headerIdle, 0)
+}
+
+// SetBodiesIdle sets the peer to idle, allowing it to execute block body retrieval
+// requests. Its estimated body retrieval throughput is updated with that measured
+// just now.
+func (p *peerConnection) SetBodiesIdle(delivered int, deliveryTime time.Time) {
+ p.rates.Update(eth.BlockBodiesMsg, deliveryTime.Sub(p.blockStarted), delivered)
+ atomic.StoreInt32(&p.blockIdle, 0)
+}
+
+// SetReceiptsIdle sets the peer to idle, allowing it to execute new receipt
+// retrieval requests. Its estimated receipt retrieval throughput is updated
+// with that measured just now.
+func (p *peerConnection) SetReceiptsIdle(delivered int, deliveryTime time.Time) {
+ p.rates.Update(eth.ReceiptsMsg, deliveryTime.Sub(p.receiptStarted), delivered)
+ atomic.StoreInt32(&p.receiptIdle, 0)
+}
+
+// SetNodeDataIdle sets the peer to idle, allowing it to execute new state trie
+// data retrieval requests. Its estimated state retrieval throughput is updated
+// with that measured just now.
+func (p *peerConnection) SetNodeDataIdle(delivered int, deliveryTime time.Time) {
+ p.rates.Update(eth.NodeDataMsg, deliveryTime.Sub(p.stateStarted), delivered)
+ atomic.StoreInt32(&p.stateIdle, 0)
+}
+
+// HeaderCapacity retrieves the peers header download allowance based on its
+// previously discovered throughput.
+func (p *peerConnection) HeaderCapacity(targetRTT time.Duration) int {
+ cap := p.rates.Capacity(eth.BlockHeadersMsg, targetRTT)
+ if cap > MaxHeaderFetch {
+ cap = MaxHeaderFetch
+ }
+ return cap
+}
+
+// BlockCapacity retrieves the peers block download allowance based on its
+// previously discovered throughput.
+func (p *peerConnection) BlockCapacity(targetRTT time.Duration) int {
+ cap := p.rates.Capacity(eth.BlockBodiesMsg, targetRTT)
+ if cap > MaxBlockFetch {
+ cap = MaxBlockFetch
+ }
+ return cap
+}
+
+// ReceiptCapacity retrieves the peers receipt download allowance based on its
+// previously discovered throughput.
+func (p *peerConnection) ReceiptCapacity(targetRTT time.Duration) int {
+ cap := p.rates.Capacity(eth.ReceiptsMsg, targetRTT)
+ if cap > MaxReceiptFetch {
+ cap = MaxReceiptFetch
+ }
+ return cap
+}
+
+// NodeDataCapacity retrieves the peers state download allowance based on its
+// previously discovered throughput.
+func (p *peerConnection) NodeDataCapacity(targetRTT time.Duration) int {
+ cap := p.rates.Capacity(eth.NodeDataMsg, targetRTT)
+ if cap > MaxStateFetch {
+ cap = MaxStateFetch
+ }
+ return cap
+}
+
+// MarkLacking appends a new entity to the set of items (blocks, receipts, states)
+// that a peer is known not to have (i.e. have been requested before). If the
+// set reaches its maximum allowed capacity, items are randomly dropped off.
+func (p *peerConnection) MarkLacking(hash common.Hash) {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ for len(p.lacking) >= maxLackingHashes {
+ for drop := range p.lacking {
+ delete(p.lacking, drop)
+ break
+ }
+ }
+ p.lacking[hash] = struct{}{}
+}
+
+// Lacks retrieves whether the hash of a blockchain item is on the peers lacking
+// list (i.e. whether we know that the peer does not have it).
+func (p *peerConnection) Lacks(hash common.Hash) bool {
+ p.lock.RLock()
+ defer p.lock.RUnlock()
+
+ _, ok := p.lacking[hash]
+ return ok
+}
+
+// peerSet represents the collection of active peer participating in the chain
+// download procedure.
+type peerSet struct {
+ peers map[string]*peerConnection
+ rates *msgrate.Trackers // Set of rate trackers to give the sync a common beat
+
+ newPeerFeed event.Feed
+ peerDropFeed event.Feed
+
+ lock sync.RWMutex
+}
+
+// newPeerSet creates a new peer set top track the active download sources.
+func newPeerSet() *peerSet {
+ return &peerSet{
+ peers: make(map[string]*peerConnection),
+ rates: msgrate.NewTrackers(log.New("proto", "eth")),
+ }
+}
+
+// SubscribeNewPeers subscribes to peer arrival events.
+func (ps *peerSet) SubscribeNewPeers(ch chan<- *peerConnection) event.Subscription {
+ return ps.newPeerFeed.Subscribe(ch)
+}
+
+// SubscribePeerDrops subscribes to peer departure events.
+func (ps *peerSet) SubscribePeerDrops(ch chan<- *peerConnection) event.Subscription {
+ return ps.peerDropFeed.Subscribe(ch)
+}
+
+// Reset iterates over the current peer set, and resets each of the known peers
+// to prepare for a next batch of block retrieval.
+func (ps *peerSet) Reset() {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ for _, peer := range ps.peers {
+ peer.Reset()
+ }
+}
+
+// Register injects a new peer into the working set, or returns an error if the
+// peer is already known.
+//
+// The method also sets the starting throughput values of the new peer to the
+// average of all existing peers, to give it a realistic chance of being used
+// for data retrievals.
+func (ps *peerSet) Register(p *peerConnection) error {
+ // Register the new peer with some meaningful defaults
+ ps.lock.Lock()
+ if _, ok := ps.peers[p.id]; ok {
+ ps.lock.Unlock()
+ return errAlreadyRegistered
+ }
+ p.rates = msgrate.NewTracker(ps.rates.MeanCapacities(), ps.rates.MedianRoundTrip())
+ if err := ps.rates.Track(p.id, p.rates); err != nil {
+ return err
+ }
+ ps.peers[p.id] = p
+ ps.lock.Unlock()
+
+ ps.newPeerFeed.Send(p)
+ return nil
+}
+
+// Unregister removes a remote peer from the active set, disabling any further
+// actions to/from that particular entity.
+func (ps *peerSet) Unregister(id string) error {
+ ps.lock.Lock()
+ p, ok := ps.peers[id]
+ if !ok {
+ ps.lock.Unlock()
+ return errNotRegistered
+ }
+ delete(ps.peers, id)
+ ps.rates.Untrack(id)
+ ps.lock.Unlock()
+
+ ps.peerDropFeed.Send(p)
+ return nil
+}
+
+// Peer retrieves the registered peer with the given id.
+func (ps *peerSet) Peer(id string) *peerConnection {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return ps.peers[id]
+}
+
+// Len returns if the current number of peers in the set.
+func (ps *peerSet) Len() int {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ return len(ps.peers)
+}
+
+// AllPeers retrieves a flat list of all the peers within the set.
+func (ps *peerSet) AllPeers() []*peerConnection {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ list := make([]*peerConnection, 0, len(ps.peers))
+ for _, p := range ps.peers {
+ list = append(list, p)
+ }
+ return list
+}
+
+// HeaderIdlePeers retrieves a flat list of all the currently header-idle peers
+// within the active peer set, ordered by their reputation.
+func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) {
+ idle := func(p *peerConnection) bool {
+ return atomic.LoadInt32(&p.headerIdle) == 0
+ }
+ throughput := func(p *peerConnection) int {
+ return p.rates.Capacity(eth.BlockHeadersMsg, time.Second)
+ }
+ return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput)
+}
+
+// BodyIdlePeers retrieves a flat list of all the currently body-idle peers within
+// the active peer set, ordered by their reputation.
+func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) {
+ idle := func(p *peerConnection) bool {
+ return atomic.LoadInt32(&p.blockIdle) == 0
+ }
+ throughput := func(p *peerConnection) int {
+ return p.rates.Capacity(eth.BlockBodiesMsg, time.Second)
+ }
+ return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput)
+}
+
+// ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers
+// within the active peer set, ordered by their reputation.
+func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) {
+ idle := func(p *peerConnection) bool {
+ return atomic.LoadInt32(&p.receiptIdle) == 0
+ }
+ throughput := func(p *peerConnection) int {
+ return p.rates.Capacity(eth.ReceiptsMsg, time.Second)
+ }
+ return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput)
+}
+
+// NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle
+// peers within the active peer set, ordered by their reputation.
+func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) {
+ idle := func(p *peerConnection) bool {
+ return atomic.LoadInt32(&p.stateIdle) == 0
+ }
+ throughput := func(p *peerConnection) int {
+ return p.rates.Capacity(eth.NodeDataMsg, time.Second)
+ }
+ return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput)
+}
+
+// idlePeers retrieves a flat list of all currently idle peers satisfying the
+// protocol version constraints, using the provided function to check idleness.
+// The resulting set of peers are sorted by their capacity.
+func (ps *peerSet) idlePeers(minProtocol, maxProtocol uint, idleCheck func(*peerConnection) bool, capacity func(*peerConnection) int) ([]*peerConnection, int) {
+ ps.lock.RLock()
+ defer ps.lock.RUnlock()
+
+ var (
+ total = 0
+ idle = make([]*peerConnection, 0, len(ps.peers))
+ tps = make([]int, 0, len(ps.peers))
+ )
+ for _, p := range ps.peers {
+ if p.version >= minProtocol && p.version <= maxProtocol {
+ if idleCheck(p) {
+ idle = append(idle, p)
+ tps = append(tps, capacity(p))
+ }
+ total++
+ }
+ }
+
+ // And sort them
+ sortPeers := &peerCapacitySort{idle, tps}
+ sort.Sort(sortPeers)
+ return sortPeers.p, total
+}
+
+// peerCapacitySort implements sort.Interface.
+// It sorts peer connections by capacity (descending).
+type peerCapacitySort struct {
+ p []*peerConnection
+ tp []int
+}
+
+func (ps *peerCapacitySort) Len() int {
+ return len(ps.p)
+}
+
+func (ps *peerCapacitySort) Less(i, j int) bool {
+ return ps.tp[i] > ps.tp[j]
+}
+
+func (ps *peerCapacitySort) Swap(i, j int) {
+ ps.p[i], ps.p[j] = ps.p[j], ps.p[i]
+ ps.tp[i], ps.tp[j] = ps.tp[j], ps.tp[i]
+}
diff --git a/les/downloader/queue.go b/les/downloader/queue.go
new file mode 100644
index 0000000000..04ec12cfa9
--- /dev/null
+++ b/les/downloader/queue.go
@@ -0,0 +1,913 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Contains the block download scheduler to collect download tasks and schedule
+// them in an ordered, and throttled way.
+
+package downloader
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/prque"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+ bodyType = uint(0)
+ receiptType = uint(1)
+)
+
+var (
+ blockCacheMaxItems = 8192 // Maximum number of blocks to cache before throttling the download
+ blockCacheInitialItems = 2048 // Initial number of blocks to start fetching, before we know the sizes of the blocks
+ blockCacheMemory = 256 * 1024 * 1024 // Maximum amount of memory to use for block caching
+ blockCacheSizeWeight = 0.1 // Multiplier to approximate the average block size based on past ones
+)
+
+var (
+ errNoFetchesPending = errors.New("no fetches pending")
+ errStaleDelivery = errors.New("stale delivery")
+)
+
+// fetchRequest is a currently running data retrieval operation.
+type fetchRequest struct {
+ Peer *peerConnection // Peer to which the request was sent
+ From uint64 // [eth/62] Requested chain element index (used for skeleton fills only)
+ Headers []*types.Header // [eth/62] Requested headers, sorted by request order
+ Time time.Time // Time when the request was made
+}
+
+// fetchResult is a struct collecting partial results from data fetchers until
+// all outstanding pieces complete and the result as a whole can be processed.
+type fetchResult struct {
+ pending int32 // Flag telling what deliveries are outstanding
+
+ Header *types.Header
+ Uncles []*types.Header
+ Transactions types.Transactions
+ Receipts types.Receipts
+}
+
+func newFetchResult(header *types.Header, fastSync bool) *fetchResult {
+ item := &fetchResult{
+ Header: header,
+ }
+ if !header.EmptyBody() {
+ item.pending |= (1 << bodyType)
+ }
+ if fastSync && !header.EmptyReceipts() {
+ item.pending |= (1 << receiptType)
+ }
+ return item
+}
+
+// SetBodyDone flags the body as finished.
+func (f *fetchResult) SetBodyDone() {
+ if v := atomic.LoadInt32(&f.pending); (v & (1 << bodyType)) != 0 {
+ atomic.AddInt32(&f.pending, -1)
+ }
+}
+
+// AllDone checks if item is done.
+func (f *fetchResult) AllDone() bool {
+ return atomic.LoadInt32(&f.pending) == 0
+}
+
+// SetReceiptsDone flags the receipts as finished.
+func (f *fetchResult) SetReceiptsDone() {
+ if v := atomic.LoadInt32(&f.pending); (v & (1 << receiptType)) != 0 {
+ atomic.AddInt32(&f.pending, -2)
+ }
+}
+
+// Done checks if the given type is done already
+func (f *fetchResult) Done(kind uint) bool {
+ v := atomic.LoadInt32(&f.pending)
+ return v&(1< 0
+}
+
+// InFlightBlocks retrieves whether there are block fetch requests currently in
+// flight.
+func (q *queue) InFlightBlocks() bool {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return len(q.blockPendPool) > 0
+}
+
+// InFlightReceipts retrieves whether there are receipt fetch requests currently
+// in flight.
+func (q *queue) InFlightReceipts() bool {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return len(q.receiptPendPool) > 0
+}
+
+// Idle returns if the queue is fully idle or has some data still inside.
+func (q *queue) Idle() bool {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ queued := q.blockTaskQueue.Size() + q.receiptTaskQueue.Size()
+ pending := len(q.blockPendPool) + len(q.receiptPendPool)
+
+ return (queued + pending) == 0
+}
+
+// ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill
+// up an already retrieved header skeleton.
+func (q *queue) ScheduleSkeleton(from uint64, skeleton []*types.Header) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ // No skeleton retrieval can be in progress, fail hard if so (huge implementation bug)
+ if q.headerResults != nil {
+ panic("skeleton assembly already in progress")
+ }
+ // Schedule all the header retrieval tasks for the skeleton assembly
+ q.headerTaskPool = make(map[uint64]*types.Header)
+ q.headerTaskQueue = prque.New(nil)
+ q.headerPeerMiss = make(map[string]map[uint64]struct{}) // Reset availability to correct invalid chains
+ q.headerResults = make([]*types.Header, len(skeleton)*MaxHeaderFetch)
+ q.headerProced = 0
+ q.headerOffset = from
+ q.headerContCh = make(chan bool, 1)
+
+ for i, header := range skeleton {
+ index := from + uint64(i*MaxHeaderFetch)
+
+ q.headerTaskPool[index] = header
+ q.headerTaskQueue.Push(index, -int64(index))
+ }
+}
+
+// RetrieveHeaders retrieves the header chain assemble based on the scheduled
+// skeleton.
+func (q *queue) RetrieveHeaders() ([]*types.Header, int) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ headers, proced := q.headerResults, q.headerProced
+ q.headerResults, q.headerProced = nil, 0
+
+ return headers, proced
+}
+
+// Schedule adds a set of headers for the download queue for scheduling, returning
+// the new headers encountered.
+func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ // Insert all the headers prioritised by the contained block number
+ inserts := make([]*types.Header, 0, len(headers))
+ for _, header := range headers {
+ // Make sure chain order is honoured and preserved throughout
+ hash := header.Hash()
+ if header.Number == nil || header.Number.Uint64() != from {
+ log.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", from)
+ break
+ }
+ if q.headerHead != (common.Hash{}) && q.headerHead != header.ParentHash {
+ log.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash)
+ break
+ }
+ // Make sure no duplicate requests are executed
+ // We cannot skip this, even if the block is empty, since this is
+ // what triggers the fetchResult creation.
+ if _, ok := q.blockTaskPool[hash]; ok {
+ log.Warn("Header already scheduled for block fetch", "number", header.Number, "hash", hash)
+ } else {
+ q.blockTaskPool[hash] = header
+ q.blockTaskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ // Queue for receipt retrieval
+ if q.mode == FastSync && !header.EmptyReceipts() {
+ if _, ok := q.receiptTaskPool[hash]; ok {
+ log.Warn("Header already scheduled for receipt fetch", "number", header.Number, "hash", hash)
+ } else {
+ q.receiptTaskPool[hash] = header
+ q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ }
+ inserts = append(inserts, header)
+ q.headerHead = hash
+ from++
+ }
+ return inserts
+}
+
+// Results retrieves and permanently removes a batch of fetch results from
+// the cache. the result slice will be empty if the queue has been closed.
+// Results can be called concurrently with Deliver and Schedule,
+// but assumes that there are not two simultaneous callers to Results
+func (q *queue) Results(block bool) []*fetchResult {
+ // Abort early if there are no items and non-blocking requested
+ if !block && !q.resultCache.HasCompletedItems() {
+ return nil
+ }
+ closed := false
+ for !closed && !q.resultCache.HasCompletedItems() {
+ // In order to wait on 'active', we need to obtain the lock.
+ // That may take a while, if someone is delivering at the same
+ // time, so after obtaining the lock, we check again if there
+ // are any results to fetch.
+ // Also, in-between we ask for the lock and the lock is obtained,
+ // someone can have closed the queue. In that case, we should
+ // return the available results and stop blocking
+ q.lock.Lock()
+ if q.resultCache.HasCompletedItems() || q.closed {
+ q.lock.Unlock()
+ break
+ }
+ // No items available, and not closed
+ q.active.Wait()
+ closed = q.closed
+ q.lock.Unlock()
+ }
+ // Regardless if closed or not, we can still deliver whatever we have
+ results := q.resultCache.GetCompleted(maxResultsProcess)
+ for _, result := range results {
+ // Recalculate the result item weights to prevent memory exhaustion
+ size := result.Header.Size()
+ for _, uncle := range result.Uncles {
+ size += uncle.Size()
+ }
+ for _, receipt := range result.Receipts {
+ size += receipt.Size()
+ }
+ for _, tx := range result.Transactions {
+ size += tx.Size()
+ }
+ q.resultSize = common.StorageSize(blockCacheSizeWeight)*size +
+ (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize
+ }
+ // Using the newly calibrated resultsize, figure out the new throttle limit
+ // on the result cache
+ throttleThreshold := uint64((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize)
+ throttleThreshold = q.resultCache.SetThrottleThreshold(throttleThreshold)
+
+ // Log some info at certain times
+ if time.Since(q.lastStatLog) > 60*time.Second {
+ q.lastStatLog = time.Now()
+ info := q.Stats()
+ info = append(info, "throttle", throttleThreshold)
+ log.Info("Downloader queue stats", info...)
+ }
+ return results
+}
+
+func (q *queue) Stats() []interface{} {
+ q.lock.RLock()
+ defer q.lock.RUnlock()
+
+ return q.stats()
+}
+
+func (q *queue) stats() []interface{} {
+ return []interface{}{
+ "receiptTasks", q.receiptTaskQueue.Size(),
+ "blockTasks", q.blockTaskQueue.Size(),
+ "itemSize", q.resultSize,
+ }
+}
+
+// ReserveHeaders reserves a set of headers for the given peer, skipping any
+// previously failed batches.
+func (q *queue) ReserveHeaders(p *peerConnection, count int) *fetchRequest {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ // Short circuit if the peer's already downloading something (sanity check to
+ // not corrupt state)
+ if _, ok := q.headerPendPool[p.id]; ok {
+ return nil
+ }
+ // Retrieve a batch of hashes, skipping previously failed ones
+ send, skip := uint64(0), []uint64{}
+ for send == 0 && !q.headerTaskQueue.Empty() {
+ from, _ := q.headerTaskQueue.Pop()
+ if q.headerPeerMiss[p.id] != nil {
+ if _, ok := q.headerPeerMiss[p.id][from.(uint64)]; ok {
+ skip = append(skip, from.(uint64))
+ continue
+ }
+ }
+ send = from.(uint64)
+ }
+ // Merge all the skipped batches back
+ for _, from := range skip {
+ q.headerTaskQueue.Push(from, -int64(from))
+ }
+ // Assemble and return the block download request
+ if send == 0 {
+ return nil
+ }
+ request := &fetchRequest{
+ Peer: p,
+ From: send,
+ Time: time.Now(),
+ }
+ q.headerPendPool[p.id] = request
+ return request
+}
+
+// ReserveBodies reserves a set of body fetches for the given peer, skipping any
+// previously failed downloads. Beside the next batch of needed fetches, it also
+// returns a flag whether empty blocks were queued requiring processing.
+func (q *queue) ReserveBodies(p *peerConnection, count int) (*fetchRequest, bool, bool) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, bodyType)
+}
+
+// ReserveReceipts reserves a set of receipt fetches for the given peer, skipping
+// any previously failed downloads. Beside the next batch of needed fetches, it
+// also returns a flag whether empty receipts were queued requiring importing.
+func (q *queue) ReserveReceipts(p *peerConnection, count int) (*fetchRequest, bool, bool) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, receiptType)
+}
+
+// reserveHeaders reserves a set of data download operations for a given peer,
+// skipping any previously failed ones. This method is a generic version used
+// by the individual special reservation functions.
+//
+// Note, this method expects the queue lock to be already held for writing. The
+// reason the lock is not obtained in here is because the parameters already need
+// to access the queue, so they already need a lock anyway.
+//
+// Returns:
+// item - the fetchRequest
+// progress - whether any progress was made
+// throttle - if the caller should throttle for a while
+func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque,
+ pendPool map[string]*fetchRequest, kind uint) (*fetchRequest, bool, bool) {
+ // Short circuit if the pool has been depleted, or if the peer's already
+ // downloading something (sanity check not to corrupt state)
+ if taskQueue.Empty() {
+ return nil, false, true
+ }
+ if _, ok := pendPool[p.id]; ok {
+ return nil, false, false
+ }
+ // Retrieve a batch of tasks, skipping previously failed ones
+ send := make([]*types.Header, 0, count)
+ skip := make([]*types.Header, 0)
+ progress := false
+ throttled := false
+ for proc := 0; len(send) < count && !taskQueue.Empty(); proc++ {
+ // the task queue will pop items in order, so the highest prio block
+ // is also the lowest block number.
+ h, _ := taskQueue.Peek()
+ header := h.(*types.Header)
+ // we can ask the resultcache if this header is within the
+ // "prioritized" segment of blocks. If it is not, we need to throttle
+
+ stale, throttle, item, err := q.resultCache.AddFetch(header, q.mode == FastSync)
+ if stale {
+ // Don't put back in the task queue, this item has already been
+ // delivered upstream
+ taskQueue.PopItem()
+ progress = true
+ delete(taskPool, header.Hash())
+ proc = proc - 1
+ log.Error("Fetch reservation already delivered", "number", header.Number.Uint64())
+ continue
+ }
+ if throttle {
+ // There are no resultslots available. Leave it in the task queue
+ // However, if there are any left as 'skipped', we should not tell
+ // the caller to throttle, since we still want some other
+ // peer to fetch those for us
+ throttled = len(skip) == 0
+ break
+ }
+ if err != nil {
+ // this most definitely should _not_ happen
+ log.Warn("Failed to reserve headers", "err", err)
+ // There are no resultslots available. Leave it in the task queue
+ break
+ }
+ if item.Done(kind) {
+ // If it's a noop, we can skip this task
+ delete(taskPool, header.Hash())
+ taskQueue.PopItem()
+ proc = proc - 1
+ progress = true
+ continue
+ }
+ // Remove it from the task queue
+ taskQueue.PopItem()
+ // Otherwise unless the peer is known not to have the data, add to the retrieve list
+ if p.Lacks(header.Hash()) {
+ skip = append(skip, header)
+ } else {
+ send = append(send, header)
+ }
+ }
+ // Merge all the skipped headers back
+ for _, header := range skip {
+ taskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ if q.resultCache.HasCompletedItems() {
+ // Wake Results, resultCache was modified
+ q.active.Signal()
+ }
+ // Assemble and return the block download request
+ if len(send) == 0 {
+ return nil, progress, throttled
+ }
+ request := &fetchRequest{
+ Peer: p,
+ Headers: send,
+ Time: time.Now(),
+ }
+ pendPool[p.id] = request
+ return request, progress, throttled
+}
+
+// CancelHeaders aborts a fetch request, returning all pending skeleton indexes to the queue.
+func (q *queue) CancelHeaders(request *fetchRequest) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ q.cancel(request, q.headerTaskQueue, q.headerPendPool)
+}
+
+// CancelBodies aborts a body fetch request, returning all pending headers to the
+// task queue.
+func (q *queue) CancelBodies(request *fetchRequest) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ q.cancel(request, q.blockTaskQueue, q.blockPendPool)
+}
+
+// CancelReceipts aborts a body fetch request, returning all pending headers to
+// the task queue.
+func (q *queue) CancelReceipts(request *fetchRequest) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ q.cancel(request, q.receiptTaskQueue, q.receiptPendPool)
+}
+
+// Cancel aborts a fetch request, returning all pending hashes to the task queue.
+func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool map[string]*fetchRequest) {
+ if request.From > 0 {
+ taskQueue.Push(request.From, -int64(request.From))
+ }
+ for _, header := range request.Headers {
+ taskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ delete(pendPool, request.Peer.id)
+}
+
+// Revoke cancels all pending requests belonging to a given peer. This method is
+// meant to be called during a peer drop to quickly reassign owned data fetches
+// to remaining nodes.
+func (q *queue) Revoke(peerID string) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ if request, ok := q.blockPendPool[peerID]; ok {
+ for _, header := range request.Headers {
+ q.blockTaskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ delete(q.blockPendPool, peerID)
+ }
+ if request, ok := q.receiptPendPool[peerID]; ok {
+ for _, header := range request.Headers {
+ q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ delete(q.receiptPendPool, peerID)
+ }
+}
+
+// ExpireHeaders checks for in flight requests that exceeded a timeout allowance,
+// canceling them and returning the responsible peers for penalisation.
+func (q *queue) ExpireHeaders(timeout time.Duration) map[string]int {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return q.expire(timeout, q.headerPendPool, q.headerTaskQueue, headerTimeoutMeter)
+}
+
+// ExpireBodies checks for in flight block body requests that exceeded a timeout
+// allowance, canceling them and returning the responsible peers for penalisation.
+func (q *queue) ExpireBodies(timeout time.Duration) map[string]int {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return q.expire(timeout, q.blockPendPool, q.blockTaskQueue, bodyTimeoutMeter)
+}
+
+// ExpireReceipts checks for in flight receipt requests that exceeded a timeout
+// allowance, canceling them and returning the responsible peers for penalisation.
+func (q *queue) ExpireReceipts(timeout time.Duration) map[string]int {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ return q.expire(timeout, q.receiptPendPool, q.receiptTaskQueue, receiptTimeoutMeter)
+}
+
+// expire is the generic check that move expired tasks from a pending pool back
+// into a task pool, returning all entities caught with expired tasks.
+//
+// Note, this method expects the queue lock to be already held. The
+// reason the lock is not obtained in here is because the parameters already need
+// to access the queue, so they already need a lock anyway.
+func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, taskQueue *prque.Prque, timeoutMeter metrics.Meter) map[string]int {
+ // Iterate over the expired requests and return each to the queue
+ expiries := make(map[string]int)
+ for id, request := range pendPool {
+ if time.Since(request.Time) > timeout {
+ // Update the metrics with the timeout
+ timeoutMeter.Mark(1)
+
+ // Return any non satisfied requests to the pool
+ if request.From > 0 {
+ taskQueue.Push(request.From, -int64(request.From))
+ }
+ for _, header := range request.Headers {
+ taskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ // Add the peer to the expiry report along the number of failed requests
+ expiries[id] = len(request.Headers)
+
+ // Remove the expired requests from the pending pool directly
+ delete(pendPool, id)
+ }
+ }
+ return expiries
+}
+
+// DeliverHeaders injects a header retrieval response into the header results
+// cache. This method either accepts all headers it received, or none of them
+// if they do not map correctly to the skeleton.
+//
+// If the headers are accepted, the method makes an attempt to deliver the set
+// of ready headers to the processor to keep the pipeline full. However it will
+// not block to prevent stalling other pending deliveries.
+func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh chan []*types.Header) (int, error) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ var logger log.Logger
+ if len(id) < 16 {
+ // Tests use short IDs, don't choke on them
+ logger = log.New("peer", id)
+ } else {
+ logger = log.New("peer", id[:16])
+ }
+ // Short circuit if the data was never requested
+ request := q.headerPendPool[id]
+ if request == nil {
+ return 0, errNoFetchesPending
+ }
+ headerReqTimer.UpdateSince(request.Time)
+ delete(q.headerPendPool, id)
+
+ // Ensure headers can be mapped onto the skeleton chain
+ target := q.headerTaskPool[request.From].Hash()
+
+ accepted := len(headers) == MaxHeaderFetch
+ if accepted {
+ if headers[0].Number.Uint64() != request.From {
+ logger.Trace("First header broke chain ordering", "number", headers[0].Number, "hash", headers[0].Hash(), "expected", request.From)
+ accepted = false
+ } else if headers[len(headers)-1].Hash() != target {
+ logger.Trace("Last header broke skeleton structure ", "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target)
+ accepted = false
+ }
+ }
+ if accepted {
+ parentHash := headers[0].Hash()
+ for i, header := range headers[1:] {
+ hash := header.Hash()
+ if want := request.From + 1 + uint64(i); header.Number.Uint64() != want {
+ logger.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", want)
+ accepted = false
+ break
+ }
+ if parentHash != header.ParentHash {
+ logger.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash)
+ accepted = false
+ break
+ }
+ // Set-up parent hash for next round
+ parentHash = hash
+ }
+ }
+ // If the batch of headers wasn't accepted, mark as unavailable
+ if !accepted {
+ logger.Trace("Skeleton filling not accepted", "from", request.From)
+
+ miss := q.headerPeerMiss[id]
+ if miss == nil {
+ q.headerPeerMiss[id] = make(map[uint64]struct{})
+ miss = q.headerPeerMiss[id]
+ }
+ miss[request.From] = struct{}{}
+
+ q.headerTaskQueue.Push(request.From, -int64(request.From))
+ return 0, errors.New("delivery not accepted")
+ }
+ // Clean up a successful fetch and try to deliver any sub-results
+ copy(q.headerResults[request.From-q.headerOffset:], headers)
+ delete(q.headerTaskPool, request.From)
+
+ ready := 0
+ for q.headerProced+ready < len(q.headerResults) && q.headerResults[q.headerProced+ready] != nil {
+ ready += MaxHeaderFetch
+ }
+ if ready > 0 {
+ // Headers are ready for delivery, gather them and push forward (non blocking)
+ process := make([]*types.Header, ready)
+ copy(process, q.headerResults[q.headerProced:q.headerProced+ready])
+
+ select {
+ case headerProcCh <- process:
+ logger.Trace("Pre-scheduled new headers", "count", len(process), "from", process[0].Number)
+ q.headerProced += len(process)
+ default:
+ }
+ }
+ // Check for termination and return
+ if len(q.headerTaskPool) == 0 {
+ q.headerContCh <- false
+ }
+ return len(headers), nil
+}
+
+// DeliverBodies injects a block body retrieval response into the results queue.
+// The method returns the number of blocks bodies accepted from the delivery and
+// also wakes any threads waiting for data delivery.
+func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLists [][]*types.Header) (int, error) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ trieHasher := trie.NewStackTrie(nil)
+ validate := func(index int, header *types.Header) error {
+ if types.DeriveSha(types.Transactions(txLists[index]), trieHasher) != header.TxHash {
+ return errInvalidBody
+ }
+ if types.CalcUncleHash(uncleLists[index]) != header.UncleHash {
+ return errInvalidBody
+ }
+ return nil
+ }
+
+ reconstruct := func(index int, result *fetchResult) {
+ result.Transactions = txLists[index]
+ result.Uncles = uncleLists[index]
+ result.SetBodyDone()
+ }
+ return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool,
+ bodyReqTimer, len(txLists), validate, reconstruct)
+}
+
+// DeliverReceipts injects a receipt retrieval response into the results queue.
+// The method returns the number of transaction receipts accepted from the delivery
+// and also wakes any threads waiting for data delivery.
+func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, error) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ trieHasher := trie.NewStackTrie(nil)
+ validate := func(index int, header *types.Header) error {
+ if types.DeriveSha(types.Receipts(receiptList[index]), trieHasher) != header.ReceiptHash {
+ return errInvalidReceipt
+ }
+ return nil
+ }
+ reconstruct := func(index int, result *fetchResult) {
+ result.Receipts = receiptList[index]
+ result.SetReceiptsDone()
+ }
+ return q.deliver(id, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool,
+ receiptReqTimer, len(receiptList), validate, reconstruct)
+}
+
+// deliver injects a data retrieval response into the results queue.
+//
+// Note, this method expects the queue lock to be already held for writing. The
+// reason this lock is not obtained in here is because the parameters already need
+// to access the queue, so they already need a lock anyway.
+func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header,
+ taskQueue *prque.Prque, pendPool map[string]*fetchRequest, reqTimer metrics.Timer,
+ results int, validate func(index int, header *types.Header) error,
+ reconstruct func(index int, result *fetchResult)) (int, error) {
+
+ // Short circuit if the data was never requested
+ request := pendPool[id]
+ if request == nil {
+ return 0, errNoFetchesPending
+ }
+ reqTimer.UpdateSince(request.Time)
+ delete(pendPool, id)
+
+ // If no data items were retrieved, mark them as unavailable for the origin peer
+ if results == 0 {
+ for _, header := range request.Headers {
+ request.Peer.MarkLacking(header.Hash())
+ }
+ }
+ // Assemble each of the results with their headers and retrieved data parts
+ var (
+ accepted int
+ failure error
+ i int
+ hashes []common.Hash
+ )
+ for _, header := range request.Headers {
+ // Short circuit assembly if no more fetch results are found
+ if i >= results {
+ break
+ }
+ // Validate the fields
+ if err := validate(i, header); err != nil {
+ failure = err
+ break
+ }
+ hashes = append(hashes, header.Hash())
+ i++
+ }
+
+ for _, header := range request.Headers[:i] {
+ if res, stale, err := q.resultCache.GetDeliverySlot(header.Number.Uint64()); err == nil {
+ reconstruct(accepted, res)
+ } else {
+ // else: betweeen here and above, some other peer filled this result,
+ // or it was indeed a no-op. This should not happen, but if it does it's
+ // not something to panic about
+ log.Error("Delivery stale", "stale", stale, "number", header.Number.Uint64(), "err", err)
+ failure = errStaleDelivery
+ }
+ // Clean up a successful fetch
+ delete(taskPool, hashes[accepted])
+ accepted++
+ }
+ // Return all failed or missing fetches to the queue
+ for _, header := range request.Headers[accepted:] {
+ taskQueue.Push(header, -int64(header.Number.Uint64()))
+ }
+ // Wake up Results
+ if accepted > 0 {
+ q.active.Signal()
+ }
+ if failure == nil {
+ return accepted, nil
+ }
+ // If none of the data was good, it's a stale delivery
+ if accepted > 0 {
+ return accepted, fmt.Errorf("partial failure: %v", failure)
+ }
+ return accepted, fmt.Errorf("%w: %v", failure, errStaleDelivery)
+}
+
+// Prepare configures the result cache to allow accepting and caching inbound
+// fetch results.
+func (q *queue) Prepare(offset uint64, mode SyncMode) {
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ // Prepare the queue for sync results
+ q.resultCache.Prepare(offset)
+ q.mode = mode
+}
diff --git a/les/downloader/queue_test.go b/les/downloader/queue_test.go
new file mode 100644
index 0000000000..cde5f306a2
--- /dev/null
+++ b/les/downloader/queue_test.go
@@ -0,0 +1,452 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "fmt"
+ "math/big"
+ "math/rand"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/params"
+)
+
+var (
+ testdb = rawdb.NewMemoryDatabase()
+ genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000000000))
+)
+
+// makeChain creates a chain of n blocks starting at and including parent.
+// the returned hash chain is ordered head->parent. In addition, every 3rd block
+// contains a transaction and every 5th an uncle to allow testing correct block
+// reassembly.
+func makeChain(n int, seed byte, parent *types.Block, empty bool) ([]*types.Block, []types.Receipts) {
+ blocks, receipts := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testdb, n, func(i int, block *core.BlockGen) {
+ block.SetCoinbase(common.Address{seed})
+ // Add one tx to every secondblock
+ if !empty && i%2 == 0 {
+ signer := types.MakeSigner(params.TestChainConfig, block.Number())
+ tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey)
+ if err != nil {
+ panic(err)
+ }
+ block.AddTx(tx)
+ }
+ })
+ return blocks, receipts
+}
+
+type chainData struct {
+ blocks []*types.Block
+ offset int
+}
+
+var chain *chainData
+var emptyChain *chainData
+
+func init() {
+ // Create a chain of blocks to import
+ targetBlocks := 128
+ blocks, _ := makeChain(targetBlocks, 0, genesis, false)
+ chain = &chainData{blocks, 0}
+
+ blocks, _ = makeChain(targetBlocks, 0, genesis, true)
+ emptyChain = &chainData{blocks, 0}
+}
+
+func (chain *chainData) headers() []*types.Header {
+ hdrs := make([]*types.Header, len(chain.blocks))
+ for i, b := range chain.blocks {
+ hdrs[i] = b.Header()
+ }
+ return hdrs
+}
+
+func (chain *chainData) Len() int {
+ return len(chain.blocks)
+}
+
+func dummyPeer(id string) *peerConnection {
+ p := &peerConnection{
+ id: id,
+ lacking: make(map[common.Hash]struct{}),
+ }
+ return p
+}
+
+func TestBasics(t *testing.T) {
+ numOfBlocks := len(emptyChain.blocks)
+ numOfReceipts := len(emptyChain.blocks) / 2
+
+ q := newQueue(10, 10)
+ if !q.Idle() {
+ t.Errorf("new queue should be idle")
+ }
+ q.Prepare(1, FastSync)
+ if res := q.Results(false); len(res) != 0 {
+ t.Fatal("new queue should have 0 results")
+ }
+
+ // Schedule a batch of headers
+ q.Schedule(chain.headers(), 1)
+ if q.Idle() {
+ t.Errorf("queue should not be idle")
+ }
+ if got, exp := q.PendingBlocks(), chain.Len(); got != exp {
+ t.Errorf("wrong pending block count, got %d, exp %d", got, exp)
+ }
+ // Only non-empty receipts get added to task-queue
+ if got, exp := q.PendingReceipts(), 64; got != exp {
+ t.Errorf("wrong pending receipt count, got %d, exp %d", got, exp)
+ }
+ // Items are now queued for downloading, next step is that we tell the
+ // queue that a certain peer will deliver them for us
+ {
+ peer := dummyPeer("peer-1")
+ fetchReq, _, throttle := q.ReserveBodies(peer, 50)
+ if !throttle {
+ // queue size is only 10, so throttling should occur
+ t.Fatal("should throttle")
+ }
+ // But we should still get the first things to fetch
+ if got, exp := len(fetchReq.Headers), 5; got != exp {
+ t.Fatalf("expected %d requests, got %d", exp, got)
+ }
+ if got, exp := fetchReq.Headers[0].Number.Uint64(), uint64(1); got != exp {
+ t.Fatalf("expected header %d, got %d", exp, got)
+ }
+ }
+ if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got {
+ t.Errorf("expected block task queue to be %d, got %d", exp, got)
+ }
+ if exp, got := q.receiptTaskQueue.Size(), numOfReceipts; exp != got {
+ t.Errorf("expected receipt task queue to be %d, got %d", exp, got)
+ }
+ {
+ peer := dummyPeer("peer-2")
+ fetchReq, _, throttle := q.ReserveBodies(peer, 50)
+
+ // The second peer should hit throttling
+ if !throttle {
+ t.Fatalf("should not throttle")
+ }
+ // And not get any fetches at all, since it was throttled to begin with
+ if fetchReq != nil {
+ t.Fatalf("should have no fetches, got %d", len(fetchReq.Headers))
+ }
+ }
+ if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got {
+ t.Errorf("expected block task queue to be %d, got %d", exp, got)
+ }
+ if exp, got := q.receiptTaskQueue.Size(), numOfReceipts; exp != got {
+ t.Errorf("expected receipt task queue to be %d, got %d", exp, got)
+ }
+ {
+ // The receipt delivering peer should not be affected
+ // by the throttling of body deliveries
+ peer := dummyPeer("peer-3")
+ fetchReq, _, throttle := q.ReserveReceipts(peer, 50)
+ if !throttle {
+ // queue size is only 10, so throttling should occur
+ t.Fatal("should throttle")
+ }
+ // But we should still get the first things to fetch
+ if got, exp := len(fetchReq.Headers), 5; got != exp {
+ t.Fatalf("expected %d requests, got %d", exp, got)
+ }
+ if got, exp := fetchReq.Headers[0].Number.Uint64(), uint64(1); got != exp {
+ t.Fatalf("expected header %d, got %d", exp, got)
+ }
+
+ }
+ if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got {
+ t.Errorf("expected block task queue to be %d, got %d", exp, got)
+ }
+ if exp, got := q.receiptTaskQueue.Size(), numOfReceipts-5; exp != got {
+ t.Errorf("expected receipt task queue to be %d, got %d", exp, got)
+ }
+ if got, exp := q.resultCache.countCompleted(), 0; got != exp {
+ t.Errorf("wrong processable count, got %d, exp %d", got, exp)
+ }
+}
+
+func TestEmptyBlocks(t *testing.T) {
+ numOfBlocks := len(emptyChain.blocks)
+
+ q := newQueue(10, 10)
+
+ q.Prepare(1, FastSync)
+ // Schedule a batch of headers
+ q.Schedule(emptyChain.headers(), 1)
+ if q.Idle() {
+ t.Errorf("queue should not be idle")
+ }
+ if got, exp := q.PendingBlocks(), len(emptyChain.blocks); got != exp {
+ t.Errorf("wrong pending block count, got %d, exp %d", got, exp)
+ }
+ if got, exp := q.PendingReceipts(), 0; got != exp {
+ t.Errorf("wrong pending receipt count, got %d, exp %d", got, exp)
+ }
+ // They won't be processable, because the fetchresults haven't been
+ // created yet
+ if got, exp := q.resultCache.countCompleted(), 0; got != exp {
+ t.Errorf("wrong processable count, got %d, exp %d", got, exp)
+ }
+
+ // Items are now queued for downloading, next step is that we tell the
+ // queue that a certain peer will deliver them for us
+ // That should trigger all of them to suddenly become 'done'
+ {
+ // Reserve blocks
+ peer := dummyPeer("peer-1")
+ fetchReq, _, _ := q.ReserveBodies(peer, 50)
+
+ // there should be nothing to fetch, blocks are empty
+ if fetchReq != nil {
+ t.Fatal("there should be no body fetch tasks remaining")
+ }
+
+ }
+ if q.blockTaskQueue.Size() != numOfBlocks-10 {
+ t.Errorf("expected block task queue to be %d, got %d", numOfBlocks-10, q.blockTaskQueue.Size())
+ }
+ if q.receiptTaskQueue.Size() != 0 {
+ t.Errorf("expected receipt task queue to be %d, got %d", 0, q.receiptTaskQueue.Size())
+ }
+ {
+ peer := dummyPeer("peer-3")
+ fetchReq, _, _ := q.ReserveReceipts(peer, 50)
+
+ // there should be nothing to fetch, blocks are empty
+ if fetchReq != nil {
+ t.Fatal("there should be no body fetch tasks remaining")
+ }
+ }
+ if q.blockTaskQueue.Size() != numOfBlocks-10 {
+ t.Errorf("expected block task queue to be %d, got %d", numOfBlocks-10, q.blockTaskQueue.Size())
+ }
+ if q.receiptTaskQueue.Size() != 0 {
+ t.Errorf("expected receipt task queue to be %d, got %d", 0, q.receiptTaskQueue.Size())
+ }
+ if got, exp := q.resultCache.countCompleted(), 10; got != exp {
+ t.Errorf("wrong processable count, got %d, exp %d", got, exp)
+ }
+}
+
+// XTestDelivery does some more extensive testing of events that happen,
+// blocks that become known and peers that make reservations and deliveries.
+// disabled since it's not really a unit-test, but can be executed to test
+// some more advanced scenarios
+func XTestDelivery(t *testing.T) {
+ // the outside network, holding blocks
+ blo, rec := makeChain(128, 0, genesis, false)
+ world := newNetwork()
+ world.receipts = rec
+ world.chain = blo
+ world.progress(10)
+ if false {
+ log.Root().SetHandler(log.StdoutHandler)
+
+ }
+ q := newQueue(10, 10)
+ var wg sync.WaitGroup
+ q.Prepare(1, FastSync)
+ wg.Add(1)
+ go func() {
+ // deliver headers
+ defer wg.Done()
+ c := 1
+ for {
+ //fmt.Printf("getting headers from %d\n", c)
+ hdrs := world.headers(c)
+ l := len(hdrs)
+ //fmt.Printf("scheduling %d headers, first %d last %d\n",
+ // l, hdrs[0].Number.Uint64(), hdrs[len(hdrs)-1].Number.Uint64())
+ q.Schedule(hdrs, uint64(c))
+ c += l
+ }
+ }()
+ wg.Add(1)
+ go func() {
+ // collect results
+ defer wg.Done()
+ tot := 0
+ for {
+ res := q.Results(true)
+ tot += len(res)
+ fmt.Printf("got %d results, %d tot\n", len(res), tot)
+ // Now we can forget about these
+ world.forget(res[len(res)-1].Header.Number.Uint64())
+
+ }
+ }()
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ // reserve body fetch
+ i := 4
+ for {
+ peer := dummyPeer(fmt.Sprintf("peer-%d", i))
+ f, _, _ := q.ReserveBodies(peer, rand.Intn(30))
+ if f != nil {
+ var emptyList []*types.Header
+ var txs [][]*types.Transaction
+ var uncles [][]*types.Header
+ numToSkip := rand.Intn(len(f.Headers))
+ for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] {
+ txs = append(txs, world.getTransactions(hdr.Number.Uint64()))
+ uncles = append(uncles, emptyList)
+ }
+ time.Sleep(100 * time.Millisecond)
+ _, err := q.DeliverBodies(peer.id, txs, uncles)
+ if err != nil {
+ fmt.Printf("delivered %d bodies %v\n", len(txs), err)
+ }
+ } else {
+ i++
+ time.Sleep(200 * time.Millisecond)
+ }
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ // reserve receiptfetch
+ peer := dummyPeer("peer-3")
+ for {
+ f, _, _ := q.ReserveReceipts(peer, rand.Intn(50))
+ if f != nil {
+ var rcs [][]*types.Receipt
+ for _, hdr := range f.Headers {
+ rcs = append(rcs, world.getReceipts(hdr.Number.Uint64()))
+ }
+ _, err := q.DeliverReceipts(peer.id, rcs)
+ if err != nil {
+ fmt.Printf("delivered %d receipts %v\n", len(rcs), err)
+ }
+ time.Sleep(100 * time.Millisecond)
+ } else {
+ time.Sleep(200 * time.Millisecond)
+ }
+ }
+ }()
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < 50; i++ {
+ time.Sleep(300 * time.Millisecond)
+ //world.tick()
+ //fmt.Printf("trying to progress\n")
+ world.progress(rand.Intn(100))
+ }
+ for i := 0; i < 50; i++ {
+ time.Sleep(2990 * time.Millisecond)
+
+ }
+ }()
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for {
+ time.Sleep(990 * time.Millisecond)
+ fmt.Printf("world block tip is %d\n",
+ world.chain[len(world.chain)-1].Header().Number.Uint64())
+ fmt.Println(q.Stats())
+ }
+ }()
+ wg.Wait()
+}
+
+func newNetwork() *network {
+ var l sync.RWMutex
+ return &network{
+ cond: sync.NewCond(&l),
+ offset: 1, // block 1 is at blocks[0]
+ }
+}
+
+// represents the network
+type network struct {
+ offset int
+ chain []*types.Block
+ receipts []types.Receipts
+ lock sync.RWMutex
+ cond *sync.Cond
+}
+
+func (n *network) getTransactions(blocknum uint64) types.Transactions {
+ index := blocknum - uint64(n.offset)
+ return n.chain[index].Transactions()
+}
+func (n *network) getReceipts(blocknum uint64) types.Receipts {
+ index := blocknum - uint64(n.offset)
+ if got := n.chain[index].Header().Number.Uint64(); got != blocknum {
+ fmt.Printf("Err, got %d exp %d\n", got, blocknum)
+ panic("sd")
+ }
+ return n.receipts[index]
+}
+
+func (n *network) forget(blocknum uint64) {
+ index := blocknum - uint64(n.offset)
+ n.chain = n.chain[index:]
+ n.receipts = n.receipts[index:]
+ n.offset = int(blocknum)
+
+}
+func (n *network) progress(numBlocks int) {
+
+ n.lock.Lock()
+ defer n.lock.Unlock()
+ //fmt.Printf("progressing...\n")
+ newBlocks, newR := makeChain(numBlocks, 0, n.chain[len(n.chain)-1], false)
+ n.chain = append(n.chain, newBlocks...)
+ n.receipts = append(n.receipts, newR...)
+ n.cond.Broadcast()
+
+}
+
+func (n *network) headers(from int) []*types.Header {
+ numHeaders := 128
+ var hdrs []*types.Header
+ index := from - n.offset
+
+ for index >= len(n.chain) {
+ // wait for progress
+ n.cond.L.Lock()
+ //fmt.Printf("header going into wait\n")
+ n.cond.Wait()
+ index = from - n.offset
+ n.cond.L.Unlock()
+ }
+ n.lock.RLock()
+ defer n.lock.RUnlock()
+ for i, b := range n.chain[index:] {
+ hdrs = append(hdrs, b.Header())
+ if i >= numHeaders {
+ break
+ }
+ }
+ return hdrs
+}
diff --git a/les/downloader/resultstore.go b/les/downloader/resultstore.go
new file mode 100644
index 0000000000..21928c2a00
--- /dev/null
+++ b/les/downloader/resultstore.go
@@ -0,0 +1,194 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+// resultStore implements a structure for maintaining fetchResults, tracking their
+// download-progress and delivering (finished) results.
+type resultStore struct {
+ items []*fetchResult // Downloaded but not yet delivered fetch results
+ resultOffset uint64 // Offset of the first cached fetch result in the block chain
+
+ // Internal index of first non-completed entry, updated atomically when needed.
+ // If all items are complete, this will equal length(items), so
+ // *important* : is not safe to use for indexing without checking against length
+ indexIncomplete int32 // atomic access
+
+ // throttleThreshold is the limit up to which we _want_ to fill the
+ // results. If blocks are large, we want to limit the results to less
+ // than the number of available slots, and maybe only fill 1024 out of
+ // 8192 possible places. The queue will, at certain times, recalibrate
+ // this index.
+ throttleThreshold uint64
+
+ lock sync.RWMutex
+}
+
+func newResultStore(size int) *resultStore {
+ return &resultStore{
+ resultOffset: 0,
+ items: make([]*fetchResult, size),
+ throttleThreshold: uint64(size),
+ }
+}
+
+// SetThrottleThreshold updates the throttling threshold based on the requested
+// limit and the total queue capacity. It returns the (possibly capped) threshold
+func (r *resultStore) SetThrottleThreshold(threshold uint64) uint64 {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+
+ limit := uint64(len(r.items))
+ if threshold >= limit {
+ threshold = limit
+ }
+ r.throttleThreshold = threshold
+ return r.throttleThreshold
+}
+
+// AddFetch adds a header for body/receipt fetching. This is used when the queue
+// wants to reserve headers for fetching.
+//
+// It returns the following:
+// stale - if true, this item is already passed, and should not be requested again
+// throttled - if true, the store is at capacity, this particular header is not prio now
+// item - the result to store data into
+// err - any error that occurred
+func (r *resultStore) AddFetch(header *types.Header, fastSync bool) (stale, throttled bool, item *fetchResult, err error) {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+
+ var index int
+ item, index, stale, throttled, err = r.getFetchResult(header.Number.Uint64())
+ if err != nil || stale || throttled {
+ return stale, throttled, item, err
+ }
+ if item == nil {
+ item = newFetchResult(header, fastSync)
+ r.items[index] = item
+ }
+ return stale, throttled, item, err
+}
+
+// GetDeliverySlot returns the fetchResult for the given header. If the 'stale' flag
+// is true, that means the header has already been delivered 'upstream'. This method
+// does not bubble up the 'throttle' flag, since it's moot at the point in time when
+// the item is downloaded and ready for delivery
+func (r *resultStore) GetDeliverySlot(headerNumber uint64) (*fetchResult, bool, error) {
+ r.lock.RLock()
+ defer r.lock.RUnlock()
+
+ res, _, stale, _, err := r.getFetchResult(headerNumber)
+ return res, stale, err
+}
+
+// getFetchResult returns the fetchResult corresponding to the given item, and
+// the index where the result is stored.
+func (r *resultStore) getFetchResult(headerNumber uint64) (item *fetchResult, index int, stale, throttle bool, err error) {
+ index = int(int64(headerNumber) - int64(r.resultOffset))
+ throttle = index >= int(r.throttleThreshold)
+ stale = index < 0
+
+ if index >= len(r.items) {
+ err = fmt.Errorf("%w: index allocation went beyond available resultStore space "+
+ "(index [%d] = header [%d] - resultOffset [%d], len(resultStore) = %d", errInvalidChain,
+ index, headerNumber, r.resultOffset, len(r.items))
+ return nil, index, stale, throttle, err
+ }
+ if stale {
+ return nil, index, stale, throttle, nil
+ }
+ item = r.items[index]
+ return item, index, stale, throttle, nil
+}
+
+// hasCompletedItems returns true if there are processable items available
+// this method is cheaper than countCompleted
+func (r *resultStore) HasCompletedItems() bool {
+ r.lock.RLock()
+ defer r.lock.RUnlock()
+
+ if len(r.items) == 0 {
+ return false
+ }
+ if item := r.items[0]; item != nil && item.AllDone() {
+ return true
+ }
+ return false
+}
+
+// countCompleted returns the number of items ready for delivery, stopping at
+// the first non-complete item.
+//
+// The mthod assumes (at least) rlock is held.
+func (r *resultStore) countCompleted() int {
+ // We iterate from the already known complete point, and see
+ // if any more has completed since last count
+ index := atomic.LoadInt32(&r.indexIncomplete)
+ for ; ; index++ {
+ if index >= int32(len(r.items)) {
+ break
+ }
+ result := r.items[index]
+ if result == nil || !result.AllDone() {
+ break
+ }
+ }
+ atomic.StoreInt32(&r.indexIncomplete, index)
+ return int(index)
+}
+
+// GetCompleted returns the next batch of completed fetchResults
+func (r *resultStore) GetCompleted(limit int) []*fetchResult {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+
+ completed := r.countCompleted()
+ if limit > completed {
+ limit = completed
+ }
+ results := make([]*fetchResult, limit)
+ copy(results, r.items[:limit])
+
+ // Delete the results from the cache and clear the tail.
+ copy(r.items, r.items[limit:])
+ for i := len(r.items) - limit; i < len(r.items); i++ {
+ r.items[i] = nil
+ }
+ // Advance the expected block number of the first cache entry
+ r.resultOffset += uint64(limit)
+ atomic.AddInt32(&r.indexIncomplete, int32(-limit))
+
+ return results
+}
+
+// Prepare initialises the offset with the given block number
+func (r *resultStore) Prepare(offset uint64) {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+
+ if r.resultOffset < offset {
+ r.resultOffset = offset
+ }
+}
diff --git a/les/downloader/statesync.go b/les/downloader/statesync.go
new file mode 100644
index 0000000000..6c53e5577a
--- /dev/null
+++ b/les/downloader/statesync.go
@@ -0,0 +1,615 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/trie"
+ "golang.org/x/crypto/sha3"
+)
+
+// stateReq represents a batch of state fetch requests grouped together into
+// a single data retrieval network packet.
+type stateReq struct {
+ nItems uint16 // Number of items requested for download (max is 384, so uint16 is sufficient)
+ trieTasks map[common.Hash]*trieTask // Trie node download tasks to track previous attempts
+ codeTasks map[common.Hash]*codeTask // Byte code download tasks to track previous attempts
+ timeout time.Duration // Maximum round trip time for this to complete
+ timer *time.Timer // Timer to fire when the RTT timeout expires
+ peer *peerConnection // Peer that we're requesting from
+ delivered time.Time // Time when the packet was delivered (independent when we process it)
+ response [][]byte // Response data of the peer (nil for timeouts)
+ dropped bool // Flag whether the peer dropped off early
+}
+
+// timedOut returns if this request timed out.
+func (req *stateReq) timedOut() bool {
+ return req.response == nil
+}
+
+// stateSyncStats is a collection of progress stats to report during a state trie
+// sync to RPC requests as well as to display in user logs.
+type stateSyncStats struct {
+ processed uint64 // Number of state entries processed
+ duplicate uint64 // Number of state entries downloaded twice
+ unexpected uint64 // Number of non-requested state entries received
+ pending uint64 // Number of still pending state entries
+}
+
+// syncState starts downloading state with the given root hash.
+func (d *Downloader) syncState(root common.Hash) *stateSync {
+ // Create the state sync
+ s := newStateSync(d, root)
+ select {
+ case d.stateSyncStart <- s:
+ // If we tell the statesync to restart with a new root, we also need
+ // to wait for it to actually also start -- when old requests have timed
+ // out or been delivered
+ <-s.started
+ case <-d.quitCh:
+ s.err = errCancelStateFetch
+ close(s.done)
+ }
+ return s
+}
+
+// stateFetcher manages the active state sync and accepts requests
+// on its behalf.
+func (d *Downloader) stateFetcher() {
+ for {
+ select {
+ case s := <-d.stateSyncStart:
+ for next := s; next != nil; {
+ next = d.runStateSync(next)
+ }
+ case <-d.stateCh:
+ // Ignore state responses while no sync is running.
+ case <-d.quitCh:
+ return
+ }
+ }
+}
+
+// runStateSync runs a state synchronisation until it completes or another root
+// hash is requested to be switched over to.
+func (d *Downloader) runStateSync(s *stateSync) *stateSync {
+ var (
+ active = make(map[string]*stateReq) // Currently in-flight requests
+ finished []*stateReq // Completed or failed requests
+ timeout = make(chan *stateReq) // Timed out active requests
+ )
+ log.Trace("State sync starting", "root", s.root)
+
+ defer func() {
+ // Cancel active request timers on exit. Also set peers to idle so they're
+ // available for the next sync.
+ for _, req := range active {
+ req.timer.Stop()
+ req.peer.SetNodeDataIdle(int(req.nItems), time.Now())
+ }
+ }()
+ go s.run()
+ defer s.Cancel()
+
+ // Listen for peer departure events to cancel assigned tasks
+ peerDrop := make(chan *peerConnection, 1024)
+ peerSub := s.d.peers.SubscribePeerDrops(peerDrop)
+ defer peerSub.Unsubscribe()
+
+ for {
+ // Enable sending of the first buffered element if there is one.
+ var (
+ deliverReq *stateReq
+ deliverReqCh chan *stateReq
+ )
+ if len(finished) > 0 {
+ deliverReq = finished[0]
+ deliverReqCh = s.deliver
+ }
+
+ select {
+ // The stateSync lifecycle:
+ case next := <-d.stateSyncStart:
+ d.spindownStateSync(active, finished, timeout, peerDrop)
+ return next
+
+ case <-s.done:
+ d.spindownStateSync(active, finished, timeout, peerDrop)
+ return nil
+
+ // Send the next finished request to the current sync:
+ case deliverReqCh <- deliverReq:
+ // Shift out the first request, but also set the emptied slot to nil for GC
+ copy(finished, finished[1:])
+ finished[len(finished)-1] = nil
+ finished = finished[:len(finished)-1]
+
+ // Handle incoming state packs:
+ case pack := <-d.stateCh:
+ // Discard any data not requested (or previously timed out)
+ req := active[pack.PeerId()]
+ if req == nil {
+ log.Debug("Unrequested node data", "peer", pack.PeerId(), "len", pack.Items())
+ continue
+ }
+ // Finalize the request and queue up for processing
+ req.timer.Stop()
+ req.response = pack.(*statePack).states
+ req.delivered = time.Now()
+
+ finished = append(finished, req)
+ delete(active, pack.PeerId())
+
+ // Handle dropped peer connections:
+ case p := <-peerDrop:
+ // Skip if no request is currently pending
+ req := active[p.id]
+ if req == nil {
+ continue
+ }
+ // Finalize the request and queue up for processing
+ req.timer.Stop()
+ req.dropped = true
+ req.delivered = time.Now()
+
+ finished = append(finished, req)
+ delete(active, p.id)
+
+ // Handle timed-out requests:
+ case req := <-timeout:
+ // If the peer is already requesting something else, ignore the stale timeout.
+ // This can happen when the timeout and the delivery happens simultaneously,
+ // causing both pathways to trigger.
+ if active[req.peer.id] != req {
+ continue
+ }
+ req.delivered = time.Now()
+ // Move the timed out data back into the download queue
+ finished = append(finished, req)
+ delete(active, req.peer.id)
+
+ // Track outgoing state requests:
+ case req := <-d.trackStateReq:
+ // If an active request already exists for this peer, we have a problem. In
+ // theory the trie node schedule must never assign two requests to the same
+ // peer. In practice however, a peer might receive a request, disconnect and
+ // immediately reconnect before the previous times out. In this case the first
+ // request is never honored, alas we must not silently overwrite it, as that
+ // causes valid requests to go missing and sync to get stuck.
+ if old := active[req.peer.id]; old != nil {
+ log.Warn("Busy peer assigned new state fetch", "peer", old.peer.id)
+ // Move the previous request to the finished set
+ old.timer.Stop()
+ old.dropped = true
+ old.delivered = time.Now()
+ finished = append(finished, old)
+ }
+ // Start a timer to notify the sync loop if the peer stalled.
+ req.timer = time.AfterFunc(req.timeout, func() {
+ timeout <- req
+ })
+ active[req.peer.id] = req
+ }
+ }
+}
+
+// spindownStateSync 'drains' the outstanding requests; some will be delivered and other
+// will time out. This is to ensure that when the next stateSync starts working, all peers
+// are marked as idle and de facto _are_ idle.
+func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*stateReq, timeout chan *stateReq, peerDrop chan *peerConnection) {
+ log.Trace("State sync spinning down", "active", len(active), "finished", len(finished))
+ for len(active) > 0 {
+ var (
+ req *stateReq
+ reason string
+ )
+ select {
+ // Handle (drop) incoming state packs:
+ case pack := <-d.stateCh:
+ req = active[pack.PeerId()]
+ reason = "delivered"
+ // Handle dropped peer connections:
+ case p := <-peerDrop:
+ req = active[p.id]
+ reason = "peerdrop"
+ // Handle timed-out requests:
+ case req = <-timeout:
+ reason = "timeout"
+ }
+ if req == nil {
+ continue
+ }
+ req.peer.log.Trace("State peer marked idle (spindown)", "req.items", int(req.nItems), "reason", reason)
+ req.timer.Stop()
+ delete(active, req.peer.id)
+ req.peer.SetNodeDataIdle(int(req.nItems), time.Now())
+ }
+ // The 'finished' set contains deliveries that we were going to pass to processing.
+ // Those are now moot, but we still need to set those peers as idle, which would
+ // otherwise have been done after processing
+ for _, req := range finished {
+ req.peer.SetNodeDataIdle(int(req.nItems), time.Now())
+ }
+}
+
+// stateSync schedules requests for downloading a particular state trie defined
+// by a given state root.
+type stateSync struct {
+ d *Downloader // Downloader instance to access and manage current peerset
+
+ root common.Hash // State root currently being synced
+ sched *trie.Sync // State trie sync scheduler defining the tasks
+ keccak crypto.KeccakState // Keccak256 hasher to verify deliveries with
+
+ trieTasks map[common.Hash]*trieTask // Set of trie node tasks currently queued for retrieval
+ codeTasks map[common.Hash]*codeTask // Set of byte code tasks currently queued for retrieval
+
+ numUncommitted int
+ bytesUncommitted int
+
+ started chan struct{} // Started is signalled once the sync loop starts
+
+ deliver chan *stateReq // Delivery channel multiplexing peer responses
+ cancel chan struct{} // Channel to signal a termination request
+ cancelOnce sync.Once // Ensures cancel only ever gets called once
+ done chan struct{} // Channel to signal termination completion
+ err error // Any error hit during sync (set before completion)
+}
+
+// trieTask represents a single trie node download task, containing a set of
+// peers already attempted retrieval from to detect stalled syncs and abort.
+type trieTask struct {
+ path [][]byte
+ attempts map[string]struct{}
+}
+
+// codeTask represents a single byte code download task, containing a set of
+// peers already attempted retrieval from to detect stalled syncs and abort.
+type codeTask struct {
+ attempts map[string]struct{}
+}
+
+// newStateSync creates a new state trie download scheduler. This method does not
+// yet start the sync. The user needs to call run to initiate.
+func newStateSync(d *Downloader, root common.Hash) *stateSync {
+ return &stateSync{
+ d: d,
+ root: root,
+ sched: state.NewStateSync(root, d.stateDB, d.stateBloom, nil),
+ keccak: sha3.NewLegacyKeccak256().(crypto.KeccakState),
+ trieTasks: make(map[common.Hash]*trieTask),
+ codeTasks: make(map[common.Hash]*codeTask),
+ deliver: make(chan *stateReq),
+ cancel: make(chan struct{}),
+ done: make(chan struct{}),
+ started: make(chan struct{}),
+ }
+}
+
+// run starts the task assignment and response processing loop, blocking until
+// it finishes, and finally notifying any goroutines waiting for the loop to
+// finish.
+func (s *stateSync) run() {
+ close(s.started)
+ if s.d.snapSync {
+ s.err = s.d.SnapSyncer.Sync(s.root, s.cancel)
+ } else {
+ s.err = s.loop()
+ }
+ close(s.done)
+}
+
+// Wait blocks until the sync is done or canceled.
+func (s *stateSync) Wait() error {
+ <-s.done
+ return s.err
+}
+
+// Cancel cancels the sync and waits until it has shut down.
+func (s *stateSync) Cancel() error {
+ s.cancelOnce.Do(func() {
+ close(s.cancel)
+ })
+ return s.Wait()
+}
+
+// loop is the main event loop of a state trie sync. It it responsible for the
+// assignment of new tasks to peers (including sending it to them) as well as
+// for the processing of inbound data. Note, that the loop does not directly
+// receive data from peers, rather those are buffered up in the downloader and
+// pushed here async. The reason is to decouple processing from data receipt
+// and timeouts.
+func (s *stateSync) loop() (err error) {
+ // Listen for new peer events to assign tasks to them
+ newPeer := make(chan *peerConnection, 1024)
+ peerSub := s.d.peers.SubscribeNewPeers(newPeer)
+ defer peerSub.Unsubscribe()
+ defer func() {
+ cerr := s.commit(true)
+ if err == nil {
+ err = cerr
+ }
+ }()
+
+ // Keep assigning new tasks until the sync completes or aborts
+ for s.sched.Pending() > 0 {
+ if err = s.commit(false); err != nil {
+ return err
+ }
+ s.assignTasks()
+ // Tasks assigned, wait for something to happen
+ select {
+ case <-newPeer:
+ // New peer arrived, try to assign it download tasks
+
+ case <-s.cancel:
+ return errCancelStateFetch
+
+ case <-s.d.cancelCh:
+ return errCanceled
+
+ case req := <-s.deliver:
+ // Response, disconnect or timeout triggered, drop the peer if stalling
+ log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut())
+ if req.nItems <= 2 && !req.dropped && req.timedOut() {
+ // 2 items are the minimum requested, if even that times out, we've no use of
+ // this peer at the moment.
+ log.Warn("Stalling state sync, dropping peer", "peer", req.peer.id)
+ if s.d.dropPeer == nil {
+ // The dropPeer method is nil when `--copydb` is used for a local copy.
+ // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored
+ req.peer.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", req.peer.id)
+ } else {
+ s.d.dropPeer(req.peer.id)
+
+ // If this peer was the master peer, abort sync immediately
+ s.d.cancelLock.RLock()
+ master := req.peer.id == s.d.cancelPeer
+ s.d.cancelLock.RUnlock()
+
+ if master {
+ s.d.cancel()
+ return errTimeout
+ }
+ }
+ }
+ // Process all the received blobs and check for stale delivery
+ delivered, err := s.process(req)
+ req.peer.SetNodeDataIdle(delivered, req.delivered)
+ if err != nil {
+ log.Warn("Node data write error", "err", err)
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (s *stateSync) commit(force bool) error {
+ if !force && s.bytesUncommitted < ethdb.IdealBatchSize {
+ return nil
+ }
+ start := time.Now()
+ b := s.d.stateDB.NewBatch()
+ if err := s.sched.Commit(b); err != nil {
+ return err
+ }
+ if err := b.Write(); err != nil {
+ return fmt.Errorf("DB write error: %v", err)
+ }
+ s.updateStats(s.numUncommitted, 0, 0, time.Since(start))
+ s.numUncommitted = 0
+ s.bytesUncommitted = 0
+ return nil
+}
+
+// assignTasks attempts to assign new tasks to all idle peers, either from the
+// batch currently being retried, or fetching new data from the trie sync itself.
+func (s *stateSync) assignTasks() {
+ // Iterate over all idle peers and try to assign them state fetches
+ peers, _ := s.d.peers.NodeDataIdlePeers()
+ for _, p := range peers {
+ // Assign a batch of fetches proportional to the estimated latency/bandwidth
+ cap := p.NodeDataCapacity(s.d.peers.rates.TargetRoundTrip())
+ req := &stateReq{peer: p, timeout: s.d.peers.rates.TargetTimeout()}
+
+ nodes, _, codes := s.fillTasks(cap, req)
+
+ // If the peer was assigned tasks to fetch, send the network request
+ if len(nodes)+len(codes) > 0 {
+ req.peer.log.Trace("Requesting batch of state data", "nodes", len(nodes), "codes", len(codes), "root", s.root)
+ select {
+ case s.d.trackStateReq <- req:
+ req.peer.FetchNodeData(append(nodes, codes...)) // Unified retrieval under eth/6x
+ case <-s.cancel:
+ case <-s.d.cancelCh:
+ }
+ }
+ }
+}
+
+// fillTasks fills the given request object with a maximum of n state download
+// tasks to send to the remote peer.
+func (s *stateSync) fillTasks(n int, req *stateReq) (nodes []common.Hash, paths []trie.SyncPath, codes []common.Hash) {
+ // Refill available tasks from the scheduler.
+ if fill := n - (len(s.trieTasks) + len(s.codeTasks)); fill > 0 {
+ nodes, paths, codes := s.sched.Missing(fill)
+ for i, hash := range nodes {
+ s.trieTasks[hash] = &trieTask{
+ path: paths[i],
+ attempts: make(map[string]struct{}),
+ }
+ }
+ for _, hash := range codes {
+ s.codeTasks[hash] = &codeTask{
+ attempts: make(map[string]struct{}),
+ }
+ }
+ }
+ // Find tasks that haven't been tried with the request's peer. Prefer code
+ // over trie nodes as those can be written to disk and forgotten about.
+ nodes = make([]common.Hash, 0, n)
+ paths = make([]trie.SyncPath, 0, n)
+ codes = make([]common.Hash, 0, n)
+
+ req.trieTasks = make(map[common.Hash]*trieTask, n)
+ req.codeTasks = make(map[common.Hash]*codeTask, n)
+
+ for hash, t := range s.codeTasks {
+ // Stop when we've gathered enough requests
+ if len(nodes)+len(codes) == n {
+ break
+ }
+ // Skip any requests we've already tried from this peer
+ if _, ok := t.attempts[req.peer.id]; ok {
+ continue
+ }
+ // Assign the request to this peer
+ t.attempts[req.peer.id] = struct{}{}
+ codes = append(codes, hash)
+ req.codeTasks[hash] = t
+ delete(s.codeTasks, hash)
+ }
+ for hash, t := range s.trieTasks {
+ // Stop when we've gathered enough requests
+ if len(nodes)+len(codes) == n {
+ break
+ }
+ // Skip any requests we've already tried from this peer
+ if _, ok := t.attempts[req.peer.id]; ok {
+ continue
+ }
+ // Assign the request to this peer
+ t.attempts[req.peer.id] = struct{}{}
+
+ nodes = append(nodes, hash)
+ paths = append(paths, t.path)
+
+ req.trieTasks[hash] = t
+ delete(s.trieTasks, hash)
+ }
+ req.nItems = uint16(len(nodes) + len(codes))
+ return nodes, paths, codes
+}
+
+// process iterates over a batch of delivered state data, injecting each item
+// into a running state sync, re-queuing any items that were requested but not
+// delivered. Returns whether the peer actually managed to deliver anything of
+// value, and any error that occurred.
+func (s *stateSync) process(req *stateReq) (int, error) {
+ // Collect processing stats and update progress if valid data was received
+ duplicate, unexpected, successful := 0, 0, 0
+
+ defer func(start time.Time) {
+ if duplicate > 0 || unexpected > 0 {
+ s.updateStats(0, duplicate, unexpected, time.Since(start))
+ }
+ }(time.Now())
+
+ // Iterate over all the delivered data and inject one-by-one into the trie
+ for _, blob := range req.response {
+ hash, err := s.processNodeData(blob)
+ switch err {
+ case nil:
+ s.numUncommitted++
+ s.bytesUncommitted += len(blob)
+ successful++
+ case trie.ErrNotRequested:
+ unexpected++
+ case trie.ErrAlreadyProcessed:
+ duplicate++
+ default:
+ return successful, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
+ }
+ // Delete from both queues (one delivery is enough for the syncer)
+ delete(req.trieTasks, hash)
+ delete(req.codeTasks, hash)
+ }
+ // Put unfulfilled tasks back into the retry queue
+ npeers := s.d.peers.Len()
+ for hash, task := range req.trieTasks {
+ // If the node did deliver something, missing items may be due to a protocol
+ // limit or a previous timeout + delayed delivery. Both cases should permit
+ // the node to retry the missing items (to avoid single-peer stalls).
+ if len(req.response) > 0 || req.timedOut() {
+ delete(task.attempts, req.peer.id)
+ }
+ // If we've requested the node too many times already, it may be a malicious
+ // sync where nobody has the right data. Abort.
+ if len(task.attempts) >= npeers {
+ return successful, fmt.Errorf("trie node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
+ }
+ // Missing item, place into the retry queue.
+ s.trieTasks[hash] = task
+ }
+ for hash, task := range req.codeTasks {
+ // If the node did deliver something, missing items may be due to a protocol
+ // limit or a previous timeout + delayed delivery. Both cases should permit
+ // the node to retry the missing items (to avoid single-peer stalls).
+ if len(req.response) > 0 || req.timedOut() {
+ delete(task.attempts, req.peer.id)
+ }
+ // If we've requested the node too many times already, it may be a malicious
+ // sync where nobody has the right data. Abort.
+ if len(task.attempts) >= npeers {
+ return successful, fmt.Errorf("byte code %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
+ }
+ // Missing item, place into the retry queue.
+ s.codeTasks[hash] = task
+ }
+ return successful, nil
+}
+
+// processNodeData tries to inject a trie node data blob delivered from a remote
+// peer into the state trie, returning whether anything useful was written or any
+// error occurred.
+func (s *stateSync) processNodeData(blob []byte) (common.Hash, error) {
+ res := trie.SyncResult{Data: blob}
+ s.keccak.Reset()
+ s.keccak.Write(blob)
+ s.keccak.Read(res.Hash[:])
+ err := s.sched.Process(res)
+ return res.Hash, err
+}
+
+// updateStats bumps the various state sync progress counters and displays a log
+// message for the user to see.
+func (s *stateSync) updateStats(written, duplicate, unexpected int, duration time.Duration) {
+ s.d.syncStatsLock.Lock()
+ defer s.d.syncStatsLock.Unlock()
+
+ s.d.syncStatsState.pending = uint64(s.sched.Pending())
+ s.d.syncStatsState.processed += uint64(written)
+ s.d.syncStatsState.duplicate += uint64(duplicate)
+ s.d.syncStatsState.unexpected += uint64(unexpected)
+
+ if written > 0 || duplicate > 0 || unexpected > 0 {
+ log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "trieretry", len(s.trieTasks), "coderetry", len(s.codeTasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected)
+ }
+ if written > 0 {
+ rawdb.WriteFastTrieProgress(s.d.stateDB, s.d.syncStatsState.processed)
+ }
+}
diff --git a/les/downloader/testchain_test.go b/les/downloader/testchain_test.go
new file mode 100644
index 0000000000..b9865f7e03
--- /dev/null
+++ b/les/downloader/testchain_test.go
@@ -0,0 +1,230 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "fmt"
+ "math/big"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/params"
+)
+
+// Test chain parameters.
+var (
+ testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ testAddress = crypto.PubkeyToAddress(testKey.PublicKey)
+ testDB = rawdb.NewMemoryDatabase()
+ testGenesis = core.GenesisBlockForTesting(testDB, testAddress, big.NewInt(1000000000000000))
+)
+
+// The common prefix of all test chains:
+var testChainBase = newTestChain(blockCacheMaxItems+200, testGenesis)
+
+// Different forks on top of the base chain:
+var testChainForkLightA, testChainForkLightB, testChainForkHeavy *testChain
+
+func init() {
+ var forkLen = int(fullMaxForkAncestry + 50)
+ var wg sync.WaitGroup
+ wg.Add(3)
+ go func() { testChainForkLightA = testChainBase.makeFork(forkLen, false, 1); wg.Done() }()
+ go func() { testChainForkLightB = testChainBase.makeFork(forkLen, false, 2); wg.Done() }()
+ go func() { testChainForkHeavy = testChainBase.makeFork(forkLen, true, 3); wg.Done() }()
+ wg.Wait()
+}
+
+type testChain struct {
+ genesis *types.Block
+ chain []common.Hash
+ headerm map[common.Hash]*types.Header
+ blockm map[common.Hash]*types.Block
+ receiptm map[common.Hash][]*types.Receipt
+ tdm map[common.Hash]*big.Int
+}
+
+// newTestChain creates a blockchain of the given length.
+func newTestChain(length int, genesis *types.Block) *testChain {
+ tc := new(testChain).copy(length)
+ tc.genesis = genesis
+ tc.chain = append(tc.chain, genesis.Hash())
+ tc.headerm[tc.genesis.Hash()] = tc.genesis.Header()
+ tc.tdm[tc.genesis.Hash()] = tc.genesis.Difficulty()
+ tc.blockm[tc.genesis.Hash()] = tc.genesis
+ tc.generate(length-1, 0, genesis, false)
+ return tc
+}
+
+// makeFork creates a fork on top of the test chain.
+func (tc *testChain) makeFork(length int, heavy bool, seed byte) *testChain {
+ fork := tc.copy(tc.len() + length)
+ fork.generate(length, seed, tc.headBlock(), heavy)
+ return fork
+}
+
+// shorten creates a copy of the chain with the given length. It panics if the
+// length is longer than the number of available blocks.
+func (tc *testChain) shorten(length int) *testChain {
+ if length > tc.len() {
+ panic(fmt.Errorf("can't shorten test chain to %d blocks, it's only %d blocks long", length, tc.len()))
+ }
+ return tc.copy(length)
+}
+
+func (tc *testChain) copy(newlen int) *testChain {
+ cpy := &testChain{
+ genesis: tc.genesis,
+ headerm: make(map[common.Hash]*types.Header, newlen),
+ blockm: make(map[common.Hash]*types.Block, newlen),
+ receiptm: make(map[common.Hash][]*types.Receipt, newlen),
+ tdm: make(map[common.Hash]*big.Int, newlen),
+ }
+ for i := 0; i < len(tc.chain) && i < newlen; i++ {
+ hash := tc.chain[i]
+ cpy.chain = append(cpy.chain, tc.chain[i])
+ cpy.tdm[hash] = tc.tdm[hash]
+ cpy.blockm[hash] = tc.blockm[hash]
+ cpy.headerm[hash] = tc.headerm[hash]
+ cpy.receiptm[hash] = tc.receiptm[hash]
+ }
+ return cpy
+}
+
+// generate creates a chain of n blocks starting at and including parent.
+// the returned hash chain is ordered head->parent. In addition, every 22th block
+// contains a transaction and every 5th an uncle to allow testing correct block
+// reassembly.
+func (tc *testChain) generate(n int, seed byte, parent *types.Block, heavy bool) {
+ // start := time.Now()
+ // defer func() { fmt.Printf("test chain generated in %v\n", time.Since(start)) }()
+
+ blocks, receipts := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testDB, n, func(i int, block *core.BlockGen) {
+ block.SetCoinbase(common.Address{seed})
+ // If a heavy chain is requested, delay blocks to raise difficulty
+ if heavy {
+ block.OffsetTime(-1)
+ }
+ // Include transactions to the miner to make blocks more interesting.
+ if parent == tc.genesis && i%22 == 0 {
+ signer := types.MakeSigner(params.TestChainConfig, block.Number())
+ tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey)
+ if err != nil {
+ panic(err)
+ }
+ block.AddTx(tx)
+ }
+ // if the block number is a multiple of 5, add a bonus uncle to the block
+ if i > 0 && i%5 == 0 {
+ block.AddUncle(&types.Header{
+ ParentHash: block.PrevBlock(i - 1).Hash(),
+ Number: big.NewInt(block.Number().Int64() - 1),
+ })
+ }
+ })
+
+ // Convert the block-chain into a hash-chain and header/block maps
+ td := new(big.Int).Set(tc.td(parent.Hash()))
+ for i, b := range blocks {
+ td := td.Add(td, b.Difficulty())
+ hash := b.Hash()
+ tc.chain = append(tc.chain, hash)
+ tc.blockm[hash] = b
+ tc.headerm[hash] = b.Header()
+ tc.receiptm[hash] = receipts[i]
+ tc.tdm[hash] = new(big.Int).Set(td)
+ }
+}
+
+// len returns the total number of blocks in the chain.
+func (tc *testChain) len() int {
+ return len(tc.chain)
+}
+
+// headBlock returns the head of the chain.
+func (tc *testChain) headBlock() *types.Block {
+ return tc.blockm[tc.chain[len(tc.chain)-1]]
+}
+
+// td returns the total difficulty of the given block.
+func (tc *testChain) td(hash common.Hash) *big.Int {
+ return tc.tdm[hash]
+}
+
+// headersByHash returns headers in order from the given hash.
+func (tc *testChain) headersByHash(origin common.Hash, amount int, skip int, reverse bool) []*types.Header {
+ num, _ := tc.hashToNumber(origin)
+ return tc.headersByNumber(num, amount, skip, reverse)
+}
+
+// headersByNumber returns headers from the given number.
+func (tc *testChain) headersByNumber(origin uint64, amount int, skip int, reverse bool) []*types.Header {
+ result := make([]*types.Header, 0, amount)
+
+ if !reverse {
+ for num := origin; num < uint64(len(tc.chain)) && len(result) < amount; num += uint64(skip) + 1 {
+ if header, ok := tc.headerm[tc.chain[int(num)]]; ok {
+ result = append(result, header)
+ }
+ }
+ } else {
+ for num := int64(origin); num >= 0 && len(result) < amount; num -= int64(skip) + 1 {
+ if header, ok := tc.headerm[tc.chain[int(num)]]; ok {
+ result = append(result, header)
+ }
+ }
+ }
+ return result
+}
+
+// receipts returns the receipts of the given block hashes.
+func (tc *testChain) receipts(hashes []common.Hash) [][]*types.Receipt {
+ results := make([][]*types.Receipt, 0, len(hashes))
+ for _, hash := range hashes {
+ if receipt, ok := tc.receiptm[hash]; ok {
+ results = append(results, receipt)
+ }
+ }
+ return results
+}
+
+// bodies returns the block bodies of the given block hashes.
+func (tc *testChain) bodies(hashes []common.Hash) ([][]*types.Transaction, [][]*types.Header) {
+ transactions := make([][]*types.Transaction, 0, len(hashes))
+ uncles := make([][]*types.Header, 0, len(hashes))
+ for _, hash := range hashes {
+ if block, ok := tc.blockm[hash]; ok {
+ transactions = append(transactions, block.Transactions())
+ uncles = append(uncles, block.Uncles())
+ }
+ }
+ return transactions, uncles
+}
+
+func (tc *testChain) hashToNumber(target common.Hash) (uint64, bool) {
+ for num, hash := range tc.chain {
+ if hash == target {
+ return uint64(num), true
+ }
+ }
+ return 0, false
+}
diff --git a/les/downloader/types.go b/les/downloader/types.go
new file mode 100644
index 0000000000..ff70bfa0e3
--- /dev/null
+++ b/les/downloader/types.go
@@ -0,0 +1,79 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package downloader
+
+import (
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+// peerDropFn is a callback type for dropping a peer detected as malicious.
+type peerDropFn func(id string)
+
+// dataPack is a data message returned by a peer for some query.
+type dataPack interface {
+ PeerId() string
+ Items() int
+ Stats() string
+}
+
+// headerPack is a batch of block headers returned by a peer.
+type headerPack struct {
+ peerID string
+ headers []*types.Header
+}
+
+func (p *headerPack) PeerId() string { return p.peerID }
+func (p *headerPack) Items() int { return len(p.headers) }
+func (p *headerPack) Stats() string { return fmt.Sprintf("%d", len(p.headers)) }
+
+// bodyPack is a batch of block bodies returned by a peer.
+type bodyPack struct {
+ peerID string
+ transactions [][]*types.Transaction
+ uncles [][]*types.Header
+}
+
+func (p *bodyPack) PeerId() string { return p.peerID }
+func (p *bodyPack) Items() int {
+ if len(p.transactions) <= len(p.uncles) {
+ return len(p.transactions)
+ }
+ return len(p.uncles)
+}
+func (p *bodyPack) Stats() string { return fmt.Sprintf("%d:%d", len(p.transactions), len(p.uncles)) }
+
+// receiptPack is a batch of receipts returned by a peer.
+type receiptPack struct {
+ peerID string
+ receipts [][]*types.Receipt
+}
+
+func (p *receiptPack) PeerId() string { return p.peerID }
+func (p *receiptPack) Items() int { return len(p.receipts) }
+func (p *receiptPack) Stats() string { return fmt.Sprintf("%d", len(p.receipts)) }
+
+// statePack is a batch of states returned by a peer.
+type statePack struct {
+ peerID string
+ states [][]byte
+}
+
+func (p *statePack) PeerId() string { return p.peerID }
+func (p *statePack) Items() int { return len(p.states) }
+func (p *statePack) Stats() string { return fmt.Sprintf("%d", len(p.states)) }
diff --git a/les/fetcher.go b/les/fetcher.go
index 5eea996748..d944d32858 100644
--- a/les/fetcher.go
+++ b/les/fetcher.go
@@ -27,8 +27,8 @@ import (
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/eth/fetcher"
"github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/les/fetcher"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
diff --git a/les/fetcher/block_fetcher.go b/les/fetcher/block_fetcher.go
new file mode 100644
index 0000000000..283008db0f
--- /dev/null
+++ b/les/fetcher/block_fetcher.go
@@ -0,0 +1,889 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// This is a temporary package whilst working on the eth/66 blocking refactors.
+// After that work is done, les needs to be refactored to use the new package,
+// or alternatively use a stripped down version of it. Either way, we need to
+// keep the changes scoped so duplicating temporarily seems the sanest.
+package fetcher
+
+import (
+ "errors"
+ "math/rand"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/prque"
+ "github.com/ethereum/go-ethereum/consensus"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+const (
+ lightTimeout = time.Millisecond // Time allowance before an announced header is explicitly requested
+ arriveTimeout = 500 * time.Millisecond // Time allowance before an announced block/transaction is explicitly requested
+ gatherSlack = 100 * time.Millisecond // Interval used to collate almost-expired announces with fetches
+ fetchTimeout = 5 * time.Second // Maximum allotted time to return an explicitly requested block/transaction
+)
+
+const (
+ maxUncleDist = 7 // Maximum allowed backward distance from the chain head
+ maxQueueDist = 32 // Maximum allowed distance from the chain head to queue
+ hashLimit = 256 // Maximum number of unique blocks or headers a peer may have announced
+ blockLimit = 64 // Maximum number of unique blocks a peer may have delivered
+)
+
+var (
+ blockAnnounceInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/in", nil)
+ blockAnnounceOutTimer = metrics.NewRegisteredTimer("eth/fetcher/block/announces/out", nil)
+ blockAnnounceDropMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/drop", nil)
+ blockAnnounceDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/dos", nil)
+
+ blockBroadcastInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/in", nil)
+ blockBroadcastOutTimer = metrics.NewRegisteredTimer("eth/fetcher/block/broadcasts/out", nil)
+ blockBroadcastDropMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/drop", nil)
+ blockBroadcastDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/dos", nil)
+
+ headerFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/block/headers", nil)
+ bodyFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/block/bodies", nil)
+
+ headerFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/headers/in", nil)
+ headerFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/headers/out", nil)
+ bodyFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/bodies/in", nil)
+ bodyFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/bodies/out", nil)
+)
+
+var errTerminated = errors.New("terminated")
+
+// HeaderRetrievalFn is a callback type for retrieving a header from the local chain.
+type HeaderRetrievalFn func(common.Hash) *types.Header
+
+// blockRetrievalFn is a callback type for retrieving a block from the local chain.
+type blockRetrievalFn func(common.Hash) *types.Block
+
+// headerRequesterFn is a callback type for sending a header retrieval request.
+type headerRequesterFn func(common.Hash) error
+
+// bodyRequesterFn is a callback type for sending a body retrieval request.
+type bodyRequesterFn func([]common.Hash) error
+
+// headerVerifierFn is a callback type to verify a block's header for fast propagation.
+type headerVerifierFn func(header *types.Header) error
+
+// blockBroadcasterFn is a callback type for broadcasting a block to connected peers.
+type blockBroadcasterFn func(block *types.Block, propagate bool)
+
+// chainHeightFn is a callback type to retrieve the current chain height.
+type chainHeightFn func() uint64
+
+// headersInsertFn is a callback type to insert a batch of headers into the local chain.
+type headersInsertFn func(headers []*types.Header) (int, error)
+
+// chainInsertFn is a callback type to insert a batch of blocks into the local chain.
+type chainInsertFn func(types.Blocks) (int, error)
+
+// peerDropFn is a callback type for dropping a peer detected as malicious.
+type peerDropFn func(id string)
+
+// blockAnnounce is the hash notification of the availability of a new block in the
+// network.
+type blockAnnounce struct {
+ hash common.Hash // Hash of the block being announced
+ number uint64 // Number of the block being announced (0 = unknown | old protocol)
+ header *types.Header // Header of the block partially reassembled (new protocol)
+ time time.Time // Timestamp of the announcement
+
+ origin string // Identifier of the peer originating the notification
+
+ fetchHeader headerRequesterFn // Fetcher function to retrieve the header of an announced block
+ fetchBodies bodyRequesterFn // Fetcher function to retrieve the body of an announced block
+}
+
+// headerFilterTask represents a batch of headers needing fetcher filtering.
+type headerFilterTask struct {
+ peer string // The source peer of block headers
+ headers []*types.Header // Collection of headers to filter
+ time time.Time // Arrival time of the headers
+}
+
+// bodyFilterTask represents a batch of block bodies (transactions and uncles)
+// needing fetcher filtering.
+type bodyFilterTask struct {
+ peer string // The source peer of block bodies
+ transactions [][]*types.Transaction // Collection of transactions per block bodies
+ uncles [][]*types.Header // Collection of uncles per block bodies
+ time time.Time // Arrival time of the blocks' contents
+}
+
+// blockOrHeaderInject represents a schedules import operation.
+type blockOrHeaderInject struct {
+ origin string
+
+ header *types.Header // Used for light mode fetcher which only cares about header.
+ block *types.Block // Used for normal mode fetcher which imports full block.
+}
+
+// number returns the block number of the injected object.
+func (inject *blockOrHeaderInject) number() uint64 {
+ if inject.header != nil {
+ return inject.header.Number.Uint64()
+ }
+ return inject.block.NumberU64()
+}
+
+// number returns the block hash of the injected object.
+func (inject *blockOrHeaderInject) hash() common.Hash {
+ if inject.header != nil {
+ return inject.header.Hash()
+ }
+ return inject.block.Hash()
+}
+
+// BlockFetcher is responsible for accumulating block announcements from various peers
+// and scheduling them for retrieval.
+type BlockFetcher struct {
+ light bool // The indicator whether it's a light fetcher or normal one.
+
+ // Various event channels
+ notify chan *blockAnnounce
+ inject chan *blockOrHeaderInject
+
+ headerFilter chan chan *headerFilterTask
+ bodyFilter chan chan *bodyFilterTask
+
+ done chan common.Hash
+ quit chan struct{}
+
+ // Announce states
+ announces map[string]int // Per peer blockAnnounce counts to prevent memory exhaustion
+ announced map[common.Hash][]*blockAnnounce // Announced blocks, scheduled for fetching
+ fetching map[common.Hash]*blockAnnounce // Announced blocks, currently fetching
+ fetched map[common.Hash][]*blockAnnounce // Blocks with headers fetched, scheduled for body retrieval
+ completing map[common.Hash]*blockAnnounce // Blocks with headers, currently body-completing
+
+ // Block cache
+ queue *prque.Prque // Queue containing the import operations (block number sorted)
+ queues map[string]int // Per peer block counts to prevent memory exhaustion
+ queued map[common.Hash]*blockOrHeaderInject // Set of already queued blocks (to dedup imports)
+
+ // Callbacks
+ getHeader HeaderRetrievalFn // Retrieves a header from the local chain
+ getBlock blockRetrievalFn // Retrieves a block from the local chain
+ verifyHeader headerVerifierFn // Checks if a block's headers have a valid proof of work
+ broadcastBlock blockBroadcasterFn // Broadcasts a block to connected peers
+ chainHeight chainHeightFn // Retrieves the current chain's height
+ insertHeaders headersInsertFn // Injects a batch of headers into the chain
+ insertChain chainInsertFn // Injects a batch of blocks into the chain
+ dropPeer peerDropFn // Drops a peer for misbehaving
+
+ // Testing hooks
+ announceChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a hash from the blockAnnounce list
+ queueChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a block from the import queue
+ fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch
+ completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62)
+ importedHook func(*types.Header, *types.Block) // Method to call upon successful header or block import (both eth/61 and eth/62)
+}
+
+// NewBlockFetcher creates a block fetcher to retrieve blocks based on hash announcements.
+func NewBlockFetcher(light bool, getHeader HeaderRetrievalFn, getBlock blockRetrievalFn, verifyHeader headerVerifierFn, broadcastBlock blockBroadcasterFn, chainHeight chainHeightFn, insertHeaders headersInsertFn, insertChain chainInsertFn, dropPeer peerDropFn) *BlockFetcher {
+ return &BlockFetcher{
+ light: light,
+ notify: make(chan *blockAnnounce),
+ inject: make(chan *blockOrHeaderInject),
+ headerFilter: make(chan chan *headerFilterTask),
+ bodyFilter: make(chan chan *bodyFilterTask),
+ done: make(chan common.Hash),
+ quit: make(chan struct{}),
+ announces: make(map[string]int),
+ announced: make(map[common.Hash][]*blockAnnounce),
+ fetching: make(map[common.Hash]*blockAnnounce),
+ fetched: make(map[common.Hash][]*blockAnnounce),
+ completing: make(map[common.Hash]*blockAnnounce),
+ queue: prque.New(nil),
+ queues: make(map[string]int),
+ queued: make(map[common.Hash]*blockOrHeaderInject),
+ getHeader: getHeader,
+ getBlock: getBlock,
+ verifyHeader: verifyHeader,
+ broadcastBlock: broadcastBlock,
+ chainHeight: chainHeight,
+ insertHeaders: insertHeaders,
+ insertChain: insertChain,
+ dropPeer: dropPeer,
+ }
+}
+
+// Start boots up the announcement based synchroniser, accepting and processing
+// hash notifications and block fetches until termination requested.
+func (f *BlockFetcher) Start() {
+ go f.loop()
+}
+
+// Stop terminates the announcement based synchroniser, canceling all pending
+// operations.
+func (f *BlockFetcher) Stop() {
+ close(f.quit)
+}
+
+// Notify announces the fetcher of the potential availability of a new block in
+// the network.
+func (f *BlockFetcher) Notify(peer string, hash common.Hash, number uint64, time time.Time,
+ headerFetcher headerRequesterFn, bodyFetcher bodyRequesterFn) error {
+ block := &blockAnnounce{
+ hash: hash,
+ number: number,
+ time: time,
+ origin: peer,
+ fetchHeader: headerFetcher,
+ fetchBodies: bodyFetcher,
+ }
+ select {
+ case f.notify <- block:
+ return nil
+ case <-f.quit:
+ return errTerminated
+ }
+}
+
+// Enqueue tries to fill gaps the fetcher's future import queue.
+func (f *BlockFetcher) Enqueue(peer string, block *types.Block) error {
+ op := &blockOrHeaderInject{
+ origin: peer,
+ block: block,
+ }
+ select {
+ case f.inject <- op:
+ return nil
+ case <-f.quit:
+ return errTerminated
+ }
+}
+
+// FilterHeaders extracts all the headers that were explicitly requested by the fetcher,
+// returning those that should be handled differently.
+func (f *BlockFetcher) FilterHeaders(peer string, headers []*types.Header, time time.Time) []*types.Header {
+ log.Trace("Filtering headers", "peer", peer, "headers", len(headers))
+
+ // Send the filter channel to the fetcher
+ filter := make(chan *headerFilterTask)
+
+ select {
+ case f.headerFilter <- filter:
+ case <-f.quit:
+ return nil
+ }
+ // Request the filtering of the header list
+ select {
+ case filter <- &headerFilterTask{peer: peer, headers: headers, time: time}:
+ case <-f.quit:
+ return nil
+ }
+ // Retrieve the headers remaining after filtering
+ select {
+ case task := <-filter:
+ return task.headers
+ case <-f.quit:
+ return nil
+ }
+}
+
+// FilterBodies extracts all the block bodies that were explicitly requested by
+// the fetcher, returning those that should be handled differently.
+func (f *BlockFetcher) FilterBodies(peer string, transactions [][]*types.Transaction, uncles [][]*types.Header, time time.Time) ([][]*types.Transaction, [][]*types.Header) {
+ log.Trace("Filtering bodies", "peer", peer, "txs", len(transactions), "uncles", len(uncles))
+
+ // Send the filter channel to the fetcher
+ filter := make(chan *bodyFilterTask)
+
+ select {
+ case f.bodyFilter <- filter:
+ case <-f.quit:
+ return nil, nil
+ }
+ // Request the filtering of the body list
+ select {
+ case filter <- &bodyFilterTask{peer: peer, transactions: transactions, uncles: uncles, time: time}:
+ case <-f.quit:
+ return nil, nil
+ }
+ // Retrieve the bodies remaining after filtering
+ select {
+ case task := <-filter:
+ return task.transactions, task.uncles
+ case <-f.quit:
+ return nil, nil
+ }
+}
+
+// Loop is the main fetcher loop, checking and processing various notification
+// events.
+func (f *BlockFetcher) loop() {
+ // Iterate the block fetching until a quit is requested
+ var (
+ fetchTimer = time.NewTimer(0)
+ completeTimer = time.NewTimer(0)
+ )
+ <-fetchTimer.C // clear out the channel
+ <-completeTimer.C
+ defer fetchTimer.Stop()
+ defer completeTimer.Stop()
+
+ for {
+ // Clean up any expired block fetches
+ for hash, announce := range f.fetching {
+ if time.Since(announce.time) > fetchTimeout {
+ f.forgetHash(hash)
+ }
+ }
+ // Import any queued blocks that could potentially fit
+ height := f.chainHeight()
+ for !f.queue.Empty() {
+ op := f.queue.PopItem().(*blockOrHeaderInject)
+ hash := op.hash()
+ if f.queueChangeHook != nil {
+ f.queueChangeHook(hash, false)
+ }
+ // If too high up the chain or phase, continue later
+ number := op.number()
+ if number > height+1 {
+ f.queue.Push(op, -int64(number))
+ if f.queueChangeHook != nil {
+ f.queueChangeHook(hash, true)
+ }
+ break
+ }
+ // Otherwise if fresh and still unknown, try and import
+ if (number+maxUncleDist < height) || (f.light && f.getHeader(hash) != nil) || (!f.light && f.getBlock(hash) != nil) {
+ f.forgetBlock(hash)
+ continue
+ }
+ if f.light {
+ f.importHeaders(op.origin, op.header)
+ } else {
+ f.importBlocks(op.origin, op.block)
+ }
+ }
+ // Wait for an outside event to occur
+ select {
+ case <-f.quit:
+ // BlockFetcher terminating, abort all operations
+ return
+
+ case notification := <-f.notify:
+ // A block was announced, make sure the peer isn't DOSing us
+ blockAnnounceInMeter.Mark(1)
+
+ count := f.announces[notification.origin] + 1
+ if count > hashLimit {
+ log.Debug("Peer exceeded outstanding announces", "peer", notification.origin, "limit", hashLimit)
+ blockAnnounceDOSMeter.Mark(1)
+ break
+ }
+ // If we have a valid block number, check that it's potentially useful
+ if notification.number > 0 {
+ if dist := int64(notification.number) - int64(f.chainHeight()); dist < -maxUncleDist || dist > maxQueueDist {
+ log.Debug("Peer discarded announcement", "peer", notification.origin, "number", notification.number, "hash", notification.hash, "distance", dist)
+ blockAnnounceDropMeter.Mark(1)
+ break
+ }
+ }
+ // All is well, schedule the announce if block's not yet downloading
+ if _, ok := f.fetching[notification.hash]; ok {
+ break
+ }
+ if _, ok := f.completing[notification.hash]; ok {
+ break
+ }
+ f.announces[notification.origin] = count
+ f.announced[notification.hash] = append(f.announced[notification.hash], notification)
+ if f.announceChangeHook != nil && len(f.announced[notification.hash]) == 1 {
+ f.announceChangeHook(notification.hash, true)
+ }
+ if len(f.announced) == 1 {
+ f.rescheduleFetch(fetchTimer)
+ }
+
+ case op := <-f.inject:
+ // A direct block insertion was requested, try and fill any pending gaps
+ blockBroadcastInMeter.Mark(1)
+
+ // Now only direct block injection is allowed, drop the header injection
+ // here silently if we receive.
+ if f.light {
+ continue
+ }
+ f.enqueue(op.origin, nil, op.block)
+
+ case hash := <-f.done:
+ // A pending import finished, remove all traces of the notification
+ f.forgetHash(hash)
+ f.forgetBlock(hash)
+
+ case <-fetchTimer.C:
+ // At least one block's timer ran out, check for needing retrieval
+ request := make(map[string][]common.Hash)
+
+ for hash, announces := range f.announced {
+ // In current LES protocol(les2/les3), only header announce is
+ // available, no need to wait too much time for header broadcast.
+ timeout := arriveTimeout - gatherSlack
+ if f.light {
+ timeout = 0
+ }
+ if time.Since(announces[0].time) > timeout {
+ // Pick a random peer to retrieve from, reset all others
+ announce := announces[rand.Intn(len(announces))]
+ f.forgetHash(hash)
+
+ // If the block still didn't arrive, queue for fetching
+ if (f.light && f.getHeader(hash) == nil) || (!f.light && f.getBlock(hash) == nil) {
+ request[announce.origin] = append(request[announce.origin], hash)
+ f.fetching[hash] = announce
+ }
+ }
+ }
+ // Send out all block header requests
+ for peer, hashes := range request {
+ log.Trace("Fetching scheduled headers", "peer", peer, "list", hashes)
+
+ // Create a closure of the fetch and schedule in on a new thread
+ fetchHeader, hashes := f.fetching[hashes[0]].fetchHeader, hashes
+ go func() {
+ if f.fetchingHook != nil {
+ f.fetchingHook(hashes)
+ }
+ for _, hash := range hashes {
+ headerFetchMeter.Mark(1)
+ fetchHeader(hash) // Suboptimal, but protocol doesn't allow batch header retrievals
+ }
+ }()
+ }
+ // Schedule the next fetch if blocks are still pending
+ f.rescheduleFetch(fetchTimer)
+
+ case <-completeTimer.C:
+ // At least one header's timer ran out, retrieve everything
+ request := make(map[string][]common.Hash)
+
+ for hash, announces := range f.fetched {
+ // Pick a random peer to retrieve from, reset all others
+ announce := announces[rand.Intn(len(announces))]
+ f.forgetHash(hash)
+
+ // If the block still didn't arrive, queue for completion
+ if f.getBlock(hash) == nil {
+ request[announce.origin] = append(request[announce.origin], hash)
+ f.completing[hash] = announce
+ }
+ }
+ // Send out all block body requests
+ for peer, hashes := range request {
+ log.Trace("Fetching scheduled bodies", "peer", peer, "list", hashes)
+
+ // Create a closure of the fetch and schedule in on a new thread
+ if f.completingHook != nil {
+ f.completingHook(hashes)
+ }
+ bodyFetchMeter.Mark(int64(len(hashes)))
+ go f.completing[hashes[0]].fetchBodies(hashes)
+ }
+ // Schedule the next fetch if blocks are still pending
+ f.rescheduleComplete(completeTimer)
+
+ case filter := <-f.headerFilter:
+ // Headers arrived from a remote peer. Extract those that were explicitly
+ // requested by the fetcher, and return everything else so it's delivered
+ // to other parts of the system.
+ var task *headerFilterTask
+ select {
+ case task = <-filter:
+ case <-f.quit:
+ return
+ }
+ headerFilterInMeter.Mark(int64(len(task.headers)))
+
+ // Split the batch of headers into unknown ones (to return to the caller),
+ // known incomplete ones (requiring body retrievals) and completed blocks.
+ unknown, incomplete, complete, lightHeaders := []*types.Header{}, []*blockAnnounce{}, []*types.Block{}, []*blockAnnounce{}
+ for _, header := range task.headers {
+ hash := header.Hash()
+
+ // Filter fetcher-requested headers from other synchronisation algorithms
+ if announce := f.fetching[hash]; announce != nil && announce.origin == task.peer && f.fetched[hash] == nil && f.completing[hash] == nil && f.queued[hash] == nil {
+ // If the delivered header does not match the promised number, drop the announcer
+ if header.Number.Uint64() != announce.number {
+ log.Trace("Invalid block number fetched", "peer", announce.origin, "hash", header.Hash(), "announced", announce.number, "provided", header.Number)
+ f.dropPeer(announce.origin)
+ f.forgetHash(hash)
+ continue
+ }
+ // Collect all headers only if we are running in light
+ // mode and the headers are not imported by other means.
+ if f.light {
+ if f.getHeader(hash) == nil {
+ announce.header = header
+ lightHeaders = append(lightHeaders, announce)
+ }
+ f.forgetHash(hash)
+ continue
+ }
+ // Only keep if not imported by other means
+ if f.getBlock(hash) == nil {
+ announce.header = header
+ announce.time = task.time
+
+ // If the block is empty (header only), short circuit into the final import queue
+ if header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash {
+ log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash())
+
+ block := types.NewBlockWithHeader(header)
+ block.ReceivedAt = task.time
+
+ complete = append(complete, block)
+ f.completing[hash] = announce
+ continue
+ }
+ // Otherwise add to the list of blocks needing completion
+ incomplete = append(incomplete, announce)
+ } else {
+ log.Trace("Block already imported, discarding header", "peer", announce.origin, "number", header.Number, "hash", header.Hash())
+ f.forgetHash(hash)
+ }
+ } else {
+ // BlockFetcher doesn't know about it, add to the return list
+ unknown = append(unknown, header)
+ }
+ }
+ headerFilterOutMeter.Mark(int64(len(unknown)))
+ select {
+ case filter <- &headerFilterTask{headers: unknown, time: task.time}:
+ case <-f.quit:
+ return
+ }
+ // Schedule the retrieved headers for body completion
+ for _, announce := range incomplete {
+ hash := announce.header.Hash()
+ if _, ok := f.completing[hash]; ok {
+ continue
+ }
+ f.fetched[hash] = append(f.fetched[hash], announce)
+ if len(f.fetched) == 1 {
+ f.rescheduleComplete(completeTimer)
+ }
+ }
+ // Schedule the header for light fetcher import
+ for _, announce := range lightHeaders {
+ f.enqueue(announce.origin, announce.header, nil)
+ }
+ // Schedule the header-only blocks for import
+ for _, block := range complete {
+ if announce := f.completing[block.Hash()]; announce != nil {
+ f.enqueue(announce.origin, nil, block)
+ }
+ }
+
+ case filter := <-f.bodyFilter:
+ // Block bodies arrived, extract any explicitly requested blocks, return the rest
+ var task *bodyFilterTask
+ select {
+ case task = <-filter:
+ case <-f.quit:
+ return
+ }
+ bodyFilterInMeter.Mark(int64(len(task.transactions)))
+ blocks := []*types.Block{}
+ // abort early if there's nothing explicitly requested
+ if len(f.completing) > 0 {
+ for i := 0; i < len(task.transactions) && i < len(task.uncles); i++ {
+ // Match up a body to any possible completion request
+ var (
+ matched = false
+ uncleHash common.Hash // calculated lazily and reused
+ txnHash common.Hash // calculated lazily and reused
+ )
+ for hash, announce := range f.completing {
+ if f.queued[hash] != nil || announce.origin != task.peer {
+ continue
+ }
+ if uncleHash == (common.Hash{}) {
+ uncleHash = types.CalcUncleHash(task.uncles[i])
+ }
+ if uncleHash != announce.header.UncleHash {
+ continue
+ }
+ if txnHash == (common.Hash{}) {
+ txnHash = types.DeriveSha(types.Transactions(task.transactions[i]), trie.NewStackTrie(nil))
+ }
+ if txnHash != announce.header.TxHash {
+ continue
+ }
+ // Mark the body matched, reassemble if still unknown
+ matched = true
+ if f.getBlock(hash) == nil {
+ block := types.NewBlockWithHeader(announce.header).WithBody(task.transactions[i], task.uncles[i])
+ block.ReceivedAt = task.time
+ blocks = append(blocks, block)
+ } else {
+ f.forgetHash(hash)
+ }
+
+ }
+ if matched {
+ task.transactions = append(task.transactions[:i], task.transactions[i+1:]...)
+ task.uncles = append(task.uncles[:i], task.uncles[i+1:]...)
+ i--
+ continue
+ }
+ }
+ }
+ bodyFilterOutMeter.Mark(int64(len(task.transactions)))
+ select {
+ case filter <- task:
+ case <-f.quit:
+ return
+ }
+ // Schedule the retrieved blocks for ordered import
+ for _, block := range blocks {
+ if announce := f.completing[block.Hash()]; announce != nil {
+ f.enqueue(announce.origin, nil, block)
+ }
+ }
+ }
+ }
+}
+
+// rescheduleFetch resets the specified fetch timer to the next blockAnnounce timeout.
+func (f *BlockFetcher) rescheduleFetch(fetch *time.Timer) {
+ // Short circuit if no blocks are announced
+ if len(f.announced) == 0 {
+ return
+ }
+ // Schedule announcement retrieval quickly for light mode
+ // since server won't send any headers to client.
+ if f.light {
+ fetch.Reset(lightTimeout)
+ return
+ }
+ // Otherwise find the earliest expiring announcement
+ earliest := time.Now()
+ for _, announces := range f.announced {
+ if earliest.After(announces[0].time) {
+ earliest = announces[0].time
+ }
+ }
+ fetch.Reset(arriveTimeout - time.Since(earliest))
+}
+
+// rescheduleComplete resets the specified completion timer to the next fetch timeout.
+func (f *BlockFetcher) rescheduleComplete(complete *time.Timer) {
+ // Short circuit if no headers are fetched
+ if len(f.fetched) == 0 {
+ return
+ }
+ // Otherwise find the earliest expiring announcement
+ earliest := time.Now()
+ for _, announces := range f.fetched {
+ if earliest.After(announces[0].time) {
+ earliest = announces[0].time
+ }
+ }
+ complete.Reset(gatherSlack - time.Since(earliest))
+}
+
+// enqueue schedules a new header or block import operation, if the component
+// to be imported has not yet been seen.
+func (f *BlockFetcher) enqueue(peer string, header *types.Header, block *types.Block) {
+ var (
+ hash common.Hash
+ number uint64
+ )
+ if header != nil {
+ hash, number = header.Hash(), header.Number.Uint64()
+ } else {
+ hash, number = block.Hash(), block.NumberU64()
+ }
+ // Ensure the peer isn't DOSing us
+ count := f.queues[peer] + 1
+ if count > blockLimit {
+ log.Debug("Discarded delivered header or block, exceeded allowance", "peer", peer, "number", number, "hash", hash, "limit", blockLimit)
+ blockBroadcastDOSMeter.Mark(1)
+ f.forgetHash(hash)
+ return
+ }
+ // Discard any past or too distant blocks
+ if dist := int64(number) - int64(f.chainHeight()); dist < -maxUncleDist || dist > maxQueueDist {
+ log.Debug("Discarded delivered header or block, too far away", "peer", peer, "number", number, "hash", hash, "distance", dist)
+ blockBroadcastDropMeter.Mark(1)
+ f.forgetHash(hash)
+ return
+ }
+ // Schedule the block for future importing
+ if _, ok := f.queued[hash]; !ok {
+ op := &blockOrHeaderInject{origin: peer}
+ if header != nil {
+ op.header = header
+ } else {
+ op.block = block
+ }
+ f.queues[peer] = count
+ f.queued[hash] = op
+ f.queue.Push(op, -int64(number))
+ if f.queueChangeHook != nil {
+ f.queueChangeHook(hash, true)
+ }
+ log.Debug("Queued delivered header or block", "peer", peer, "number", number, "hash", hash, "queued", f.queue.Size())
+ }
+}
+
+// importHeaders spawns a new goroutine to run a header insertion into the chain.
+// If the header's number is at the same height as the current import phase, it
+// updates the phase states accordingly.
+func (f *BlockFetcher) importHeaders(peer string, header *types.Header) {
+ hash := header.Hash()
+ log.Debug("Importing propagated header", "peer", peer, "number", header.Number, "hash", hash)
+
+ go func() {
+ defer func() { f.done <- hash }()
+ // If the parent's unknown, abort insertion
+ parent := f.getHeader(header.ParentHash)
+ if parent == nil {
+ log.Debug("Unknown parent of propagated header", "peer", peer, "number", header.Number, "hash", hash, "parent", header.ParentHash)
+ return
+ }
+ // Validate the header and if something went wrong, drop the peer
+ if err := f.verifyHeader(header); err != nil && err != consensus.ErrFutureBlock {
+ log.Debug("Propagated header verification failed", "peer", peer, "number", header.Number, "hash", hash, "err", err)
+ f.dropPeer(peer)
+ return
+ }
+ // Run the actual import and log any issues
+ if _, err := f.insertHeaders([]*types.Header{header}); err != nil {
+ log.Debug("Propagated header import failed", "peer", peer, "number", header.Number, "hash", hash, "err", err)
+ return
+ }
+ // Invoke the testing hook if needed
+ if f.importedHook != nil {
+ f.importedHook(header, nil)
+ }
+ }()
+}
+
+// importBlocks spawns a new goroutine to run a block insertion into the chain. If the
+// block's number is at the same height as the current import phase, it updates
+// the phase states accordingly.
+func (f *BlockFetcher) importBlocks(peer string, block *types.Block) {
+ hash := block.Hash()
+
+ // Run the import on a new thread
+ log.Debug("Importing propagated block", "peer", peer, "number", block.Number(), "hash", hash)
+ go func() {
+ defer func() { f.done <- hash }()
+
+ // If the parent's unknown, abort insertion
+ parent := f.getBlock(block.ParentHash())
+ if parent == nil {
+ log.Debug("Unknown parent of propagated block", "peer", peer, "number", block.Number(), "hash", hash, "parent", block.ParentHash())
+ return
+ }
+ // Quickly validate the header and propagate the block if it passes
+ switch err := f.verifyHeader(block.Header()); err {
+ case nil:
+ // All ok, quickly propagate to our peers
+ blockBroadcastOutTimer.UpdateSince(block.ReceivedAt)
+ go f.broadcastBlock(block, true)
+
+ case consensus.ErrFutureBlock:
+ // Weird future block, don't fail, but neither propagate
+
+ default:
+ // Something went very wrong, drop the peer
+ log.Debug("Propagated block verification failed", "peer", peer, "number", block.Number(), "hash", hash, "err", err)
+ f.dropPeer(peer)
+ return
+ }
+ // Run the actual import and log any issues
+ if _, err := f.insertChain(types.Blocks{block}); err != nil {
+ log.Debug("Propagated block import failed", "peer", peer, "number", block.Number(), "hash", hash, "err", err)
+ return
+ }
+ // If import succeeded, broadcast the block
+ blockAnnounceOutTimer.UpdateSince(block.ReceivedAt)
+ go f.broadcastBlock(block, false)
+
+ // Invoke the testing hook if needed
+ if f.importedHook != nil {
+ f.importedHook(nil, block)
+ }
+ }()
+}
+
+// forgetHash removes all traces of a block announcement from the fetcher's
+// internal state.
+func (f *BlockFetcher) forgetHash(hash common.Hash) {
+ // Remove all pending announces and decrement DOS counters
+ if announceMap, ok := f.announced[hash]; ok {
+ for _, announce := range announceMap {
+ f.announces[announce.origin]--
+ if f.announces[announce.origin] <= 0 {
+ delete(f.announces, announce.origin)
+ }
+ }
+ delete(f.announced, hash)
+ if f.announceChangeHook != nil {
+ f.announceChangeHook(hash, false)
+ }
+ }
+ // Remove any pending fetches and decrement the DOS counters
+ if announce := f.fetching[hash]; announce != nil {
+ f.announces[announce.origin]--
+ if f.announces[announce.origin] <= 0 {
+ delete(f.announces, announce.origin)
+ }
+ delete(f.fetching, hash)
+ }
+
+ // Remove any pending completion requests and decrement the DOS counters
+ for _, announce := range f.fetched[hash] {
+ f.announces[announce.origin]--
+ if f.announces[announce.origin] <= 0 {
+ delete(f.announces, announce.origin)
+ }
+ }
+ delete(f.fetched, hash)
+
+ // Remove any pending completions and decrement the DOS counters
+ if announce := f.completing[hash]; announce != nil {
+ f.announces[announce.origin]--
+ if f.announces[announce.origin] <= 0 {
+ delete(f.announces, announce.origin)
+ }
+ delete(f.completing, hash)
+ }
+}
+
+// forgetBlock removes all traces of a queued block from the fetcher's internal
+// state.
+func (f *BlockFetcher) forgetBlock(hash common.Hash) {
+ if insert := f.queued[hash]; insert != nil {
+ f.queues[insert.origin]--
+ if f.queues[insert.origin] == 0 {
+ delete(f.queues, insert.origin)
+ }
+ delete(f.queued, hash)
+ }
+}
diff --git a/les/fetcher/block_fetcher_test.go b/les/fetcher/block_fetcher_test.go
new file mode 100644
index 0000000000..b6d1125b56
--- /dev/null
+++ b/les/fetcher/block_fetcher_test.go
@@ -0,0 +1,896 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package fetcher
+
+import (
+ "errors"
+ "math/big"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/trie"
+)
+
+var (
+ testdb = rawdb.NewMemoryDatabase()
+ testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
+ testAddress = crypto.PubkeyToAddress(testKey.PublicKey)
+ genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000000000))
+ unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit, BaseFee: big.NewInt(params.InitialBaseFee)}, nil, nil, nil, trie.NewStackTrie(nil))
+)
+
+// makeChain creates a chain of n blocks starting at and including parent.
+// the returned hash chain is ordered head->parent. In addition, every 3rd block
+// contains a transaction and every 5th an uncle to allow testing correct block
+// reassembly.
+func makeChain(n int, seed byte, parent *types.Block) ([]common.Hash, map[common.Hash]*types.Block) {
+ blocks, _ := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testdb, n, func(i int, block *core.BlockGen) {
+ block.SetCoinbase(common.Address{seed})
+
+ // If the block number is multiple of 3, send a bonus transaction to the miner
+ if parent == genesis && i%3 == 0 {
+ signer := types.MakeSigner(params.TestChainConfig, block.Number())
+ tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey)
+ if err != nil {
+ panic(err)
+ }
+ block.AddTx(tx)
+ }
+ // If the block number is a multiple of 5, add a bonus uncle to the block
+ if i%5 == 0 {
+ block.AddUncle(&types.Header{ParentHash: block.PrevBlock(i - 1).Hash(), Number: big.NewInt(int64(i - 1))})
+ }
+ })
+ hashes := make([]common.Hash, n+1)
+ hashes[len(hashes)-1] = parent.Hash()
+ blockm := make(map[common.Hash]*types.Block, n+1)
+ blockm[parent.Hash()] = parent
+ for i, b := range blocks {
+ hashes[len(hashes)-i-2] = b.Hash()
+ blockm[b.Hash()] = b
+ }
+ return hashes, blockm
+}
+
+// fetcherTester is a test simulator for mocking out local block chain.
+type fetcherTester struct {
+ fetcher *BlockFetcher
+
+ hashes []common.Hash // Hash chain belonging to the tester
+ headers map[common.Hash]*types.Header // Headers belonging to the tester
+ blocks map[common.Hash]*types.Block // Blocks belonging to the tester
+ drops map[string]bool // Map of peers dropped by the fetcher
+
+ lock sync.RWMutex
+}
+
+// newTester creates a new fetcher test mocker.
+func newTester(light bool) *fetcherTester {
+ tester := &fetcherTester{
+ hashes: []common.Hash{genesis.Hash()},
+ headers: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()},
+ blocks: map[common.Hash]*types.Block{genesis.Hash(): genesis},
+ drops: make(map[string]bool),
+ }
+ tester.fetcher = NewBlockFetcher(light, tester.getHeader, tester.getBlock, tester.verifyHeader, tester.broadcastBlock, tester.chainHeight, tester.insertHeaders, tester.insertChain, tester.dropPeer)
+ tester.fetcher.Start()
+
+ return tester
+}
+
+// getHeader retrieves a header from the tester's block chain.
+func (f *fetcherTester) getHeader(hash common.Hash) *types.Header {
+ f.lock.RLock()
+ defer f.lock.RUnlock()
+
+ return f.headers[hash]
+}
+
+// getBlock retrieves a block from the tester's block chain.
+func (f *fetcherTester) getBlock(hash common.Hash) *types.Block {
+ f.lock.RLock()
+ defer f.lock.RUnlock()
+
+ return f.blocks[hash]
+}
+
+// verifyHeader is a nop placeholder for the block header verification.
+func (f *fetcherTester) verifyHeader(header *types.Header) error {
+ return nil
+}
+
+// broadcastBlock is a nop placeholder for the block broadcasting.
+func (f *fetcherTester) broadcastBlock(block *types.Block, propagate bool) {
+}
+
+// chainHeight retrieves the current height (block number) of the chain.
+func (f *fetcherTester) chainHeight() uint64 {
+ f.lock.RLock()
+ defer f.lock.RUnlock()
+
+ if f.fetcher.light {
+ return f.headers[f.hashes[len(f.hashes)-1]].Number.Uint64()
+ }
+ return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64()
+}
+
+// insertChain injects a new headers into the simulated chain.
+func (f *fetcherTester) insertHeaders(headers []*types.Header) (int, error) {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ for i, header := range headers {
+ // Make sure the parent in known
+ if _, ok := f.headers[header.ParentHash]; !ok {
+ return i, errors.New("unknown parent")
+ }
+ // Discard any new blocks if the same height already exists
+ if header.Number.Uint64() <= f.headers[f.hashes[len(f.hashes)-1]].Number.Uint64() {
+ return i, nil
+ }
+ // Otherwise build our current chain
+ f.hashes = append(f.hashes, header.Hash())
+ f.headers[header.Hash()] = header
+ }
+ return 0, nil
+}
+
+// insertChain injects a new blocks into the simulated chain.
+func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ for i, block := range blocks {
+ // Make sure the parent in known
+ if _, ok := f.blocks[block.ParentHash()]; !ok {
+ return i, errors.New("unknown parent")
+ }
+ // Discard any new blocks if the same height already exists
+ if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() {
+ return i, nil
+ }
+ // Otherwise build our current chain
+ f.hashes = append(f.hashes, block.Hash())
+ f.blocks[block.Hash()] = block
+ }
+ return 0, nil
+}
+
+// dropPeer is an emulator for the peer removal, simply accumulating the various
+// peers dropped by the fetcher.
+func (f *fetcherTester) dropPeer(peer string) {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+
+ f.drops[peer] = true
+}
+
+// makeHeaderFetcher retrieves a block header fetcher associated with a simulated peer.
+func (f *fetcherTester) makeHeaderFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn {
+ closure := make(map[common.Hash]*types.Block)
+ for hash, block := range blocks {
+ closure[hash] = block
+ }
+ // Create a function that return a header from the closure
+ return func(hash common.Hash) error {
+ // Gather the blocks to return
+ headers := make([]*types.Header, 0, 1)
+ if block, ok := closure[hash]; ok {
+ headers = append(headers, block.Header())
+ }
+ // Return on a new thread
+ go f.fetcher.FilterHeaders(peer, headers, time.Now().Add(drift))
+
+ return nil
+ }
+}
+
+// makeBodyFetcher retrieves a block body fetcher associated with a simulated peer.
+func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn {
+ closure := make(map[common.Hash]*types.Block)
+ for hash, block := range blocks {
+ closure[hash] = block
+ }
+ // Create a function that returns blocks from the closure
+ return func(hashes []common.Hash) error {
+ // Gather the block bodies to return
+ transactions := make([][]*types.Transaction, 0, len(hashes))
+ uncles := make([][]*types.Header, 0, len(hashes))
+
+ for _, hash := range hashes {
+ if block, ok := closure[hash]; ok {
+ transactions = append(transactions, block.Transactions())
+ uncles = append(uncles, block.Uncles())
+ }
+ }
+ // Return on a new thread
+ go f.fetcher.FilterBodies(peer, transactions, uncles, time.Now().Add(drift))
+
+ return nil
+ }
+}
+
+// verifyFetchingEvent verifies that one single event arrive on a fetching channel.
+func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) {
+ if arrive {
+ select {
+ case <-fetching:
+ case <-time.After(time.Second):
+ t.Fatalf("fetching timeout")
+ }
+ } else {
+ select {
+ case <-fetching:
+ t.Fatalf("fetching invoked")
+ case <-time.After(10 * time.Millisecond):
+ }
+ }
+}
+
+// verifyCompletingEvent verifies that one single event arrive on an completing channel.
+func verifyCompletingEvent(t *testing.T, completing chan []common.Hash, arrive bool) {
+ if arrive {
+ select {
+ case <-completing:
+ case <-time.After(time.Second):
+ t.Fatalf("completing timeout")
+ }
+ } else {
+ select {
+ case <-completing:
+ t.Fatalf("completing invoked")
+ case <-time.After(10 * time.Millisecond):
+ }
+ }
+}
+
+// verifyImportEvent verifies that one single event arrive on an import channel.
+func verifyImportEvent(t *testing.T, imported chan interface{}, arrive bool) {
+ if arrive {
+ select {
+ case <-imported:
+ case <-time.After(time.Second):
+ t.Fatalf("import timeout")
+ }
+ } else {
+ select {
+ case <-imported:
+ t.Fatalf("import invoked")
+ case <-time.After(20 * time.Millisecond):
+ }
+ }
+}
+
+// verifyImportCount verifies that exactly count number of events arrive on an
+// import hook channel.
+func verifyImportCount(t *testing.T, imported chan interface{}, count int) {
+ for i := 0; i < count; i++ {
+ select {
+ case <-imported:
+ case <-time.After(time.Second):
+ t.Fatalf("block %d: import timeout", i+1)
+ }
+ }
+ verifyImportDone(t, imported)
+}
+
+// verifyImportDone verifies that no more events are arriving on an import channel.
+func verifyImportDone(t *testing.T, imported chan interface{}) {
+ select {
+ case <-imported:
+ t.Fatalf("extra block imported")
+ case <-time.After(50 * time.Millisecond):
+ }
+}
+
+// verifyChainHeight verifies the chain height is as expected.
+func verifyChainHeight(t *testing.T, fetcher *fetcherTester, height uint64) {
+ if fetcher.chainHeight() != height {
+ t.Fatalf("chain height mismatch, got %d, want %d", fetcher.chainHeight(), height)
+ }
+}
+
+// Tests that a fetcher accepts block/header announcements and initiates retrievals
+// for them, successfully importing into the local chain.
+func TestFullSequentialAnnouncements(t *testing.T) { testSequentialAnnouncements(t, false) }
+func TestLightSequentialAnnouncements(t *testing.T) { testSequentialAnnouncements(t, true) }
+
+func testSequentialAnnouncements(t *testing.T, light bool) {
+ // Create a chain of blocks to import
+ targetBlocks := 4 * hashLimit
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+
+ tester := newTester(light)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ // Iteratively announce blocks until all are imported
+ imported := make(chan interface{})
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if light {
+ if header == nil {
+ t.Fatalf("Fetcher try to import empty header")
+ }
+ imported <- header
+ } else {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ }
+ for i := len(hashes) - 2; i >= 0; i-- {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ verifyImportEvent(t, imported, true)
+ }
+ verifyImportDone(t, imported)
+ verifyChainHeight(t, tester, uint64(len(hashes)-1))
+}
+
+// Tests that if blocks are announced by multiple peers (or even the same buggy
+// peer), they will only get downloaded at most once.
+func TestFullConcurrentAnnouncements(t *testing.T) { testConcurrentAnnouncements(t, false) }
+func TestLightConcurrentAnnouncements(t *testing.T) { testConcurrentAnnouncements(t, true) }
+
+func testConcurrentAnnouncements(t *testing.T, light bool) {
+ // Create a chain of blocks to import
+ targetBlocks := 4 * hashLimit
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+
+ // Assemble a tester with a built in counter for the requests
+ tester := newTester(light)
+ firstHeaderFetcher := tester.makeHeaderFetcher("first", blocks, -gatherSlack)
+ firstBodyFetcher := tester.makeBodyFetcher("first", blocks, 0)
+ secondHeaderFetcher := tester.makeHeaderFetcher("second", blocks, -gatherSlack)
+ secondBodyFetcher := tester.makeBodyFetcher("second", blocks, 0)
+
+ counter := uint32(0)
+ firstHeaderWrapper := func(hash common.Hash) error {
+ atomic.AddUint32(&counter, 1)
+ return firstHeaderFetcher(hash)
+ }
+ secondHeaderWrapper := func(hash common.Hash) error {
+ atomic.AddUint32(&counter, 1)
+ return secondHeaderFetcher(hash)
+ }
+ // Iteratively announce blocks until all are imported
+ imported := make(chan interface{})
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if light {
+ if header == nil {
+ t.Fatalf("Fetcher try to import empty header")
+ }
+ imported <- header
+ } else {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ }
+ for i := len(hashes) - 2; i >= 0; i-- {
+ tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher)
+ tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
+ tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
+ verifyImportEvent(t, imported, true)
+ }
+ verifyImportDone(t, imported)
+
+ // Make sure no blocks were retrieved twice
+ if int(counter) != targetBlocks {
+ t.Fatalf("retrieval count mismatch: have %v, want %v", counter, targetBlocks)
+ }
+ verifyChainHeight(t, tester, uint64(len(hashes)-1))
+}
+
+// Tests that announcements arriving while a previous is being fetched still
+// results in a valid import.
+func TestFullOverlappingAnnouncements(t *testing.T) { testOverlappingAnnouncements(t, false) }
+func TestLightOverlappingAnnouncements(t *testing.T) { testOverlappingAnnouncements(t, true) }
+
+func testOverlappingAnnouncements(t *testing.T, light bool) {
+ // Create a chain of blocks to import
+ targetBlocks := 4 * hashLimit
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+
+ tester := newTester(light)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ // Iteratively announce blocks, but overlap them continuously
+ overlap := 16
+ imported := make(chan interface{}, len(hashes)-1)
+ for i := 0; i < overlap; i++ {
+ imported <- nil
+ }
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if light {
+ if header == nil {
+ t.Fatalf("Fetcher try to import empty header")
+ }
+ imported <- header
+ } else {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ }
+
+ for i := len(hashes) - 2; i >= 0; i-- {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ select {
+ case <-imported:
+ case <-time.After(time.Second):
+ t.Fatalf("block %d: import timeout", len(hashes)-i)
+ }
+ }
+ // Wait for all the imports to complete and check count
+ verifyImportCount(t, imported, overlap)
+ verifyChainHeight(t, tester, uint64(len(hashes)-1))
+}
+
+// Tests that announces already being retrieved will not be duplicated.
+func TestFullPendingDeduplication(t *testing.T) { testPendingDeduplication(t, false) }
+func TestLightPendingDeduplication(t *testing.T) { testPendingDeduplication(t, true) }
+
+func testPendingDeduplication(t *testing.T, light bool) {
+ // Create a hash and corresponding block
+ hashes, blocks := makeChain(1, 0, genesis)
+
+ // Assemble a tester with a built in counter and delayed fetcher
+ tester := newTester(light)
+ headerFetcher := tester.makeHeaderFetcher("repeater", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("repeater", blocks, 0)
+
+ delay := 50 * time.Millisecond
+ counter := uint32(0)
+ headerWrapper := func(hash common.Hash) error {
+ atomic.AddUint32(&counter, 1)
+
+ // Simulate a long running fetch
+ go func() {
+ time.Sleep(delay)
+ headerFetcher(hash)
+ }()
+ return nil
+ }
+ checkNonExist := func() bool {
+ return tester.getBlock(hashes[0]) == nil
+ }
+ if light {
+ checkNonExist = func() bool {
+ return tester.getHeader(hashes[0]) == nil
+ }
+ }
+ // Announce the same block many times until it's fetched (wait for any pending ops)
+ for checkNonExist() {
+ tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher)
+ time.Sleep(time.Millisecond)
+ }
+ time.Sleep(delay)
+
+ // Check that all blocks were imported and none fetched twice
+ if int(counter) != 1 {
+ t.Fatalf("retrieval count mismatch: have %v, want %v", counter, 1)
+ }
+ verifyChainHeight(t, tester, 1)
+}
+
+// Tests that announcements retrieved in a random order are cached and eventually
+// imported when all the gaps are filled in.
+func TestFullRandomArrivalImport(t *testing.T) { testRandomArrivalImport(t, false) }
+func TestLightRandomArrivalImport(t *testing.T) { testRandomArrivalImport(t, true) }
+
+func testRandomArrivalImport(t *testing.T, light bool) {
+ // Create a chain of blocks to import, and choose one to delay
+ targetBlocks := maxQueueDist
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+ skip := targetBlocks / 2
+
+ tester := newTester(light)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ // Iteratively announce blocks, skipping one entry
+ imported := make(chan interface{}, len(hashes)-1)
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if light {
+ if header == nil {
+ t.Fatalf("Fetcher try to import empty header")
+ }
+ imported <- header
+ } else {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ }
+ for i := len(hashes) - 1; i >= 0; i-- {
+ if i != skip {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ time.Sleep(time.Millisecond)
+ }
+ }
+ // Finally announce the skipped entry and check full import
+ tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ verifyImportCount(t, imported, len(hashes)-1)
+ verifyChainHeight(t, tester, uint64(len(hashes)-1))
+}
+
+// Tests that direct block enqueues (due to block propagation vs. hash announce)
+// are correctly schedule, filling and import queue gaps.
+func TestQueueGapFill(t *testing.T) {
+ // Create a chain of blocks to import, and choose one to not announce at all
+ targetBlocks := maxQueueDist
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+ skip := targetBlocks / 2
+
+ tester := newTester(false)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ // Iteratively announce blocks, skipping one entry
+ imported := make(chan interface{}, len(hashes)-1)
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block }
+
+ for i := len(hashes) - 1; i >= 0; i-- {
+ if i != skip {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ time.Sleep(time.Millisecond)
+ }
+ }
+ // Fill the missing block directly as if propagated
+ tester.fetcher.Enqueue("valid", blocks[hashes[skip]])
+ verifyImportCount(t, imported, len(hashes)-1)
+ verifyChainHeight(t, tester, uint64(len(hashes)-1))
+}
+
+// Tests that blocks arriving from various sources (multiple propagations, hash
+// announces, etc) do not get scheduled for import multiple times.
+func TestImportDeduplication(t *testing.T) {
+ // Create two blocks to import (one for duplication, the other for stalling)
+ hashes, blocks := makeChain(2, 0, genesis)
+
+ // Create the tester and wrap the importer with a counter
+ tester := newTester(false)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ counter := uint32(0)
+ tester.fetcher.insertChain = func(blocks types.Blocks) (int, error) {
+ atomic.AddUint32(&counter, uint32(len(blocks)))
+ return tester.insertChain(blocks)
+ }
+ // Instrument the fetching and imported events
+ fetching := make(chan []common.Hash)
+ imported := make(chan interface{}, len(hashes)-1)
+ tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes }
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block }
+
+ // Announce the duplicating block, wait for retrieval, and also propagate directly
+ tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ <-fetching
+
+ tester.fetcher.Enqueue("valid", blocks[hashes[0]])
+ tester.fetcher.Enqueue("valid", blocks[hashes[0]])
+ tester.fetcher.Enqueue("valid", blocks[hashes[0]])
+
+ // Fill the missing block directly as if propagated, and check import uniqueness
+ tester.fetcher.Enqueue("valid", blocks[hashes[1]])
+ verifyImportCount(t, imported, 2)
+
+ if counter != 2 {
+ t.Fatalf("import invocation count mismatch: have %v, want %v", counter, 2)
+ }
+}
+
+// Tests that blocks with numbers much lower or higher than out current head get
+// discarded to prevent wasting resources on useless blocks from faulty peers.
+func TestDistantPropagationDiscarding(t *testing.T) {
+ // Create a long chain to import and define the discard boundaries
+ hashes, blocks := makeChain(3*maxQueueDist, 0, genesis)
+ head := hashes[len(hashes)/2]
+
+ low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1
+
+ // Create a tester and simulate a head block being the middle of the above chain
+ tester := newTester(false)
+
+ tester.lock.Lock()
+ tester.hashes = []common.Hash{head}
+ tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
+ tester.lock.Unlock()
+
+ // Ensure that a block with a lower number than the threshold is discarded
+ tester.fetcher.Enqueue("lower", blocks[hashes[low]])
+ time.Sleep(10 * time.Millisecond)
+ if !tester.fetcher.queue.Empty() {
+ t.Fatalf("fetcher queued stale block")
+ }
+ // Ensure that a block with a higher number than the threshold is discarded
+ tester.fetcher.Enqueue("higher", blocks[hashes[high]])
+ time.Sleep(10 * time.Millisecond)
+ if !tester.fetcher.queue.Empty() {
+ t.Fatalf("fetcher queued future block")
+ }
+}
+
+// Tests that announcements with numbers much lower or higher than out current
+// head get discarded to prevent wasting resources on useless blocks from faulty
+// peers.
+func TestFullDistantAnnouncementDiscarding(t *testing.T) { testDistantAnnouncementDiscarding(t, false) }
+func TestLightDistantAnnouncementDiscarding(t *testing.T) { testDistantAnnouncementDiscarding(t, true) }
+
+func testDistantAnnouncementDiscarding(t *testing.T, light bool) {
+ // Create a long chain to import and define the discard boundaries
+ hashes, blocks := makeChain(3*maxQueueDist, 0, genesis)
+ head := hashes[len(hashes)/2]
+
+ low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1
+
+ // Create a tester and simulate a head block being the middle of the above chain
+ tester := newTester(light)
+
+ tester.lock.Lock()
+ tester.hashes = []common.Hash{head}
+ tester.headers = map[common.Hash]*types.Header{head: blocks[head].Header()}
+ tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
+ tester.lock.Unlock()
+
+ headerFetcher := tester.makeHeaderFetcher("lower", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("lower", blocks, 0)
+
+ fetching := make(chan struct{}, 2)
+ tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} }
+
+ // Ensure that a block with a lower number than the threshold is discarded
+ tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ select {
+ case <-time.After(50 * time.Millisecond):
+ case <-fetching:
+ t.Fatalf("fetcher requested stale header")
+ }
+ // Ensure that a block with a higher number than the threshold is discarded
+ tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+ select {
+ case <-time.After(50 * time.Millisecond):
+ case <-fetching:
+ t.Fatalf("fetcher requested future header")
+ }
+}
+
+// Tests that peers announcing blocks with invalid numbers (i.e. not matching
+// the headers provided afterwards) get dropped as malicious.
+func TestFullInvalidNumberAnnouncement(t *testing.T) { testInvalidNumberAnnouncement(t, false) }
+func TestLightInvalidNumberAnnouncement(t *testing.T) { testInvalidNumberAnnouncement(t, true) }
+
+func testInvalidNumberAnnouncement(t *testing.T, light bool) {
+ // Create a single block to import and check numbers against
+ hashes, blocks := makeChain(1, 0, genesis)
+
+ tester := newTester(light)
+ badHeaderFetcher := tester.makeHeaderFetcher("bad", blocks, -gatherSlack)
+ badBodyFetcher := tester.makeBodyFetcher("bad", blocks, 0)
+
+ imported := make(chan interface{})
+ announced := make(chan interface{})
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if light {
+ if header == nil {
+ t.Fatalf("Fetcher try to import empty header")
+ }
+ imported <- header
+ } else {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ }
+ // Announce a block with a bad number, check for immediate drop
+ tester.fetcher.announceChangeHook = func(hash common.Hash, b bool) {
+ announced <- nil
+ }
+ tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher)
+ verifyAnnounce := func() {
+ for i := 0; i < 2; i++ {
+ select {
+ case <-announced:
+ continue
+ case <-time.After(1 * time.Second):
+ t.Fatal("announce timeout")
+ return
+ }
+ }
+ }
+ verifyAnnounce()
+ verifyImportEvent(t, imported, false)
+ tester.lock.RLock()
+ dropped := tester.drops["bad"]
+ tester.lock.RUnlock()
+
+ if !dropped {
+ t.Fatalf("peer with invalid numbered announcement not dropped")
+ }
+ goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack)
+ goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0)
+ // Make sure a good announcement passes without a drop
+ tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher)
+ verifyAnnounce()
+ verifyImportEvent(t, imported, true)
+
+ tester.lock.RLock()
+ dropped = tester.drops["good"]
+ tester.lock.RUnlock()
+
+ if dropped {
+ t.Fatalf("peer with valid numbered announcement dropped")
+ }
+ verifyImportDone(t, imported)
+}
+
+// Tests that if a block is empty (i.e. header only), no body request should be
+// made, and instead the header should be assembled into a whole block in itself.
+func TestEmptyBlockShortCircuit(t *testing.T) {
+ // Create a chain of blocks to import
+ hashes, blocks := makeChain(32, 0, genesis)
+
+ tester := newTester(false)
+ headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ // Add a monitoring hook for all internal events
+ fetching := make(chan []common.Hash)
+ tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes }
+
+ completing := make(chan []common.Hash)
+ tester.fetcher.completingHook = func(hashes []common.Hash) { completing <- hashes }
+
+ imported := make(chan interface{})
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) {
+ if block == nil {
+ t.Fatalf("Fetcher try to import empty block")
+ }
+ imported <- block
+ }
+ // Iteratively announce blocks until all are imported
+ for i := len(hashes) - 2; i >= 0; i-- {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+
+ // All announces should fetch the header
+ verifyFetchingEvent(t, fetching, true)
+
+ // Only blocks with data contents should request bodies
+ verifyCompletingEvent(t, completing, len(blocks[hashes[i]].Transactions()) > 0 || len(blocks[hashes[i]].Uncles()) > 0)
+
+ // Irrelevant of the construct, import should succeed
+ verifyImportEvent(t, imported, true)
+ }
+ verifyImportDone(t, imported)
+}
+
+// Tests that a peer is unable to use unbounded memory with sending infinite
+// block announcements to a node, but that even in the face of such an attack,
+// the fetcher remains operational.
+func TestHashMemoryExhaustionAttack(t *testing.T) {
+ // Create a tester with instrumented import hooks
+ tester := newTester(false)
+
+ imported, announces := make(chan interface{}), int32(0)
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block }
+ tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) {
+ if added {
+ atomic.AddInt32(&announces, 1)
+ } else {
+ atomic.AddInt32(&announces, -1)
+ }
+ }
+ // Create a valid chain and an infinite junk chain
+ targetBlocks := hashLimit + 2*maxQueueDist
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+ validHeaderFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+ validBodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
+
+ attack, _ := makeChain(targetBlocks, 0, unknownBlock)
+ attackerHeaderFetcher := tester.makeHeaderFetcher("attacker", nil, -gatherSlack)
+ attackerBodyFetcher := tester.makeBodyFetcher("attacker", nil, 0)
+
+ // Feed the tester a huge hashset from the attacker, and a limited from the valid peer
+ for i := 0; i < len(attack); i++ {
+ if i < maxQueueDist {
+ tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher)
+ }
+ tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), attackerHeaderFetcher, attackerBodyFetcher)
+ }
+ if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist {
+ t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist)
+ }
+ // Wait for fetches to complete
+ verifyImportCount(t, imported, maxQueueDist)
+
+ // Feed the remaining valid hashes to ensure DOS protection state remains clean
+ for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- {
+ tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher)
+ verifyImportEvent(t, imported, true)
+ }
+ verifyImportDone(t, imported)
+}
+
+// Tests that blocks sent to the fetcher (either through propagation or via hash
+// announces and retrievals) don't pile up indefinitely, exhausting available
+// system memory.
+func TestBlockMemoryExhaustionAttack(t *testing.T) {
+ // Create a tester with instrumented import hooks
+ tester := newTester(false)
+
+ imported, enqueued := make(chan interface{}), int32(0)
+ tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block }
+ tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) {
+ if added {
+ atomic.AddInt32(&enqueued, 1)
+ } else {
+ atomic.AddInt32(&enqueued, -1)
+ }
+ }
+ // Create a valid chain and a batch of dangling (but in range) blocks
+ targetBlocks := hashLimit + 2*maxQueueDist
+ hashes, blocks := makeChain(targetBlocks, 0, genesis)
+ attack := make(map[common.Hash]*types.Block)
+ for i := byte(0); len(attack) < blockLimit+2*maxQueueDist; i++ {
+ hashes, blocks := makeChain(maxQueueDist-1, i, unknownBlock)
+ for _, hash := range hashes[:maxQueueDist-2] {
+ attack[hash] = blocks[hash]
+ }
+ }
+ // Try to feed all the attacker blocks make sure only a limited batch is accepted
+ for _, block := range attack {
+ tester.fetcher.Enqueue("attacker", block)
+ }
+ time.Sleep(200 * time.Millisecond)
+ if queued := atomic.LoadInt32(&enqueued); queued != blockLimit {
+ t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit)
+ }
+ // Queue up a batch of valid blocks, and check that a new peer is allowed to do so
+ for i := 0; i < maxQueueDist-1; i++ {
+ tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]])
+ }
+ time.Sleep(100 * time.Millisecond)
+ if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 {
+ t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1)
+ }
+ // Insert the missing piece (and sanity check the import)
+ tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2]])
+ verifyImportCount(t, imported, maxQueueDist)
+
+ // Insert the remaining blocks in chunks to ensure clean DOS protection
+ for i := maxQueueDist; i < len(hashes)-1; i++ {
+ tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2-i]])
+ verifyImportEvent(t, imported, true)
+ }
+ verifyImportDone(t, imported)
+}
diff --git a/les/handler_test.go b/les/handler_test.go
index bb8ad33829..aba45764b3 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -30,7 +30,7 @@ import (
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
- "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/les/downloader"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/params"
diff --git a/les/sync.go b/les/sync.go
index fa5ef4ff82..31cd06ca70 100644
--- a/les/sync.go
+++ b/les/sync.go
@@ -23,7 +23,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/eth/downloader"
+ "github.com/ethereum/go-ethereum/les/downloader"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"