les, les/lespay: implement new server pool (#20758)

This PR reimplements the light client server pool. It is also a first step
to move certain logic into a new lespay package. This package will contain
the implementation of the lespay token sale functions, the token buying and
selling logic and other components related to peer selection/prioritization
and service quality evaluation. Over the long term this package will be
reusable for incentivizing future protocols.

Since the LES peer logic is now based on enode.Iterator, it can now use
DNS-based fallback discovery to find servers.

This document describes the function of the new components:
https://gist.github.com/zsfelfoldi/3c7ace895234b7b345ab4f71dab102d4
This commit is contained in:
Felföldi Zsolt 2020-05-22 13:46:34 +02:00 committed by GitHub
parent 65ce550b37
commit b4a2681120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2904 additions and 1028 deletions

@ -1563,19 +1563,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
cfg.NetworkId = 3 cfg.NetworkId = 3
} }
cfg.Genesis = core.DefaultRopstenGenesisBlock() cfg.Genesis = core.DefaultRopstenGenesisBlock()
setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RopstenGenesisHash]) setDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash)
case ctx.GlobalBool(RinkebyFlag.Name): case ctx.GlobalBool(RinkebyFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) { if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 4 cfg.NetworkId = 4
} }
cfg.Genesis = core.DefaultRinkebyGenesisBlock() cfg.Genesis = core.DefaultRinkebyGenesisBlock()
setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RinkebyGenesisHash]) setDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash)
case ctx.GlobalBool(GoerliFlag.Name): case ctx.GlobalBool(GoerliFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) { if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 5 cfg.NetworkId = 5
} }
cfg.Genesis = core.DefaultGoerliGenesisBlock() cfg.Genesis = core.DefaultGoerliGenesisBlock()
setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.GoerliGenesisHash]) setDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash)
case ctx.GlobalBool(DeveloperFlag.Name): case ctx.GlobalBool(DeveloperFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) { if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 1337 cfg.NetworkId = 1337
@ -1604,18 +1604,25 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
} }
default: default:
if cfg.NetworkId == 1 { if cfg.NetworkId == 1 {
setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.MainnetGenesisHash]) setDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash)
} }
} }
} }
// setDNSDiscoveryDefaults configures DNS discovery with the given URL if // setDNSDiscoveryDefaults configures DNS discovery with the given URL if
// no URLs are set. // no URLs are set.
func setDNSDiscoveryDefaults(cfg *eth.Config, url string) { func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) {
if cfg.DiscoveryURLs != nil { if cfg.DiscoveryURLs != nil {
return return // already set through flags/config
}
protocol := "eth"
if cfg.SyncMode == downloader.LightSync {
protocol = "les"
}
if url := params.KnownDNSNetwork(genesis, protocol); url != "" {
cfg.DiscoveryURLs = []string{url}
} }
cfg.DiscoveryURLs = []string{url}
} }
// RegisterEthService adds an Ethereum client to the stack. // RegisterEthService adds an Ethereum client to the stack.

@ -27,7 +27,7 @@ import (
"strings" "strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
@ -44,6 +44,18 @@ var (
ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update") ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update")
) )
// Blockchain defines all necessary method to build a forkID.
type Blockchain interface {
// Config retrieves the chain's fork configuration.
Config() *params.ChainConfig
// Genesis retrieves the chain's genesis block.
Genesis() *types.Block
// CurrentHeader retrieves the current head header of the canonical chain.
CurrentHeader() *types.Header
}
// ID is a fork identifier as defined by EIP-2124. // ID is a fork identifier as defined by EIP-2124.
type ID struct { type ID struct {
Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers
@ -54,7 +66,7 @@ type ID struct {
type Filter func(id ID) error type Filter func(id ID) error
// NewID calculates the Ethereum fork ID from the chain config and head. // NewID calculates the Ethereum fork ID from the chain config and head.
func NewID(chain *core.BlockChain) ID { func NewID(chain Blockchain) ID {
return newID( return newID(
chain.Config(), chain.Config(),
chain.Genesis().Hash(), chain.Genesis().Hash(),
@ -85,7 +97,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID {
// NewFilter creates a filter that returns if a fork ID should be rejected or not // NewFilter creates a filter that returns if a fork ID should be rejected or not
// based on the local chain's status. // based on the local chain's status.
func NewFilter(chain *core.BlockChain) Filter { func NewFilter(chain Blockchain) Filter {
return newFilter( return newFilter(
chain.Config(), chain.Config(),
chain.Genesis().Hash(), chain.Genesis().Hash(),

@ -72,7 +72,7 @@ type Ethereum struct {
blockchain *core.BlockChain blockchain *core.BlockChain
protocolManager *ProtocolManager protocolManager *ProtocolManager
lesServer LesServer lesServer LesServer
dialCandiates enode.Iterator dialCandidates enode.Iterator
// DB interfaces // DB interfaces
chainDb ethdb.Database // Block chain database chainDb ethdb.Database // Block chain database
@ -226,7 +226,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
} }
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
eth.dialCandiates, err = eth.setupDiscovery(&ctx.Config.P2P) eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -523,7 +523,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
for i, vsn := range ProtocolVersions { for i, vsn := range ProtocolVersions {
protos[i] = s.protocolManager.makeProtocol(vsn) protos[i] = s.protocolManager.makeProtocol(vsn)
protos[i].Attributes = []enr.Entry{s.currentEthEntry()} protos[i].Attributes = []enr.Entry{s.currentEthEntry()}
protos[i].DialCandidates = s.dialCandiates protos[i].DialCandidates = s.dialCandidates
} }
if s.lesServer != nil { if s.lesServer != nil {
protos = append(protos, s.lesServer.Protocols()...) protos = append(protos, s.lesServer.Protocols()...)

@ -51,16 +51,17 @@ import (
type LightEthereum struct { type LightEthereum struct {
lesCommons lesCommons
peers *serverPeerSet peers *serverPeerSet
reqDist *requestDistributor reqDist *requestDistributor
retriever *retrieveManager retriever *retrieveManager
odr *LesOdr odr *LesOdr
relay *lesTxRelay relay *lesTxRelay
handler *clientHandler handler *clientHandler
txPool *light.TxPool txPool *light.TxPool
blockchain *light.LightChain blockchain *light.LightChain
serverPool *serverPool serverPool *serverPool
valueTracker *lpc.ValueTracker valueTracker *lpc.ValueTracker
dialCandidates enode.Iterator
bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports
@ -104,11 +105,19 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
bloomRequests: make(chan chan *bloombits.Retrieval), bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
serverPool: newServerPool(chainDb, config.UltraLightServers),
valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
} }
peers.subscribe((*vtSubscription)(leth.valueTracker)) peers.subscribe((*vtSubscription)(leth.valueTracker))
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P)
if err != nil {
return nil, err
}
leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, dnsdisc, time.Second, nil, &mclock.System{}, config.UltraLightServers)
peers.subscribe(leth.serverPool)
leth.dialCandidates = leth.serverPool.dialIterator
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout)
leth.relay = newLesTxRelay(peers, leth.retriever) leth.relay = newLesTxRelay(peers, leth.retriever)
leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever) leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever)
@ -140,11 +149,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
leth.chtIndexer.Start(leth.blockchain) leth.chtIndexer.Start(leth.blockchain)
leth.bloomIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain)
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
if leth.handler.ulc != nil {
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
leth.blockchain.DisableCheckFreq()
}
// Rewind the chain in case of an incompatible config upgrade. // Rewind the chain in case of an incompatible config upgrade.
if compat, ok := genesisErr.(*params.ConfigCompatError); ok { if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat) log.Warn("Rewinding chain to upgrade configuration", "err", compat)
@ -159,6 +163,11 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
} }
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
if leth.handler.ulc != nil {
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
leth.blockchain.DisableCheckFreq()
}
return leth, nil return leth, nil
} }
@ -260,7 +269,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
return p.Info() return p.Info()
} }
return nil return nil
}) }, s.dialCandidates)
} }
// Start implements node.Service, starting all internal goroutines needed by the // Start implements node.Service, starting all internal goroutines needed by the
@ -268,15 +277,12 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
func (s *LightEthereum) Start(srvr *p2p.Server) error { func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature") log.Warn("Light client mode is an experimental feature")
s.serverPool.start()
// Start bloom request workers. // Start bloom request workers.
s.wg.Add(bloomServiceThreads) s.wg.Add(bloomServiceThreads)
s.startBloomHandlers(params.BloomBitsBlocksClient) s.startBloomHandlers(params.BloomBitsBlocksClient)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId) s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
// clients are searching for the first advertised protocol in the list
protocolVersion := AdvertiseProtocolVersions[0]
s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
return nil return nil
} }
@ -284,6 +290,8 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error {
// Ethereum protocol. // Ethereum protocol.
func (s *LightEthereum) Stop() error { func (s *LightEthereum) Stop() error {
close(s.closeCh) close(s.closeCh)
s.serverPool.stop()
s.valueTracker.Stop()
s.peers.close() s.peers.close()
s.reqDist.close() s.reqDist.close()
s.odr.Stop() s.odr.Stop()
@ -295,8 +303,6 @@ func (s *LightEthereum) Stop() error {
s.txPool.Stop() s.txPool.Stop()
s.engine.Close() s.engine.Close()
s.eventMux.Stop() s.eventMux.Stop()
s.serverPool.stop()
s.valueTracker.Stop()
s.chainDb.Close() s.chainDb.Close()
s.wg.Wait() s.wg.Wait()
log.Info("Light ethereum stopped") log.Info("Light ethereum stopped")

@ -64,7 +64,7 @@ func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.T
if checkpoint != nil { if checkpoint != nil {
height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
} }
handler.fetcher = newLightFetcher(handler) handler.fetcher = newLightFetcher(handler, backend.serverPool.getTimeout)
handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer) handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer)
handler.backend.peers.subscribe((*downloaderPeerNotify)(handler)) handler.backend.peers.subscribe((*downloaderPeerNotify)(handler))
return handler return handler
@ -85,14 +85,9 @@ func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter)
} }
peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version))) peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version)))
defer peer.close() defer peer.close()
peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node())
if peer.poolEntry == nil {
return p2p.DiscRequested
}
h.wg.Add(1) h.wg.Add(1)
defer h.wg.Done() defer h.wg.Done()
err := h.handle(peer) err := h.handle(peer)
h.backend.serverPool.disconnect(peer.poolEntry)
return err return err
} }
@ -129,10 +124,6 @@ func (h *clientHandler) handle(p *serverPeer) error {
h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}) h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td})
// pool entry can be nil during the unit test.
if p.poolEntry != nil {
h.backend.serverPool.registered(p.poolEntry)
}
// Mark the peer starts to be served. // Mark the peer starts to be served.
atomic.StoreUint32(&p.serving, 1) atomic.StoreUint32(&p.serving, 1)
defer atomic.StoreUint32(&p.serving, 0) defer atomic.StoreUint32(&p.serving, 0)

@ -81,7 +81,7 @@ type NodeInfo struct {
} }
// makeProtocols creates protocol descriptors for the given LES versions. // makeProtocols creates protocol descriptors for the given LES versions.
func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol { func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}, dialCandidates enode.Iterator) []p2p.Protocol {
protos := make([]p2p.Protocol, len(versions)) protos := make([]p2p.Protocol, len(versions))
for i, version := range versions { for i, version := range versions {
version := version version := version
@ -93,7 +93,8 @@ func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p
Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error { Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
return runPeer(version, peer, rw) return runPeer(version, peer, rw)
}, },
PeerInfo: peerInfo, PeerInfo: peerInfo,
DialCandidates: dialCandidates,
} }
} }
return protos return protos

@ -180,12 +180,11 @@ func (d *requestDistributor) loop() {
type selectPeerItem struct { type selectPeerItem struct {
peer distPeer peer distPeer
req *distReq req *distReq
weight int64 weight uint64
} }
// Weight implements wrsItem interface func selectPeerWeight(i interface{}) uint64 {
func (sp selectPeerItem) Weight() int64 { return i.(selectPeerItem).weight
return sp.weight
} }
// nextRequest returns the next possible request from any peer, along with the // nextRequest returns the next possible request from any peer, along with the
@ -220,9 +219,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
wait, bufRemain := peer.waitBefore(cost) wait, bufRemain := peer.waitBefore(cost)
if wait == 0 { if wait == 0 {
if sel == nil { if sel == nil {
sel = utils.NewWeightedRandomSelect() sel = utils.NewWeightedRandomSelect(selectPeerWeight)
} }
sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1}) sel.Update(selectPeerItem{peer: peer, req: req, weight: uint64(bufRemain*1000000) + 1})
} else { } else {
if bestWait == 0 || wait < bestWait { if bestWait == 0 || wait < bestWait {
bestWait = wait bestWait = wait

@ -17,6 +17,9 @@
package les package les
import ( import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/dnsdisc"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -30,3 +33,12 @@ type lesEntry struct {
func (e lesEntry) ENRKey() string { func (e lesEntry) ENRKey() string {
return "les" return "les"
} }
// setupDiscovery creates the node discovery source for the eth protocol.
func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) {
if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 {
return nil, nil
}
client := dnsdisc.NewClient(dnsdisc.Config{})
return client.NewIterator(eth.config.DiscoveryURLs...)
}

@ -40,8 +40,9 @@ const (
// ODR system to ensure that we only request data related to a certain block from peers who have already processed // ODR system to ensure that we only request data related to a certain block from peers who have already processed
// and announced that block. // and announced that block.
type lightFetcher struct { type lightFetcher struct {
handler *clientHandler handler *clientHandler
chain *light.LightChain chain *light.LightChain
softRequestTimeout func() time.Duration
lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
maxConfirmedTd *big.Int maxConfirmedTd *big.Int
@ -109,18 +110,19 @@ type fetchResponse struct {
} }
// newLightFetcher creates a new light fetcher // newLightFetcher creates a new light fetcher
func newLightFetcher(h *clientHandler) *lightFetcher { func newLightFetcher(h *clientHandler, softRequestTimeout func() time.Duration) *lightFetcher {
f := &lightFetcher{ f := &lightFetcher{
handler: h, handler: h,
chain: h.backend.blockchain, chain: h.backend.blockchain,
peers: make(map[*serverPeer]*fetcherPeerInfo), peers: make(map[*serverPeer]*fetcherPeerInfo),
deliverChn: make(chan fetchResponse, 100), deliverChn: make(chan fetchResponse, 100),
requested: make(map[uint64]fetchRequest), requested: make(map[uint64]fetchRequest),
timeoutChn: make(chan uint64), timeoutChn: make(chan uint64),
requestTrigger: make(chan struct{}, 1), requestTrigger: make(chan struct{}, 1),
syncDone: make(chan *serverPeer), syncDone: make(chan *serverPeer),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
maxConfirmedTd: big.NewInt(0), maxConfirmedTd: big.NewInt(0),
softRequestTimeout: softRequestTimeout,
} }
h.backend.peers.subscribe(f) h.backend.peers.subscribe(f)
@ -163,7 +165,7 @@ func (f *lightFetcher) syncLoop() {
f.lock.Unlock() f.lock.Unlock()
} else { } else {
go func() { go func() {
time.Sleep(softRequestTimeout) time.Sleep(f.softRequestTimeout())
f.reqMu.Lock() f.reqMu.Lock()
req, ok := f.requested[reqID] req, ok := f.requested[reqID]
if ok { if ok {
@ -187,7 +189,6 @@ func (f *lightFetcher) syncLoop() {
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok { if ok {
f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
req.peer.Log().Debug("Fetching data timed out hard") req.peer.Log().Debug("Fetching data timed out hard")
go f.handler.removePeer(req.peer.id) go f.handler.removePeer(req.peer.id)
} }
@ -201,9 +202,6 @@ func (f *lightFetcher) syncLoop() {
delete(f.requested, resp.reqID) delete(f.requested, resp.reqID)
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok {
f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
}
f.lock.Lock() f.lock.Lock()
if !ok || !(f.syncing || f.processResponse(req, resp)) { if !ok || !(f.syncing || f.processResponse(req, resp)) {
resp.peer.Log().Debug("Failed processing response") resp.peer.Log().Debug("Failed processing response")
@ -879,12 +877,10 @@ func (f *lightFetcher) checkUpdateStats(p *serverPeer, newEntry *updateStatsEntr
fp.firstUpdateStats = newEntry fp.firstUpdateStats = newEntry
} }
for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
if fp.confirmedTd != nil { if fp.confirmedTd != nil {
for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
} }

@ -0,0 +1,107 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
// FillSet tries to read nodes from an input iterator and add them to a node set by
// setting the specified node state flag(s) until the size of the set reaches the target.
// Note that other mechanisms (like other FillSet instances reading from different inputs)
// can also set the same flag(s) and FillSet will always care about the total number of
// nodes having those flags.
type FillSet struct {
lock sync.Mutex
cond *sync.Cond
ns *nodestate.NodeStateMachine
input enode.Iterator
closed bool
flags nodestate.Flags
count, target int
}
// NewFillSet creates a new FillSet
func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet {
fs := &FillSet{
ns: ns,
input: input,
flags: flags,
}
fs.cond = sync.NewCond(&fs.lock)
ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) {
fs.lock.Lock()
if oldState.Equals(flags) {
fs.count--
}
if newState.Equals(flags) {
fs.count++
}
if fs.target > fs.count {
fs.cond.Signal()
}
fs.lock.Unlock()
})
go fs.readLoop()
return fs
}
// readLoop keeps reading nodes from the input and setting the specified flags for them
// whenever the node set size is under the current target
func (fs *FillSet) readLoop() {
for {
fs.lock.Lock()
for fs.target <= fs.count && !fs.closed {
fs.cond.Wait()
}
fs.lock.Unlock()
if !fs.input.Next() {
return
}
fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0)
}
}
// SetTarget sets the current target for node set size. If the previous target was not
// reached and FillSet was still waiting for the next node from the input then the next
// incoming node will be added to the set regardless of the target. This ensures that
// all nodes coming from the input are eventually added to the set.
func (fs *FillSet) SetTarget(target int) {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.target = target
if fs.target > fs.count {
fs.cond.Signal()
}
}
// Close shuts FillSet down and closes the input iterator
func (fs *FillSet) Close() {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.closed = true
fs.input.Close()
fs.cond.Signal()
}

@ -0,0 +1,113 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"math/rand"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
type testIter struct {
waitCh chan struct{}
nodeCh chan *enode.Node
node *enode.Node
}
func (i *testIter) Next() bool {
i.waitCh <- struct{}{}
i.node = <-i.nodeCh
return i.node != nil
}
func (i *testIter) Node() *enode.Node {
return i.node
}
func (i *testIter) Close() {}
func (i *testIter) push() {
var id enode.ID
rand.Read(id[:])
i.nodeCh <- enode.SignNull(new(enr.Record), id)
}
func (i *testIter) waiting(timeout time.Duration) bool {
select {
case <-i.waitCh:
return true
case <-time.After(timeout):
return false
}
}
func TestFillSet(t *testing.T) {
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
iter := &testIter{
waitCh: make(chan struct{}),
nodeCh: make(chan *enode.Node),
}
fs := NewFillSet(ns, iter, sfTest1)
ns.Start()
expWaiting := func(i int, push bool) {
for ; i > 0; i-- {
if !iter.waiting(time.Second * 10) {
t.Fatalf("FillSet not waiting for new nodes")
}
if push {
iter.push()
}
}
}
expNotWaiting := func() {
if iter.waiting(time.Millisecond * 100) {
t.Fatalf("FillSet unexpectedly waiting for new nodes")
}
}
expNotWaiting()
fs.SetTarget(3)
expWaiting(3, true)
expNotWaiting()
fs.SetTarget(100)
expWaiting(2, true)
expWaiting(1, false)
// lower the target before the previous one has been filled up
fs.SetTarget(0)
iter.push()
expNotWaiting()
fs.SetTarget(10)
expWaiting(4, true)
expNotWaiting()
// remove all previosly set flags
ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
ns.SetState(node, nodestate.Flags{}, sfTest1, 0)
})
// now expect FillSet to fill the set up again with 10 new nodes
expWaiting(10, true)
expNotWaiting()
fs.Close()
ns.Stop()
}

@ -0,0 +1,123 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"sync"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
// QueueIterator returns nodes from the specified selectable set in the same order as
// they entered the set.
type QueueIterator struct {
lock sync.Mutex
cond *sync.Cond
ns *nodestate.NodeStateMachine
queue []*enode.Node
nextNode *enode.Node
waitCallback func(bool)
fifo, closed bool
}
// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
// disables further selectability until it is removed or times out.
func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator {
qi := &QueueIterator{
ns: ns,
fifo: fifo,
waitCallback: waitCallback,
}
qi.cond = sync.NewCond(&qi.lock)
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
if newMatch == oldMatch {
return
}
qi.lock.Lock()
defer qi.lock.Unlock()
if newMatch {
qi.queue = append(qi.queue, n)
} else {
id := n.ID()
for i, qn := range qi.queue {
if qn.ID() == id {
copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:])
qi.queue = qi.queue[:len(qi.queue)-1]
break
}
}
}
qi.cond.Signal()
})
return qi
}
// Next moves to the next selectable node.
func (qi *QueueIterator) Next() bool {
qi.lock.Lock()
if !qi.closed && len(qi.queue) == 0 {
if qi.waitCallback != nil {
qi.waitCallback(true)
}
for !qi.closed && len(qi.queue) == 0 {
qi.cond.Wait()
}
if qi.waitCallback != nil {
qi.waitCallback(false)
}
}
if qi.closed {
qi.nextNode = nil
qi.lock.Unlock()
return false
}
// Move to the next node in queue.
if qi.fifo {
qi.nextNode = qi.queue[0]
copy(qi.queue[:len(qi.queue)-1], qi.queue[1:])
qi.queue = qi.queue[:len(qi.queue)-1]
} else {
qi.nextNode = qi.queue[len(qi.queue)-1]
qi.queue = qi.queue[:len(qi.queue)-1]
}
qi.lock.Unlock()
return true
}
// Close ends the iterator.
func (qi *QueueIterator) Close() {
qi.lock.Lock()
qi.closed = true
qi.lock.Unlock()
qi.cond.Signal()
}
// Node returns the current node.
func (qi *QueueIterator) Node() *enode.Node {
qi.lock.Lock()
defer qi.lock.Unlock()
return qi.nextNode
}

@ -0,0 +1,106 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
func testNodeID(i int) enode.ID {
return enode.ID{42, byte(i % 256), byte(i / 256)}
}
func testNodeIndex(id enode.ID) int {
if id[0] != 42 {
return -1
}
return int(id[1]) + int(id[2])*256
}
func testNode(i int) *enode.Node {
return enode.SignNull(new(enr.Record), testNodeID(i))
}
func TestQueueIteratorFIFO(t *testing.T) {
testQueueIterator(t, true)
}
func TestQueueIteratorLIFO(t *testing.T) {
testQueueIterator(t, false)
}
func testQueueIterator(t *testing.T, fifo bool) {
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil)
ns.Start()
for i := 1; i <= iterTestNodeCount; i++ {
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
}
next := func() int {
ch := make(chan struct{})
go func() {
qi.Next()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second * 5):
t.Fatalf("Iterator.Next() timeout")
}
node := qi.Node()
ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
return testNodeIndex(node.ID())
}
exp := func(i int) {
n := next()
if n != i {
t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n)
}
}
explist := func(list []int) {
for i := range list {
if fifo {
exp(list[i])
} else {
exp(list[len(list)-1-i])
}
}
}
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
explist([]int{1, 2, 3})
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0)
explist([]int{4, 6})
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0)
ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0)
explist([]int{1, 3, 2})
ns.Stop()
}

@ -213,6 +213,15 @@ func (vt *ValueTracker) StatsExpirer() *utils.Expirer {
return &vt.statsExpirer return &vt.statsExpirer
} }
// StatsExpirer returns the current expiration factor so that other values can be expired
// with the same rate as the service value statistics.
func (vt *ValueTracker) StatsExpFactor() utils.ExpirationFactor {
vt.statsExpLock.RLock()
defer vt.statsExpLock.RUnlock()
return vt.statsExpFactor
}
// loadFromDb loads the value tracker's state from the database and converts saved // loadFromDb loads the value tracker's state from the database and converts saved
// request basket index mapping if it does not match the specified index to name mapping. // request basket index mapping if it does not match the specified index to name mapping.
func (vt *ValueTracker) loadFromDb(mapping []string) error { func (vt *ValueTracker) loadFromDb(mapping []string) error {
@ -500,16 +509,3 @@ func (vt *ValueTracker) RequestStats() []RequestStatsItem {
} }
return res return res
} }
// TotalServiceValue returns the total service value provided by the given node (as
// a function of the weights which are calculated from the request timeout value).
func (vt *ValueTracker) TotalServiceValue(nv *NodeValueTracker, weights ResponseTimeWeights) float64 {
vt.statsExpLock.RLock()
expFactor := vt.statsExpFactor
vt.statsExpLock.RUnlock()
nv.lock.Lock()
defer nv.lock.Unlock()
return nv.rtStats.Value(weights, expFactor)
}

@ -0,0 +1,128 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"sync"
"github.com/ethereum/go-ethereum/les/utils"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
// WrsIterator returns nodes from the specified selectable set with a weighted random
// selection. Selection weights are provided by a callback function.
type WrsIterator struct {
lock sync.Mutex
cond *sync.Cond
ns *nodestate.NodeStateMachine
wrs *utils.WeightedRandomSelect
nextNode *enode.Node
closed bool
}
// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
// disables further selectability until it is removed or times out.
func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator {
wfn := func(i interface{}) uint64 {
n := ns.GetNode(i.(enode.ID))
if n == nil {
return 0
}
wt, _ := ns.GetField(n, weightField).(uint64)
return wt
}
w := &WrsIterator{
ns: ns,
wrs: utils.NewWeightedRandomSelect(wfn),
}
w.cond = sync.NewCond(&w.lock)
ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) {
if state.HasAll(requireFlags) && state.HasNone(disableFlags) {
w.lock.Lock()
w.wrs.Update(n.ID())
w.lock.Unlock()
w.cond.Signal()
}
})
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
if newMatch == oldMatch {
return
}
w.lock.Lock()
if newMatch {
w.wrs.Update(n.ID())
} else {
w.wrs.Remove(n.ID())
}
w.lock.Unlock()
w.cond.Signal()
})
return w
}
// Next selects the next node.
func (w *WrsIterator) Next() bool {
w.nextNode = w.chooseNode()
return w.nextNode != nil
}
func (w *WrsIterator) chooseNode() *enode.Node {
w.lock.Lock()
defer w.lock.Unlock()
for {
for !w.closed && w.wrs.IsEmpty() {
w.cond.Wait()
}
if w.closed {
return nil
}
// Choose the next node at random. Even though w.wrs is guaranteed
// non-empty here, Choose might return nil if all items have weight
// zero.
if c := w.wrs.Choose(); c != nil {
id := c.(enode.ID)
w.wrs.Remove(id)
return w.ns.GetNode(id)
}
}
}
// Close ends the iterator.
func (w *WrsIterator) Close() {
w.lock.Lock()
w.closed = true
w.lock.Unlock()
w.cond.Signal()
}
// Node returns the current node.
func (w *WrsIterator) Node() *enode.Node {
w.lock.Lock()
defer w.lock.Unlock()
return w.nextNode
}

@ -0,0 +1,103 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package client
import (
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/nodestate"
)
var (
testSetup = &nodestate.Setup{}
sfTest1 = testSetup.NewFlag("test1")
sfTest2 = testSetup.NewFlag("test2")
sfTest3 = testSetup.NewFlag("test3")
sfTest4 = testSetup.NewFlag("test4")
sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
)
const iterTestNodeCount = 6
func TestWrsIterator(t *testing.T) {
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight)
ns.Start()
for i := 1; i <= iterTestNodeCount; i++ {
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
ns.SetField(testNode(i), sfiTestWeight, uint64(1))
}
next := func() int {
ch := make(chan struct{})
go func() {
w.Next()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second * 5):
t.Fatalf("Iterator.Next() timeout")
}
node := w.Node()
ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
return testNodeIndex(node.ID())
}
set := make(map[int]bool)
expset := func() {
for len(set) > 0 {
n := next()
if !set[n] {
t.Errorf("Item returned by iterator not in the expected set (got %d)", n)
}
delete(set, n)
}
}
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
set[1] = true
set[2] = true
set[3] = true
expset()
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0)
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
set[4] = true
set[6] = true
expset()
ns.SetField(testNode(2), sfiTestWeight, uint64(0))
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
set[1] = true
set[3] = true
expset()
ns.SetField(testNode(2), sfiTestWeight, uint64(1))
ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0)
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
ns.SetState(testNode(2), sfTest2, sfTest4, 0)
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
set[1] = true
set[2] = true
set[3] = true
expset()
ns.Stop()
}

@ -107,6 +107,13 @@ var (
requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil) requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil) requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
serverSelectableGauge = metrics.NewRegisteredGauge("les/client/serverPool/selectable", nil)
serverDialedMeter = metrics.NewRegisteredMeter("les/client/serverPool/dialed", nil)
serverConnectedGauge = metrics.NewRegisteredGauge("les/client/serverPool/connected", nil)
sessionValueMeter = metrics.NewRegisteredMeter("les/client/serverPool/sessionValue", nil)
totalValueGauge = metrics.NewRegisteredGauge("les/client/serverPool/totalValue", nil)
suggestedTimeoutGauge = metrics.NewRegisteredGauge("les/client/serverPool/timeout", nil)
) )
// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of

@ -336,7 +336,6 @@ type serverPeer struct {
checkpointNumber uint64 // The block height which the checkpoint is registered. checkpointNumber uint64 // The block height which the checkpoint is registered.
checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server. checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server.
poolEntry *poolEntry // Statistic for server peer.
fcServer *flowcontrol.ServerNode // Client side mirror token bucket. fcServer *flowcontrol.ServerNode // Client side mirror token bucket.
vtLock sync.Mutex vtLock sync.Mutex
valueTracker *lpc.ValueTracker valueTracker *lpc.ValueTracker

@ -130,7 +130,6 @@ func init() {
} }
requestMapping[uint32(code)] = rm requestMapping[uint32(code)] = rm
} }
} }
type errCode int type errCode int

@ -24,22 +24,20 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
) )
var ( var (
retryQueue = time.Millisecond * 100 retryQueue = time.Millisecond * 100
softRequestTimeout = time.Millisecond * 500
hardRequestTimeout = time.Second * 10 hardRequestTimeout = time.Second * 10
) )
// retrieveManager is a layer on top of requestDistributor which takes care of // retrieveManager is a layer on top of requestDistributor which takes care of
// matching replies by request ID and handles timeouts and resends if necessary. // matching replies by request ID and handles timeouts and resends if necessary.
type retrieveManager struct { type retrieveManager struct {
dist *requestDistributor dist *requestDistributor
peers *serverPeerSet peers *serverPeerSet
serverPool peerSelector softRequestTimeout func() time.Duration
lock sync.RWMutex lock sync.RWMutex
sentReqs map[uint64]*sentReq sentReqs map[uint64]*sentReq
@ -48,11 +46,6 @@ type retrieveManager struct {
// validatorFunc is a function that processes a reply message // validatorFunc is a function that processes a reply message
type validatorFunc func(distPeer, *Msg) error type validatorFunc func(distPeer, *Msg) error
// peerSelector receives feedback info about response times and timeouts
type peerSelector interface {
adjustResponseTime(*poolEntry, time.Duration, bool)
}
// sentReq represents a request sent and tracked by retrieveManager // sentReq represents a request sent and tracked by retrieveManager
type sentReq struct { type sentReq struct {
rm *retrieveManager rm *retrieveManager
@ -99,12 +92,12 @@ const (
) )
// newRetrieveManager creates the retrieve manager // newRetrieveManager creates the retrieve manager
func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager { func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, srto func() time.Duration) *retrieveManager {
return &retrieveManager{ return &retrieveManager{
peers: peers, peers: peers,
dist: dist, dist: dist,
serverPool: serverPool, sentReqs: make(map[uint64]*sentReq),
sentReqs: make(map[uint64]*sentReq), softRequestTimeout: srto,
} }
} }
@ -325,8 +318,7 @@ func (r *sentReq) tryRequest() {
return return
} }
reqSent := mclock.Now() hrto := false
srto, hrto := false, false
r.lock.RLock() r.lock.RLock()
s, ok := r.sentTo[p] s, ok := r.sentTo[p]
@ -338,11 +330,7 @@ func (r *sentReq) tryRequest() {
defer func() { defer func() {
// send feedback to server pool and remove peer if hard timeout happened // send feedback to server pool and remove peer if hard timeout happened
pp, ok := p.(*serverPeer) pp, ok := p.(*serverPeer)
if ok && r.rm.serverPool != nil { if hrto && ok {
respTime := time.Duration(mclock.Now() - reqSent)
r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto)
}
if hrto {
pp.Log().Debug("Request timed out hard") pp.Log().Debug("Request timed out hard")
if r.rm.peers != nil { if r.rm.peers != nil {
r.rm.peers.unregister(pp.id) r.rm.peers.unregister(pp.id)
@ -363,8 +351,7 @@ func (r *sentReq) tryRequest() {
} }
r.eventsCh <- reqPeerEvent{event, p} r.eventsCh <- reqPeerEvent{event, p}
return return
case <-time.After(softRequestTimeout): case <-time.After(r.rm.softRequestTimeout()):
srto = true
r.eventsCh <- reqPeerEvent{rpSoftTimeout, p} r.eventsCh <- reqPeerEvent{rpSoftTimeout, p}
} }

@ -157,7 +157,7 @@ func (s *LesServer) Protocols() []p2p.Protocol {
return p.Info() return p.Info()
} }
return nil return nil
}) }, nil)
// Add "les" ENR entries. // Add "les" ENR entries.
for i := range ps { for i := range ps {
ps[i].Attributes = []enr.Entry{&lesEntry{}} ps[i].Attributes = []enr.Entry{&lesEntry{}}

File diff suppressed because it is too large Load Diff

352
les/serverpool_test.go Normal file

@ -0,0 +1,352 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les
import (
"math/rand"
"sync/atomic"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
lpc "github.com/ethereum/go-ethereum/les/lespay/client"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)
const (
spTestNodes = 1000
spTestTarget = 5
spTestLength = 10000
spMinTotal = 40000
spMaxTotal = 50000
)
func testNodeID(i int) enode.ID {
return enode.ID{42, byte(i % 256), byte(i / 256)}
}
func testNodeIndex(id enode.ID) int {
if id[0] != 42 {
return -1
}
return int(id[1]) + int(id[2])*256
}
type serverPoolTest struct {
db ethdb.KeyValueStore
clock *mclock.Simulated
quit chan struct{}
preNeg, preNegFail bool
vt *lpc.ValueTracker
sp *serverPool
input enode.Iterator
testNodes []spTestNode
trusted []string
waitCount, waitEnded int32
cycle, conn, servedConn int
serviceCycles, dialCount int
disconnect map[int][]int
}
type spTestNode struct {
connectCycles, waitCycles int
nextConnCycle, totalConn int
connected, service bool
peer *serverPeer
}
func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest {
nodes := make([]*enode.Node, spTestNodes)
for i := range nodes {
nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i))
}
return &serverPoolTest{
clock: &mclock.Simulated{},
db: memorydb.New(),
input: enode.CycleNodes(nodes),
testNodes: make([]spTestNode, spTestNodes),
preNeg: preNeg,
preNegFail: preNegFail,
}
}
func (s *serverPoolTest) beginWait() {
// ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state
for atomic.AddInt32(&s.waitCount, 1) > preNegLimit {
atomic.AddInt32(&s.waitCount, -1)
s.clock.Run(time.Second)
}
}
func (s *serverPoolTest) endWait() {
atomic.AddInt32(&s.waitCount, -1)
atomic.AddInt32(&s.waitEnded, 1)
}
func (s *serverPoolTest) addTrusted(i int) {
s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String())
}
func (s *serverPoolTest) start() {
var testQuery queryFunc
if s.preNeg {
testQuery = func(node *enode.Node) int {
idx := testNodeIndex(node.ID())
n := &s.testNodes[idx]
canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle
if s.preNegFail {
// simulate a scenario where UDP queries never work
s.beginWait()
s.clock.Sleep(time.Second * 5)
s.endWait()
return -1
} else {
switch idx % 3 {
case 0:
// pre-neg returns true only if connection is possible
if canConnect {
return 1
} else {
return 0
}
case 1:
// pre-neg returns true but connection might still fail
return 1
case 2:
// pre-neg returns true if connection is possible, otherwise timeout (node unresponsive)
if canConnect {
return 1
} else {
s.beginWait()
s.clock.Sleep(time.Second * 5)
s.endWait()
return -1
}
}
return -1
}
}
}
s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000))
s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted)
s.sp.validSchemes = enode.ValidSchemesForTesting
s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) }
s.disconnect = make(map[int][]int)
s.sp.start()
s.quit = make(chan struct{})
go func() {
last := int32(-1)
for {
select {
case <-time.After(time.Millisecond * 100):
c := atomic.LoadInt32(&s.waitEnded)
if c == last {
// advance clock if test is stuck (might happen in rare cases)
s.clock.Run(time.Second)
}
last = c
case <-s.quit:
return
}
}
}()
}
func (s *serverPoolTest) stop() {
close(s.quit)
s.sp.stop()
s.vt.Stop()
for i := range s.testNodes {
n := &s.testNodes[i]
if n.connected {
n.totalConn += s.cycle
}
n.connected = false
n.peer = nil
n.nextConnCycle = 0
}
s.conn, s.servedConn = 0, 0
}
func (s *serverPoolTest) run() {
for count := spTestLength; count > 0; count-- {
if dcList := s.disconnect[s.cycle]; dcList != nil {
for _, idx := range dcList {
n := &s.testNodes[idx]
s.sp.unregisterPeer(n.peer)
n.totalConn += s.cycle
n.connected = false
n.peer = nil
s.conn--
if n.service {
s.servedConn--
}
n.nextConnCycle = s.cycle + n.waitCycles
}
delete(s.disconnect, s.cycle)
}
if s.conn < spTestTarget {
s.dialCount++
s.beginWait()
s.sp.dialIterator.Next()
s.endWait()
dial := s.sp.dialIterator.Node()
id := dial.ID()
idx := testNodeIndex(id)
n := &s.testNodes[idx]
if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle {
s.conn++
if n.service {
s.servedConn++
}
n.totalConn -= s.cycle
n.connected = true
dc := s.cycle + n.connectCycles
s.disconnect[dc] = append(s.disconnect[dc], idx)
n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}}
s.sp.registerPeer(n.peer)
if n.service {
s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0)
}
}
}
s.serviceCycles += s.servedConn
s.clock.Run(time.Second)
s.cycle++
}
}
func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) {
for ; count > 0; count-- {
idx := rand.Intn(spTestNodes)
for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected {
idx = rand.Intn(spTestNodes)
}
res = append(res, idx)
s.testNodes[idx] = spTestNode{
connectCycles: conn,
waitCycles: wait,
service: service,
}
if trusted {
s.addTrusted(idx)
}
}
return
}
func (s *serverPoolTest) resetNodes() {
for i, n := range s.testNodes {
if n.connected {
n.totalConn += s.cycle
s.sp.unregisterPeer(n.peer)
}
s.testNodes[i] = spTestNode{totalConn: n.totalConn}
}
s.conn, s.servedConn = 0, 0
s.disconnect = make(map[int][]int)
s.trusted = nil
}
func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) {
var sum int
for _, idx := range nodes {
n := &s.testNodes[idx]
if n.connected {
n.totalConn += s.cycle
}
sum += n.totalConn
n.totalConn = 0
if n.connected {
n.totalConn -= s.cycle
}
}
if sum < spMinTotal || sum > spMaxTotal {
t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal)
}
}
func TestServerPool(t *testing.T) { testServerPool(t, false, false) }
func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) }
func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) }
func testServerPool(t *testing.T, preNeg, fail bool) {
s := newServerPoolTest(preNeg, fail)
nodes := s.setNodes(100, 200, 200, true, false)
s.setNodes(100, 20, 20, false, false)
s.start()
s.run()
s.stop()
s.checkNodes(t, nodes)
}
func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) }
func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) }
func testServerPoolChangedNodes(t *testing.T, preNeg bool) {
s := newServerPoolTest(preNeg, false)
nodes := s.setNodes(100, 200, 200, true, false)
s.setNodes(100, 20, 20, false, false)
s.start()
s.run()
s.checkNodes(t, nodes)
for i := 0; i < 3; i++ {
s.resetNodes()
nodes := s.setNodes(100, 200, 200, true, false)
s.setNodes(100, 20, 20, false, false)
s.run()
s.checkNodes(t, nodes)
}
s.stop()
}
func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) }
func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) {
testServerPoolRestartNoDiscovery(t, true)
}
func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) {
s := newServerPoolTest(preNeg, false)
nodes := s.setNodes(100, 200, 200, true, false)
s.setNodes(100, 20, 20, false, false)
s.start()
s.run()
s.stop()
s.checkNodes(t, nodes)
s.input = nil
s.start()
s.run()
s.stop()
s.checkNodes(t, nodes)
}
func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) }
func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) {
testServerPoolTrustedNoDiscovery(t, true)
}
func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) {
s := newServerPoolTest(preNeg, false)
trusted := s.setNodes(200, 200, 200, true, true)
s.input = nil
s.start()
s.run()
s.stop()
s.checkNodes(t, trusted)
}

@ -508,7 +508,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer
clock = &mclock.Simulated{} clock = &mclock.Simulated{}
} }
dist := newRequestDistributor(speers, clock) dist := newRequestDistributor(speers, clock)
rm := newRetrieveManager(speers, dist, nil) rm := newRetrieveManager(speers, dist, func() time.Duration { return time.Millisecond * 500 })
odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm) odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm)
sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig) sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig)

@ -63,14 +63,7 @@ func ExpFactor(logOffset Fixed64) ExpirationFactor {
// Value calculates the expired value based on a floating point base and integer // Value calculates the expired value based on a floating point base and integer
// power-of-2 exponent. This function should be used by multi-value expired structures. // power-of-2 exponent. This function should be used by multi-value expired structures.
func (e ExpirationFactor) Value(base float64, exp uint64) float64 { func (e ExpirationFactor) Value(base float64, exp uint64) float64 {
res := base / e.Factor return base / e.Factor * math.Pow(2, float64(int64(exp-e.Exp)))
if exp > e.Exp {
res *= float64(uint64(1) << (exp - e.Exp))
}
if exp < e.Exp {
res /= float64(uint64(1) << (e.Exp - exp))
}
return res
} }
// value calculates the value at the given moment. // value calculates the value at the given moment.

@ -16,28 +16,44 @@
package utils package utils
import "math/rand" import (
"math/rand"
)
// wrsItem interface should be implemented by any entries that are to be selected from type (
// a WeightedRandomSelect set. Note that recalculating monotonously decreasing item // WeightedRandomSelect is capable of weighted random selection from a set of items
// weights on-demand (without constantly calling Update) is allowed WeightedRandomSelect struct {
type wrsItem interface { root *wrsNode
Weight() int64 idx map[WrsItem]int
} wfn WeightFn
}
// WeightedRandomSelect is capable of weighted random selection from a set of items WrsItem interface{}
type WeightedRandomSelect struct { WeightFn func(interface{}) uint64
root *wrsNode )
idx map[wrsItem]int
}
// NewWeightedRandomSelect returns a new WeightedRandomSelect structure // NewWeightedRandomSelect returns a new WeightedRandomSelect structure
func NewWeightedRandomSelect() *WeightedRandomSelect { func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect {
return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)} return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn}
}
// Update updates an item's weight, adds it if it was non-existent or removes it if
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
func (w *WeightedRandomSelect) Update(item WrsItem) {
w.setWeight(item, w.wfn(item))
}
// Remove removes an item from the set
func (w *WeightedRandomSelect) Remove(item WrsItem) {
w.setWeight(item, 0)
}
// IsEmpty returns true if the set is empty
func (w *WeightedRandomSelect) IsEmpty() bool {
return w.root.sumWeight == 0
} }
// setWeight sets an item's weight to a specific value (removes it if zero) // setWeight sets an item's weight to a specific value (removes it if zero)
func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) { func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
idx, ok := w.idx[item] idx, ok := w.idx[item]
if ok { if ok {
w.root.setWeight(idx, weight) w.root.setWeight(idx, weight)
@ -58,33 +74,22 @@ func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
} }
} }
// Update updates an item's weight, adds it if it was non-existent or removes it if
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
func (w *WeightedRandomSelect) Update(item wrsItem) {
w.setWeight(item, item.Weight())
}
// Remove removes an item from the set
func (w *WeightedRandomSelect) Remove(item wrsItem) {
w.setWeight(item, 0)
}
// Choose randomly selects an item from the set, with a chance proportional to its // Choose randomly selects an item from the set, with a chance proportional to its
// current weight. If the weight of the chosen element has been decreased since the // current weight. If the weight of the chosen element has been decreased since the
// last stored value, returns it with a newWeight/oldWeight chance, otherwise just // last stored value, returns it with a newWeight/oldWeight chance, otherwise just
// updates its weight and selects another one // updates its weight and selects another one
func (w *WeightedRandomSelect) Choose() wrsItem { func (w *WeightedRandomSelect) Choose() WrsItem {
for { for {
if w.root.sumWeight == 0 { if w.root.sumWeight == 0 {
return nil return nil
} }
val := rand.Int63n(w.root.sumWeight) val := uint64(rand.Int63n(int64(w.root.sumWeight)))
choice, lastWeight := w.root.choose(val) choice, lastWeight := w.root.choose(val)
weight := choice.Weight() weight := w.wfn(choice)
if weight != lastWeight { if weight != lastWeight {
w.setWeight(choice, weight) w.setWeight(choice, weight)
} }
if weight >= lastWeight || rand.Int63n(lastWeight) < weight { if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight {
return choice return choice
} }
} }
@ -92,16 +97,16 @@ func (w *WeightedRandomSelect) Choose() wrsItem {
const wrsBranches = 8 // max number of branches in the wrsNode tree const wrsBranches = 8 // max number of branches in the wrsNode tree
// wrsNode is a node of a tree structure that can store wrsItems or further wrsNodes. // wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes.
type wrsNode struct { type wrsNode struct {
items [wrsBranches]interface{} items [wrsBranches]interface{}
weights [wrsBranches]int64 weights [wrsBranches]uint64
sumWeight int64 sumWeight uint64
level, itemCnt, maxItems int level, itemCnt, maxItems int
} }
// insert recursively inserts a new item to the tree and returns the item index // insert recursively inserts a new item to the tree and returns the item index
func (n *wrsNode) insert(item wrsItem, weight int64) int { func (n *wrsNode) insert(item WrsItem, weight uint64) int {
branch := 0 branch := 0
for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) { for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) {
branch++ branch++
@ -129,7 +134,7 @@ func (n *wrsNode) insert(item wrsItem, weight int64) int {
// setWeight updates the weight of a certain item (which should exist) and returns // setWeight updates the weight of a certain item (which should exist) and returns
// the change of the last weight value stored in the tree // the change of the last weight value stored in the tree
func (n *wrsNode) setWeight(idx int, weight int64) int64 { func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
if n.level == 0 { if n.level == 0 {
oldWeight := n.weights[idx] oldWeight := n.weights[idx]
n.weights[idx] = weight n.weights[idx] = weight
@ -152,12 +157,12 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 {
return diff return diff
} }
// Choose recursively selects an item from the tree and returns it along with its weight // choose recursively selects an item from the tree and returns it along with its weight
func (n *wrsNode) choose(val int64) (wrsItem, int64) { func (n *wrsNode) choose(val uint64) (WrsItem, uint64) {
for i, w := range n.weights { for i, w := range n.weights {
if val < w { if val < w {
if n.level == 0 { if n.level == 0 {
return n.items[i].(wrsItem), n.weights[i] return n.items[i].(WrsItem), n.weights[i]
} }
return n.items[i].(*wrsNode).choose(val) return n.items[i].(*wrsNode).choose(val)
} }

@ -26,17 +26,18 @@ type testWrsItem struct {
widx *int widx *int
} }
func (t *testWrsItem) Weight() int64 { func testWeight(i interface{}) uint64 {
t := i.(*testWrsItem)
w := *t.widx w := *t.widx
if w == -1 || w == t.idx { if w == -1 || w == t.idx {
return int64(t.idx + 1) return uint64(t.idx + 1)
} }
return 0 return 0
} }
func TestWeightedRandomSelect(t *testing.T) { func TestWeightedRandomSelect(t *testing.T) {
testFn := func(cnt int) { testFn := func(cnt int) {
s := NewWeightedRandomSelect() s := NewWeightedRandomSelect(testWeight)
w := -1 w := -1
list := make([]testWrsItem, cnt) list := make([]testWrsItem, cnt)
for i := range list { for i := range list {

880
p2p/nodestate/nodestate.go Normal file

@ -0,0 +1,880 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package nodestate
import (
"errors"
"reflect"
"sync"
"time"
"unsafe"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
)
type (
// NodeStateMachine connects different system components operating on subsets of
// network nodes. Node states are represented by 64 bit vectors with each bit assigned
// to a state flag. Each state flag has a descriptor structure and the mapping is
// created automatically. It is possible to subscribe to subsets of state flags and
// receive a callback if one of the nodes has a relevant state flag changed.
// Callbacks can also modify further flags of the same node or other nodes. State
// updates only return after all immediate effects throughout the system have happened
// (deadlocks should be avoided by design of the implemented state logic). The caller
// can also add timeouts assigned to a certain node and a subset of state flags.
// If the timeout elapses, the flags are reset. If all relevant flags are reset then
// the timer is dropped. State flags with no timeout are persisted in the database
// if the flag descriptor enables saving. If a node has no state flags set at any
// moment then it is discarded.
//
// Extra node fields can also be registered so system components can also store more
// complex state for each node that is relevant to them, without creating a custom
// peer set. Fields can be shared across multiple components if they all know the
// field ID. Subscription to fields is also possible. Persistent fields should have
// an encoder and a decoder function.
NodeStateMachine struct {
started, stopped bool
lock sync.Mutex
clock mclock.Clock
db ethdb.KeyValueStore
dbNodeKey []byte
nodes map[enode.ID]*nodeInfo
offlineCallbackList []offlineCallback
// Registered state flags or fields. Modifications are allowed
// only when the node state machine has not been started.
setup *Setup
fields []*fieldInfo
saveFlags bitMask
// Installed callbacks. Modifications are allowed only when the
// node state machine has not been started.
stateSubs []stateSub
// Testing hooks, only for testing purposes.
saveNodeHook func(*nodeInfo)
}
// Flags represents a set of flags from a certain setup
Flags struct {
mask bitMask
setup *Setup
}
// Field represents a field from a certain setup
Field struct {
index int
setup *Setup
}
// flagDefinition describes a node state flag. Each registered instance is automatically
// mapped to a bit of the 64 bit node states.
// If persistent is true then the node is saved when state machine is shutdown.
flagDefinition struct {
name string
persistent bool
}
// fieldDefinition describes an optional node field of the given type. The contents
// of the field are only retained for each node as long as at least one of the
// state flags is set.
fieldDefinition struct {
name string
ftype reflect.Type
encode func(interface{}) ([]byte, error)
decode func([]byte) (interface{}, error)
}
// stateSetup contains the list of flags and fields used by the application
Setup struct {
Version uint
flags []flagDefinition
fields []fieldDefinition
}
// bitMask describes a node state or state mask. It represents a subset
// of node flags with each bit assigned to a flag index (LSB represents flag 0).
bitMask uint64
// StateCallback is a subscription callback which is called when one of the
// state flags that is included in the subscription state mask is changed.
// Note: oldState and newState are also masked with the subscription mask so only
// the relevant bits are included.
StateCallback func(n *enode.Node, oldState, newState Flags)
// FieldCallback is a subscription callback which is called when the value of
// a specific field is changed.
FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{})
// nodeInfo contains node state, fields and state timeouts
nodeInfo struct {
node *enode.Node
state bitMask
timeouts []*nodeStateTimeout
fields []interface{}
db, dirty bool
}
nodeInfoEnc struct {
Enr enr.Record
Version uint
State bitMask
Fields [][]byte
}
stateSub struct {
mask bitMask
callback StateCallback
}
nodeStateTimeout struct {
mask bitMask
timer mclock.Timer
}
fieldInfo struct {
fieldDefinition
subs []FieldCallback
}
offlineCallback struct {
node *enode.Node
state bitMask
fields []interface{}
}
)
// offlineState is a special state that is assumed to be set before a node is loaded from
// the database and after it is shut down.
const offlineState = bitMask(1)
// NewFlag creates a new node state flag
func (s *Setup) NewFlag(name string) Flags {
if s.flags == nil {
s.flags = []flagDefinition{{name: "offline"}}
}
f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
s.flags = append(s.flags, flagDefinition{name: name})
return f
}
// NewPersistentFlag creates a new persistent node state flag
func (s *Setup) NewPersistentFlag(name string) Flags {
if s.flags == nil {
s.flags = []flagDefinition{{name: "offline"}}
}
f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s}
s.flags = append(s.flags, flagDefinition{name: name, persistent: true})
return f
}
// OfflineFlag returns the system-defined offline flag belonging to the given setup
func (s *Setup) OfflineFlag() Flags {
return Flags{mask: offlineState, setup: s}
}
// NewField creates a new node state field
func (s *Setup) NewField(name string, ftype reflect.Type) Field {
f := Field{index: len(s.fields), setup: s}
s.fields = append(s.fields, fieldDefinition{
name: name,
ftype: ftype,
})
return f
}
// NewPersistentField creates a new persistent node field
func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field {
f := Field{index: len(s.fields), setup: s}
s.fields = append(s.fields, fieldDefinition{
name: name,
ftype: ftype,
encode: encode,
decode: decode,
})
return f
}
// flagOp implements binary flag operations and also checks whether the operands belong to the same setup
func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags {
if a.setup == nil {
if a.mask != 0 {
panic("Node state flags have no setup reference")
}
a.setup = b.setup
}
if b.setup == nil {
if b.mask != 0 {
panic("Node state flags have no setup reference")
}
b.setup = a.setup
}
if a.setup != b.setup {
panic("Node state flags belong to a different setup")
}
res := Flags{setup: a.setup}
if trueIfA {
res.mask |= a.mask & ^b.mask
}
if trueIfB {
res.mask |= b.mask & ^a.mask
}
if trueIfBoth {
res.mask |= a.mask & b.mask
}
return res
}
// And returns the set of flags present in both a and b
func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) }
// AndNot returns the set of flags present in a but not in b
func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) }
// Or returns the set of flags present in either a or b
func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) }
// Xor returns the set of flags present in either a or b but not both
func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) }
// HasAll returns true if b is a subset of a
func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 }
// HasNone returns true if a and b have no shared flags
func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 }
// Equals returns true if a and b have the same flags set
func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 }
// IsEmpty returns true if a has no flags set
func (a Flags) IsEmpty() bool { return a.mask == 0 }
// MergeFlags merges multiple sets of state flags
func MergeFlags(list ...Flags) Flags {
if len(list) == 0 {
return Flags{}
}
res := list[0]
for i := 1; i < len(list); i++ {
res = res.Or(list[i])
}
return res
}
// String returns a list of the names of the flags specified in the bit mask
func (f Flags) String() string {
if f.mask == 0 {
return "[]"
}
s := "["
comma := false
for index, flag := range f.setup.flags {
if f.mask&(bitMask(1)<<uint(index)) != 0 {
if comma {
s = s + ", "
}
s = s + flag.name
comma = true
}
}
s = s + "]"
return s
}
// NewNodeStateMachine creates a new node state machine.
// If db is not nil then the node states, fields and active timeouts are persisted.
// Persistence can be enabled or disabled for each state flag and field.
func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Clock, setup *Setup) *NodeStateMachine {
if setup.flags == nil {
panic("No state flags defined")
}
if len(setup.flags) > 8*int(unsafe.Sizeof(bitMask(0))) {
panic("Too many node state flags")
}
ns := &NodeStateMachine{
db: db,
dbNodeKey: dbKey,
clock: clock,
setup: setup,
nodes: make(map[enode.ID]*nodeInfo),
fields: make([]*fieldInfo, len(setup.fields)),
}
stateNameMap := make(map[string]int)
for index, flag := range setup.flags {
if _, ok := stateNameMap[flag.name]; ok {
panic("Node state flag name collision")
}
stateNameMap[flag.name] = index
if flag.persistent {
ns.saveFlags |= bitMask(1) << uint(index)
}
}
fieldNameMap := make(map[string]int)
for index, field := range setup.fields {
if _, ok := fieldNameMap[field.name]; ok {
panic("Node field name collision")
}
ns.fields[index] = &fieldInfo{fieldDefinition: field}
fieldNameMap[field.name] = index
}
return ns
}
// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask
func (ns *NodeStateMachine) stateMask(flags Flags) bitMask {
if flags.setup != ns.setup && flags.mask != 0 {
panic("Node state flags belong to a different setup")
}
return flags.mask
}
// fieldIndex checks whether the field belongs to the same setup and returns its internal index
func (ns *NodeStateMachine) fieldIndex(field Field) int {
if field.setup != ns.setup {
panic("Node field belongs to a different setup")
}
return field.index
}
// SubscribeState adds a node state subscription. The callback is called while the state
// machine mutex is not held and it is allowed to make further state updates. All immediate
// changes throughout the system are processed in the same thread/goroutine. It is the
// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks,
// infinite toggling of flags or hazardous/non-deterministic state changes.
// State subscriptions should be installed before loading the node database or making the
// first state update.
func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) {
ns.lock.Lock()
defer ns.lock.Unlock()
if ns.started {
panic("state machine already started")
}
ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback})
}
// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState.
func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) {
ns.lock.Lock()
defer ns.lock.Unlock()
if ns.started {
panic("state machine already started")
}
f := ns.fields[ns.fieldIndex(field)]
f.subs = append(f.subs, callback)
}
// newNode creates a new nodeInfo
func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo {
return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))}
}
// checkStarted checks whether the state machine has already been started and panics otherwise.
func (ns *NodeStateMachine) checkStarted() {
if !ns.started {
panic("state machine not started yet")
}
}
// Start starts the state machine, enabling state and field operations and disabling
// further subscriptions.
func (ns *NodeStateMachine) Start() {
ns.lock.Lock()
if ns.started {
panic("state machine already started")
}
ns.started = true
if ns.db != nil {
ns.loadFromDb()
}
ns.lock.Unlock()
ns.offlineCallbacks(true)
}
// Stop stops the state machine and saves its state if a database was supplied
func (ns *NodeStateMachine) Stop() {
ns.lock.Lock()
for _, node := range ns.nodes {
fields := make([]interface{}, len(node.fields))
copy(fields, node.fields)
ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
}
ns.stopped = true
if ns.db != nil {
ns.saveToDb()
ns.lock.Unlock()
} else {
ns.lock.Unlock()
}
ns.offlineCallbacks(false)
}
// loadFromDb loads persisted node states from the database
func (ns *NodeStateMachine) loadFromDb() {
it := ns.db.NewIterator(ns.dbNodeKey, nil)
for it.Next() {
var id enode.ID
if len(it.Key()) != len(ns.dbNodeKey)+len(id) {
log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id))
continue
}
copy(id[:], it.Key()[len(ns.dbNodeKey):])
ns.decodeNode(id, it.Value())
}
}
type dummyIdentity enode.ID
func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil }
func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] }
// decodeNode decodes a node database entry and adds it to the node set if successful
func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) {
var enc nodeInfoEnc
if err := rlp.DecodeBytes(data, &enc); err != nil {
log.Error("Failed to decode node info", "id", id, "error", err)
return
}
n, _ := enode.New(dummyIdentity(id), &enc.Enr)
node := ns.newNode(n)
node.db = true
if enc.Version != ns.setup.Version {
log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version)
ns.deleteNode(id)
return
}
if len(enc.Fields) > len(ns.setup.fields) {
log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields))
return
}
// Resolve persisted node fields
for i, encField := range enc.Fields {
if len(encField) == 0 {
continue
}
if decode := ns.fields[i].decode; decode != nil {
if field, err := decode(encField); err == nil {
node.fields[i] = field
} else {
log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err)
return
}
} else {
log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name)
return
}
}
// It's a compatible node record, add it to set.
ns.nodes[id] = node
node.state = enc.State
fields := make([]interface{}, len(node.fields))
copy(fields, node.fields)
ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields})
log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
}
// saveNode saves the given node info to the database
func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error {
if ns.db == nil {
return nil
}
storedState := node.state & ns.saveFlags
for _, t := range node.timeouts {
storedState &= ^t.mask
}
if storedState == 0 {
if node.db {
node.db = false
ns.deleteNode(id)
}
node.dirty = false
return nil
}
enc := nodeInfoEnc{
Enr: *node.node.Record(),
Version: ns.setup.Version,
State: storedState,
Fields: make([][]byte, len(ns.fields)),
}
log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup})
lastIndex := -1
for i, f := range node.fields {
if f == nil {
continue
}
encode := ns.fields[i].encode
if encode == nil {
continue
}
blob, err := encode(f)
if err != nil {
return err
}
enc.Fields[i] = blob
lastIndex = i
}
enc.Fields = enc.Fields[:lastIndex+1]
data, err := rlp.EncodeToBytes(&enc)
if err != nil {
return err
}
if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil {
return err
}
node.dirty, node.db = false, true
if ns.saveNodeHook != nil {
ns.saveNodeHook(node)
}
return nil
}
// deleteNode removes a node info from the database
func (ns *NodeStateMachine) deleteNode(id enode.ID) {
ns.db.Delete(append(ns.dbNodeKey, id[:]...))
}
// saveToDb saves the persistent flags and fields of all nodes that have been changed
func (ns *NodeStateMachine) saveToDb() {
for id, node := range ns.nodes {
if node.dirty {
err := ns.saveNode(id, node)
if err != nil {
log.Error("Failed to save node", "id", id, "error", err)
}
}
}
}
// updateEnode updates the enode entry belonging to the given node if it already exists
func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) {
id := n.ID()
node := ns.nodes[id]
if node != nil && n.Seq() > node.node.Seq() {
node.node = n
}
return id, node
}
// Persist saves the persistent state and fields of the given node immediately
func (ns *NodeStateMachine) Persist(n *enode.Node) error {
ns.lock.Lock()
defer ns.lock.Unlock()
ns.checkStarted()
if id, node := ns.updateEnode(n); node != nil && node.dirty {
err := ns.saveNode(id, node)
if err != nil {
log.Error("Failed to save node", "id", id, "error", err)
}
return err
}
return nil
}
// SetState updates the given node state flags and processes all resulting callbacks.
// It only returns after all subsequent immediate changes (including those changed by the
// callbacks) have been processed. If a flag with a timeout is set again, the operation
// removes or replaces the existing timeout.
func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) {
ns.lock.Lock()
ns.checkStarted()
if ns.stopped {
ns.lock.Unlock()
return
}
set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags)
id, node := ns.updateEnode(n)
if node == nil {
if set == 0 {
ns.lock.Unlock()
return
}
node = ns.newNode(n)
ns.nodes[id] = node
}
oldState := node.state
newState := (node.state & (^reset)) | set
changed := oldState ^ newState
node.state = newState
// Remove the timeout callbacks for all reset and set flags,
// even they are not existent(it's noop).
ns.removeTimeouts(node, set|reset)
// Register the timeout callback if the new state is not empty
// and timeout itself is required.
if timeout != 0 && newState != 0 {
ns.addTimeout(n, set, timeout)
}
if newState == oldState {
ns.lock.Unlock()
return
}
if newState == 0 {
delete(ns.nodes, id)
if node.db {
ns.deleteNode(id)
}
} else {
if changed&ns.saveFlags != 0 {
node.dirty = true
}
}
ns.lock.Unlock()
// call state update subscription callbacks without holding the mutex
for _, sub := range ns.stateSubs {
if changed&sub.mask != 0 {
sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup})
}
}
if newState == 0 {
// call field subscriptions for discarded fields
for i, v := range node.fields {
if v != nil {
f := ns.fields[i]
if len(f.subs) > 0 {
for _, cb := range f.subs {
cb(n, Flags{setup: ns.setup}, v, nil)
}
}
}
}
}
}
// offlineCallbacks calls state update callbacks at startup or shutdown
func (ns *NodeStateMachine) offlineCallbacks(start bool) {
for _, cb := range ns.offlineCallbackList {
for _, sub := range ns.stateSubs {
offState := offlineState & sub.mask
onState := cb.state & sub.mask
if offState != onState {
if start {
sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup})
} else {
sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup})
}
}
}
for i, f := range cb.fields {
if f != nil && ns.fields[i].subs != nil {
for _, fsub := range ns.fields[i].subs {
if start {
fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f)
} else {
fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil)
}
}
}
}
}
ns.offlineCallbackList = nil
}
// AddTimeout adds a node state timeout associated to the given state flag(s).
// After the specified time interval, the relevant states will be reset.
func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) {
ns.lock.Lock()
defer ns.lock.Unlock()
ns.checkStarted()
if ns.stopped {
return
}
ns.addTimeout(n, ns.stateMask(flags), timeout)
}
// addTimeout adds a node state timeout associated to the given state flag(s).
func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) {
_, node := ns.updateEnode(n)
if node == nil {
return
}
mask &= node.state
if mask == 0 {
return
}
ns.removeTimeouts(node, mask)
t := &nodeStateTimeout{mask: mask}
t.timer = ns.clock.AfterFunc(timeout, func() {
ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0)
})
node.timeouts = append(node.timeouts, t)
if mask&ns.saveFlags != 0 {
node.dirty = true
}
}
// removeTimeout removes node state timeouts associated to the given state flag(s).
// If a timeout was associated to multiple flags which are not all included in the
// specified remove mask then only the included flags are de-associated and the timer
// stays active.
func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) {
for i := 0; i < len(node.timeouts); i++ {
t := node.timeouts[i]
match := t.mask & mask
if match == 0 {
continue
}
t.mask -= match
if t.mask != 0 {
continue
}
t.timer.Stop()
node.timeouts[i] = node.timeouts[len(node.timeouts)-1]
node.timeouts = node.timeouts[:len(node.timeouts)-1]
i--
if match&ns.saveFlags != 0 {
node.dirty = true
}
}
}
// GetField retrieves the given field of the given node
func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} {
ns.lock.Lock()
defer ns.lock.Unlock()
ns.checkStarted()
if ns.stopped {
return nil
}
if _, node := ns.updateEnode(n); node != nil {
return node.fields[ns.fieldIndex(field)]
}
return nil
}
// SetField sets the given field of the given node
func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error {
ns.lock.Lock()
ns.checkStarted()
if ns.stopped {
ns.lock.Unlock()
return nil
}
_, node := ns.updateEnode(n)
if node == nil {
ns.lock.Unlock()
return nil
}
fieldIndex := ns.fieldIndex(field)
f := ns.fields[fieldIndex]
if value != nil && reflect.TypeOf(value) != f.ftype {
log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype)
ns.lock.Unlock()
return errors.New("invalid field type")
}
oldValue := node.fields[fieldIndex]
if value == oldValue {
ns.lock.Unlock()
return nil
}
node.fields[fieldIndex] = value
if f.encode != nil {
node.dirty = true
}
state := node.state
ns.lock.Unlock()
if len(f.subs) > 0 {
for _, cb := range f.subs {
cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value)
}
}
return nil
}
// ForEach calls the callback for each node having all of the required and none of the
// disabled flags set
func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) {
ns.lock.Lock()
ns.checkStarted()
type callback struct {
node *enode.Node
state bitMask
}
require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags)
var callbacks []callback
for _, node := range ns.nodes {
if node.state&require == require && node.state&disable == 0 {
callbacks = append(callbacks, callback{node.node, node.state & (require | disable)})
}
}
ns.lock.Unlock()
for _, c := range callbacks {
cb(c.node, Flags{mask: c.state, setup: ns.setup})
}
}
// GetNode returns the enode currently associated with the given ID
func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node {
ns.lock.Lock()
defer ns.lock.Unlock()
ns.checkStarted()
if node := ns.nodes[id]; node != nil {
return node.node
}
return nil
}
// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently
// being in a given set specified by required and disabled state flags
func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) {
var count int64
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) {
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
if newMatch == oldMatch {
return
}
if newMatch {
count++
if name != "" {
log.Debug("Node entered", "set", name, "id", n.ID(), "count", count)
}
if inMeter != nil {
inMeter.Mark(1)
}
} else {
count--
if name != "" {
log.Debug("Node left", "set", name, "id", n.ID(), "count", count)
}
if outMeter != nil {
outMeter.Mark(1)
}
}
if gauge != nil {
gauge.Update(count)
}
})
}

@ -0,0 +1,389 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package nodestate
import (
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
)
func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) {
setup := &Setup{}
flags := make([]Flags, len(flagPersist))
for i, persist := range flagPersist {
if persist {
flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i))
} else {
flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i))
}
}
fields := make([]Field, len(fieldType))
for i, ftype := range fieldType {
switch ftype {
case reflect.TypeOf(uint64(0)):
fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec)
case reflect.TypeOf(""):
fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec)
default:
fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype)
}
}
return setup, flags, fields
}
func testNode(b byte) *enode.Node {
r := &enr.Record{}
r.SetSig(dummyIdentity{b}, []byte{42})
n, _ := enode.New(dummyIdentity{b}, r)
return n
}
func TestCallback(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, _ := testSetup([]bool{false, false, false}, nil)
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
set0 := make(chan struct{}, 1)
set1 := make(chan struct{}, 1)
set2 := make(chan struct{}, 1)
ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} })
ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} })
ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} })
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, 0)
ns.SetState(testNode(1), flags[1], Flags{}, time.Second)
ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second)
for i := 0; i < 3; i++ {
select {
case <-set0:
case <-set1:
case <-set2:
case <-time.After(time.Second):
t.Fatalf("failed to invoke callback")
}
}
}
func TestPersistentFlags(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, _ := testSetup([]bool{true, true, true, false}, nil)
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
saveNode := make(chan *nodeInfo, 5)
ns.saveNodeHook = func(node *nodeInfo) {
saveNode <- node
}
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved
ns.SetState(testNode(2), flags[1], Flags{}, 0)
ns.SetState(testNode(3), flags[2], Flags{}, 0)
ns.SetState(testNode(4), flags[3], Flags{}, 0)
ns.SetState(testNode(5), flags[0], Flags{}, 0)
ns.Persist(testNode(5))
select {
case <-saveNode:
case <-time.After(time.Second):
t.Fatalf("Timeout")
}
ns.Stop()
for i := 0; i < 2; i++ {
select {
case <-saveNode:
case <-time.After(time.Second):
t.Fatalf("Timeout")
}
}
select {
case <-saveNode:
t.Fatalf("Unexpected saveNode")
case <-time.After(time.Millisecond * 100):
}
}
func TestSetField(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")})
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
saveNode := make(chan *nodeInfo, 1)
ns.saveNodeHook = func(node *nodeInfo) {
saveNode <- node
}
ns.Start()
// Set field before setting state
ns.SetField(testNode(1), fields[0], "hello world")
field := ns.GetField(testNode(1), fields[0])
if field != nil {
t.Fatalf("Field shouldn't be set before setting states")
}
// Set field after setting state
ns.SetState(testNode(1), flags[0], Flags{}, 0)
ns.SetField(testNode(1), fields[0], "hello world")
field = ns.GetField(testNode(1), fields[0])
if field == nil {
t.Fatalf("Field should be set after setting states")
}
if err := ns.SetField(testNode(1), fields[0], 123); err == nil {
t.Fatalf("Invalid field should be rejected")
}
// Dirty node should be written back
ns.Stop()
select {
case <-saveNode:
case <-time.After(time.Second):
t.Fatalf("Timeout")
}
}
func TestUnsetField(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")})
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
ns.SetField(testNode(1), fields[0], "hello world")
ns.SetState(testNode(1), Flags{}, flags[0], 0)
if field := ns.GetField(testNode(1), fields[0]); field != nil {
t.Fatalf("Field should be unset")
}
}
func TestSetState(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, _ := testSetup([]bool{false, false, false}, nil)
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
type change struct{ old, new Flags }
set := make(chan change, 1)
ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) {
set <- change{
old: oldState,
new: newState,
}
})
ns.Start()
check := func(expectOld, expectNew Flags, expectChange bool) {
if expectChange {
select {
case c := <-set:
if !c.old.Equals(expectOld) {
t.Fatalf("Old state mismatch")
}
if !c.new.Equals(expectNew) {
t.Fatalf("New state mismatch")
}
case <-time.After(time.Second):
}
return
}
select {
case <-set:
t.Fatalf("Unexpected change")
case <-time.After(time.Millisecond * 100):
return
}
}
ns.SetState(testNode(1), flags[0], Flags{}, 0)
check(Flags{}, flags[0], true)
ns.SetState(testNode(1), flags[1], Flags{}, 0)
check(flags[0], flags[0].Or(flags[1]), true)
ns.SetState(testNode(1), flags[2], Flags{}, 0)
check(Flags{}, Flags{}, false)
ns.SetState(testNode(1), Flags{}, flags[0], 0)
check(flags[0].Or(flags[1]), flags[1], true)
ns.SetState(testNode(1), Flags{}, flags[1], 0)
check(flags[1], Flags{}, true)
ns.SetState(testNode(1), Flags{}, flags[2], 0)
check(Flags{}, Flags{}, false)
ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second)
check(Flags{}, flags[0].Or(flags[1]), true)
clock.Run(time.Second)
check(flags[0].Or(flags[1]), Flags{}, true)
}
func uint64FieldEnc(field interface{}) ([]byte, error) {
if u, ok := field.(uint64); ok {
enc, err := rlp.EncodeToBytes(&u)
return enc, err
} else {
return nil, errors.New("invalid field type")
}
}
func uint64FieldDec(enc []byte) (interface{}, error) {
var u uint64
err := rlp.DecodeBytes(enc, &u)
return u, err
}
func stringFieldEnc(field interface{}) ([]byte, error) {
if s, ok := field.(string); ok {
return []byte(s), nil
} else {
return nil, errors.New("invalid field type")
}
}
func stringFieldDec(enc []byte) (interface{}, error) {
return string(enc), nil
}
func TestPersistentFields(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")})
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, 0)
ns.SetField(testNode(1), fields[0], uint64(100))
ns.SetField(testNode(1), fields[1], "hello world")
ns.Stop()
ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns2.Start()
field0 := ns2.GetField(testNode(1), fields[0])
if !reflect.DeepEqual(field0, uint64(100)) {
t.Fatalf("Field changed")
}
field1 := ns2.GetField(testNode(1), fields[1])
if !reflect.DeepEqual(field1, "hello world") {
t.Fatalf("Field changed")
}
s.Version++
ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns3.Start()
if ns3.GetField(testNode(1), fields[0]) != nil {
t.Fatalf("Old field version should have been discarded")
}
}
func TestFieldSub(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))})
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
var (
lastState Flags
lastOldValue, lastNewValue interface{}
)
ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
lastState, lastOldValue, lastNewValue = state, oldValue, newValue
})
check := func(state Flags, oldValue, newValue interface{}) {
if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue {
t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue)
}
}
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, 0)
ns.SetField(testNode(1), fields[0], uint64(100))
check(flags[0], nil, uint64(100))
ns.Stop()
check(s.OfflineFlag(), uint64(100), nil)
ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
lastState, lastOldValue, lastNewValue = state, oldValue, newValue
})
ns2.Start()
check(s.OfflineFlag(), nil, uint64(100))
ns2.SetState(testNode(1), Flags{}, flags[0], 0)
check(Flags{}, uint64(100), nil)
ns2.Stop()
}
func TestDuplicatedFlags(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, _ := testSetup([]bool{true}, nil)
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
type change struct{ old, new Flags }
set := make(chan change, 1)
ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
set <- change{oldState, newState}
})
ns.Start()
defer ns.Stop()
check := func(expectOld, expectNew Flags, expectChange bool) {
if expectChange {
select {
case c := <-set:
if !c.old.Equals(expectOld) {
t.Fatalf("Old state mismatch")
}
if !c.new.Equals(expectNew) {
t.Fatalf("New state mismatch")
}
case <-time.After(time.Second):
}
return
}
select {
case <-set:
t.Fatalf("Unexpected change")
case <-time.After(time.Millisecond * 100):
return
}
}
ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
check(Flags{}, flags[0], true)
ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s
check(Flags{}, flags[0], false)
clock.Run(2 * time.Second)
check(flags[0], Flags{}, true)
}

@ -65,11 +65,22 @@ var GoerliBootnodes = []string{
const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@" const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@"
// These DNS names provide bootstrap connectivity for public testnets and the mainnet. // KnownDNSNetwork returns the address of a public DNS-based node list for the given
// See https://github.com/ethereum/discv4-dns-lists for more information. // genesis hash and protocol. See https://github.com/ethereum/discv4-dns-lists for more
var KnownDNSNetworks = map[common.Hash]string{ // information.
MainnetGenesisHash: dnsPrefix + "all.mainnet.ethdisco.net", func KnownDNSNetwork(genesis common.Hash, protocol string) string {
RopstenGenesisHash: dnsPrefix + "all.ropsten.ethdisco.net", var net string
RinkebyGenesisHash: dnsPrefix + "all.rinkeby.ethdisco.net", switch genesis {
GoerliGenesisHash: dnsPrefix + "all.goerli.ethdisco.net", case MainnetGenesisHash:
net = "mainnet"
case RopstenGenesisHash:
net = "ropsten"
case RinkebyGenesisHash:
net = "rinkeby"
case GoerliGenesisHash:
net = "goerli"
default:
return ""
}
return dnsPrefix + protocol + "." + net + ".ethdisco.net"
} }