From b4a2681120209836f7275753abcc9c7d3f9f8924 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= Date: Fri, 22 May 2020 13:46:34 +0200 Subject: [PATCH] les, les/lespay: implement new server pool (#20758) This PR reimplements the light client server pool. It is also a first step to move certain logic into a new lespay package. This package will contain the implementation of the lespay token sale functions, the token buying and selling logic and other components related to peer selection/prioritization and service quality evaluation. Over the long term this package will be reusable for incentivizing future protocols. Since the LES peer logic is now based on enode.Iterator, it can now use DNS-based fallback discovery to find servers. This document describes the function of the new components: https://gist.github.com/zsfelfoldi/3c7ace895234b7b345ab4f71dab102d4 --- cmd/utils/flags.go | 21 +- core/forkid/forkid.go | 18 +- eth/backend.go | 6 +- les/client.go | 54 +- les/client_handler.go | 11 +- les/commons.go | 5 +- les/distributor.go | 11 +- les/enr_entry.go | 12 + les/fetcher.go | 36 +- les/lespay/client/fillset.go | 107 ++ les/lespay/client/fillset_test.go | 113 ++ les/lespay/client/queueiterator.go | 123 +++ les/lespay/client/queueiterator_test.go | 106 ++ les/lespay/client/valuetracker.go | 22 +- les/lespay/client/wrsiterator.go | 128 +++ les/lespay/client/wrsiterator_test.go | 103 ++ les/metrics.go | 7 + les/peer.go | 1 - les/protocol.go | 1 - les/retrieve.go | 35 +- les/server.go | 2 +- les/serverpool.go | 1263 ++++++++--------------- les/serverpool_test.go | 352 +++++++ les/test_helper.go | 2 +- les/utils/expiredvalue.go | 9 +- les/utils/weighted_select.go | 83 +- les/utils/weighted_select_test.go | 7 +- p2p/nodestate/nodestate.go | 880 ++++++++++++++++ p2p/nodestate/nodestate_test.go | 389 +++++++ params/bootnodes.go | 25 +- 30 files changed, 2904 insertions(+), 1028 deletions(-) create mode 100644 les/lespay/client/fillset.go create mode 100644 les/lespay/client/fillset_test.go create mode 100644 les/lespay/client/queueiterator.go create mode 100644 les/lespay/client/queueiterator_test.go create mode 100644 les/lespay/client/wrsiterator.go create mode 100644 les/lespay/client/wrsiterator_test.go create mode 100644 les/serverpool_test.go create mode 100644 p2p/nodestate/nodestate.go create mode 100644 p2p/nodestate/nodestate_test.go diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 6f6233e4d..59a2c9f83 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -1563,19 +1563,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.NetworkId = 3 } cfg.Genesis = core.DefaultRopstenGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RopstenGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash) case ctx.GlobalBool(RinkebyFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 4 } cfg.Genesis = core.DefaultRinkebyGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.RinkebyGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash) case ctx.GlobalBool(GoerliFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 5 } cfg.Genesis = core.DefaultGoerliGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.GoerliGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash) case ctx.GlobalBool(DeveloperFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 1337 @@ -1604,18 +1604,25 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { } default: if cfg.NetworkId == 1 { - setDNSDiscoveryDefaults(cfg, params.KnownDNSNetworks[params.MainnetGenesisHash]) + setDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash) } } } // setDNSDiscoveryDefaults configures DNS discovery with the given URL if // no URLs are set. -func setDNSDiscoveryDefaults(cfg *eth.Config, url string) { +func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { if cfg.DiscoveryURLs != nil { - return + return // already set through flags/config + } + + protocol := "eth" + if cfg.SyncMode == downloader.LightSync { + protocol = "les" + } + if url := params.KnownDNSNetwork(genesis, protocol); url != "" { + cfg.DiscoveryURLs = []string{url} } - cfg.DiscoveryURLs = []string{url} } // RegisterEthService adds an Ethereum client to the stack. diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index e433db446..08a948510 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -27,7 +27,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -44,6 +44,18 @@ var ( ErrLocalIncompatibleOrStale = errors.New("local incompatible or needs update") ) +// Blockchain defines all necessary method to build a forkID. +type Blockchain interface { + // Config retrieves the chain's fork configuration. + Config() *params.ChainConfig + + // Genesis retrieves the chain's genesis block. + Genesis() *types.Block + + // CurrentHeader retrieves the current head header of the canonical chain. + CurrentHeader() *types.Header +} + // ID is a fork identifier as defined by EIP-2124. type ID struct { Hash [4]byte // CRC32 checksum of the genesis block and passed fork block numbers @@ -54,7 +66,7 @@ type ID struct { type Filter func(id ID) error // NewID calculates the Ethereum fork ID from the chain config and head. -func NewID(chain *core.BlockChain) ID { +func NewID(chain Blockchain) ID { return newID( chain.Config(), chain.Genesis().Hash(), @@ -85,7 +97,7 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { // NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. -func NewFilter(chain *core.BlockChain) Filter { +func NewFilter(chain Blockchain) Filter { return newFilter( chain.Config(), chain.Genesis().Hash(), diff --git a/eth/backend.go b/eth/backend.go index 391e3c0e6..e4f98360d 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -72,7 +72,7 @@ type Ethereum struct { blockchain *core.BlockChain protocolManager *ProtocolManager lesServer LesServer - dialCandiates enode.Iterator + dialCandidates enode.Iterator // DB interfaces chainDb ethdb.Database // Block chain database @@ -226,7 +226,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { } eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) - eth.dialCandiates, err = eth.setupDiscovery(&ctx.Config.P2P) + eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P) if err != nil { return nil, err } @@ -523,7 +523,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol { for i, vsn := range ProtocolVersions { protos[i] = s.protocolManager.makeProtocol(vsn) protos[i].Attributes = []enr.Entry{s.currentEthEntry()} - protos[i].DialCandidates = s.dialCandiates + protos[i].DialCandidates = s.dialCandidates } if s.lesServer != nil { protos = append(protos, s.lesServer.Protocols()...) diff --git a/les/client.go b/les/client.go index 2ac3285d8..34a654e22 100644 --- a/les/client.go +++ b/les/client.go @@ -51,16 +51,17 @@ import ( type LightEthereum struct { lesCommons - peers *serverPeerSet - reqDist *requestDistributor - retriever *retrieveManager - odr *LesOdr - relay *lesTxRelay - handler *clientHandler - txPool *light.TxPool - blockchain *light.LightChain - serverPool *serverPool - valueTracker *lpc.ValueTracker + peers *serverPeerSet + reqDist *requestDistributor + retriever *retrieveManager + odr *LesOdr + relay *lesTxRelay + handler *clientHandler + txPool *light.TxPool + blockchain *light.LightChain + serverPool *serverPool + valueTracker *lpc.ValueTracker + dialCandidates enode.Iterator bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports @@ -104,11 +105,19 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), bloomRequests: make(chan chan *bloombits.Retrieval), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), - serverPool: newServerPool(chainDb, config.UltraLightServers), valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), } peers.subscribe((*vtSubscription)(leth.valueTracker)) - leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) + + dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P) + if err != nil { + return nil, err + } + leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, dnsdisc, time.Second, nil, &mclock.System{}, config.UltraLightServers) + peers.subscribe(leth.serverPool) + leth.dialCandidates = leth.serverPool.dialIterator + + leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout) leth.relay = newLesTxRelay(peers, leth.retriever) leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever) @@ -140,11 +149,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { leth.chtIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain) - leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth) - if leth.handler.ulc != nil { - log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) - leth.blockchain.DisableCheckFreq() - } // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) @@ -159,6 +163,11 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { } leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) + leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth) + if leth.handler.ulc != nil { + log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction) + leth.blockchain.DisableCheckFreq() + } return leth, nil } @@ -260,7 +269,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { return p.Info() } return nil - }) + }, s.dialCandidates) } // Start implements node.Service, starting all internal goroutines needed by the @@ -268,15 +277,12 @@ func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Start(srvr *p2p.Server) error { log.Warn("Light client mode is an experimental feature") + s.serverPool.start() // Start bloom request workers. s.wg.Add(bloomServiceThreads) s.startBloomHandlers(params.BloomBitsBlocksClient) s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId) - - // clients are searching for the first advertised protocol in the list - protocolVersion := AdvertiseProtocolVersions[0] - s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) return nil } @@ -284,6 +290,8 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error { // Ethereum protocol. func (s *LightEthereum) Stop() error { close(s.closeCh) + s.serverPool.stop() + s.valueTracker.Stop() s.peers.close() s.reqDist.close() s.odr.Stop() @@ -295,8 +303,6 @@ func (s *LightEthereum) Stop() error { s.txPool.Stop() s.engine.Close() s.eventMux.Stop() - s.serverPool.stop() - s.valueTracker.Stop() s.chainDb.Close() s.wg.Wait() log.Info("Light ethereum stopped") diff --git a/les/client_handler.go b/les/client_handler.go index 6367fdb6b..abe472e46 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -64,7 +64,7 @@ func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.T if checkpoint != nil { height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 } - handler.fetcher = newLightFetcher(handler) + handler.fetcher = newLightFetcher(handler, backend.serverPool.getTimeout) handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer) handler.backend.peers.subscribe((*downloaderPeerNotify)(handler)) return handler @@ -85,14 +85,9 @@ func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) } peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version))) defer peer.close() - peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node()) - if peer.poolEntry == nil { - return p2p.DiscRequested - } h.wg.Add(1) defer h.wg.Done() err := h.handle(peer) - h.backend.serverPool.disconnect(peer.poolEntry) return err } @@ -129,10 +124,6 @@ func (h *clientHandler) handle(p *serverPeer) error { h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td}) - // pool entry can be nil during the unit test. - if p.poolEntry != nil { - h.backend.serverPool.registered(p.poolEntry) - } // Mark the peer starts to be served. atomic.StoreUint32(&p.serving, 1) defer atomic.StoreUint32(&p.serving, 0) diff --git a/les/commons.go b/les/commons.go index 29b5b7660..cd8a22834 100644 --- a/les/commons.go +++ b/les/commons.go @@ -81,7 +81,7 @@ type NodeInfo struct { } // makeProtocols creates protocol descriptors for the given LES versions. -func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol { +func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}, dialCandidates enode.Iterator) []p2p.Protocol { protos := make([]p2p.Protocol, len(versions)) for i, version := range versions { version := version @@ -93,7 +93,8 @@ func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error { return runPeer(version, peer, rw) }, - PeerInfo: peerInfo, + PeerInfo: peerInfo, + DialCandidates: dialCandidates, } } return protos diff --git a/les/distributor.go b/les/distributor.go index 97d2ccdfe..31150e4d7 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -180,12 +180,11 @@ func (d *requestDistributor) loop() { type selectPeerItem struct { peer distPeer req *distReq - weight int64 + weight uint64 } -// Weight implements wrsItem interface -func (sp selectPeerItem) Weight() int64 { - return sp.weight +func selectPeerWeight(i interface{}) uint64 { + return i.(selectPeerItem).weight } // nextRequest returns the next possible request from any peer, along with the @@ -220,9 +219,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) { wait, bufRemain := peer.waitBefore(cost) if wait == 0 { if sel == nil { - sel = utils.NewWeightedRandomSelect() + sel = utils.NewWeightedRandomSelect(selectPeerWeight) } - sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1}) + sel.Update(selectPeerItem{peer: peer, req: req, weight: uint64(bufRemain*1000000) + 1}) } else { if bestWait == 0 || wait < bestWait { bestWait = wait diff --git a/les/enr_entry.go b/les/enr_entry.go index c2a92dd99..65d0d1fdb 100644 --- a/les/enr_entry.go +++ b/les/enr_entry.go @@ -17,6 +17,9 @@ package les import ( + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/dnsdisc" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" ) @@ -30,3 +33,12 @@ type lesEntry struct { func (e lesEntry) ENRKey() string { return "les" } + +// setupDiscovery creates the node discovery source for the eth protocol. +func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) { + if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 { + return nil, nil + } + client := dnsdisc.NewClient(dnsdisc.Config{}) + return client.NewIterator(eth.config.DiscoveryURLs...) +} diff --git a/les/fetcher.go b/les/fetcher.go index 7fa81cbcb..aaf74aaa1 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -40,8 +40,9 @@ const ( // ODR system to ensure that we only request data related to a certain block from peers who have already processed // and announced that block. type lightFetcher struct { - handler *clientHandler - chain *light.LightChain + handler *clientHandler + chain *light.LightChain + softRequestTimeout func() time.Duration lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests maxConfirmedTd *big.Int @@ -109,18 +110,19 @@ type fetchResponse struct { } // newLightFetcher creates a new light fetcher -func newLightFetcher(h *clientHandler) *lightFetcher { +func newLightFetcher(h *clientHandler, softRequestTimeout func() time.Duration) *lightFetcher { f := &lightFetcher{ - handler: h, - chain: h.backend.blockchain, - peers: make(map[*serverPeer]*fetcherPeerInfo), - deliverChn: make(chan fetchResponse, 100), - requested: make(map[uint64]fetchRequest), - timeoutChn: make(chan uint64), - requestTrigger: make(chan struct{}, 1), - syncDone: make(chan *serverPeer), - closeCh: make(chan struct{}), - maxConfirmedTd: big.NewInt(0), + handler: h, + chain: h.backend.blockchain, + peers: make(map[*serverPeer]*fetcherPeerInfo), + deliverChn: make(chan fetchResponse, 100), + requested: make(map[uint64]fetchRequest), + timeoutChn: make(chan uint64), + requestTrigger: make(chan struct{}, 1), + syncDone: make(chan *serverPeer), + closeCh: make(chan struct{}), + maxConfirmedTd: big.NewInt(0), + softRequestTimeout: softRequestTimeout, } h.backend.peers.subscribe(f) @@ -163,7 +165,7 @@ func (f *lightFetcher) syncLoop() { f.lock.Unlock() } else { go func() { - time.Sleep(softRequestTimeout) + time.Sleep(f.softRequestTimeout()) f.reqMu.Lock() req, ok := f.requested[reqID] if ok { @@ -187,7 +189,6 @@ func (f *lightFetcher) syncLoop() { } f.reqMu.Unlock() if ok { - f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) req.peer.Log().Debug("Fetching data timed out hard") go f.handler.removePeer(req.peer.id) } @@ -201,9 +202,6 @@ func (f *lightFetcher) syncLoop() { delete(f.requested, resp.reqID) } f.reqMu.Unlock() - if ok { - f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) - } f.lock.Lock() if !ok || !(f.syncing || f.processResponse(req, resp)) { resp.peer.Log().Debug("Failed processing response") @@ -879,12 +877,10 @@ func (f *lightFetcher) checkUpdateStats(p *serverPeer, newEntry *updateStatsEntr fp.firstUpdateStats = newEntry } for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { - f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) fp.firstUpdateStats = fp.firstUpdateStats.next } if fp.confirmedTd != nil { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { - f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) fp.firstUpdateStats = fp.firstUpdateStats.next } } diff --git a/les/lespay/client/fillset.go b/les/lespay/client/fillset.go new file mode 100644 index 000000000..0da850bca --- /dev/null +++ b/les/lespay/client/fillset.go @@ -0,0 +1,107 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// FillSet tries to read nodes from an input iterator and add them to a node set by +// setting the specified node state flag(s) until the size of the set reaches the target. +// Note that other mechanisms (like other FillSet instances reading from different inputs) +// can also set the same flag(s) and FillSet will always care about the total number of +// nodes having those flags. +type FillSet struct { + lock sync.Mutex + cond *sync.Cond + ns *nodestate.NodeStateMachine + input enode.Iterator + closed bool + flags nodestate.Flags + count, target int +} + +// NewFillSet creates a new FillSet +func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet { + fs := &FillSet{ + ns: ns, + input: input, + flags: flags, + } + fs.cond = sync.NewCond(&fs.lock) + + ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) { + fs.lock.Lock() + if oldState.Equals(flags) { + fs.count-- + } + if newState.Equals(flags) { + fs.count++ + } + if fs.target > fs.count { + fs.cond.Signal() + } + fs.lock.Unlock() + }) + + go fs.readLoop() + return fs +} + +// readLoop keeps reading nodes from the input and setting the specified flags for them +// whenever the node set size is under the current target +func (fs *FillSet) readLoop() { + for { + fs.lock.Lock() + for fs.target <= fs.count && !fs.closed { + fs.cond.Wait() + } + + fs.lock.Unlock() + if !fs.input.Next() { + return + } + fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0) + } +} + +// SetTarget sets the current target for node set size. If the previous target was not +// reached and FillSet was still waiting for the next node from the input then the next +// incoming node will be added to the set regardless of the target. This ensures that +// all nodes coming from the input are eventually added to the set. +func (fs *FillSet) SetTarget(target int) { + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.target = target + if fs.target > fs.count { + fs.cond.Signal() + } +} + +// Close shuts FillSet down and closes the input iterator +func (fs *FillSet) Close() { + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.closed = true + fs.input.Close() + fs.cond.Signal() +} diff --git a/les/lespay/client/fillset_test.go b/les/lespay/client/fillset_test.go new file mode 100644 index 000000000..58240682c --- /dev/null +++ b/les/lespay/client/fillset_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +type testIter struct { + waitCh chan struct{} + nodeCh chan *enode.Node + node *enode.Node +} + +func (i *testIter) Next() bool { + i.waitCh <- struct{}{} + i.node = <-i.nodeCh + return i.node != nil +} + +func (i *testIter) Node() *enode.Node { + return i.node +} + +func (i *testIter) Close() {} + +func (i *testIter) push() { + var id enode.ID + rand.Read(id[:]) + i.nodeCh <- enode.SignNull(new(enr.Record), id) +} + +func (i *testIter) waiting(timeout time.Duration) bool { + select { + case <-i.waitCh: + return true + case <-time.After(timeout): + return false + } +} + +func TestFillSet(t *testing.T) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + iter := &testIter{ + waitCh: make(chan struct{}), + nodeCh: make(chan *enode.Node), + } + fs := NewFillSet(ns, iter, sfTest1) + ns.Start() + + expWaiting := func(i int, push bool) { + for ; i > 0; i-- { + if !iter.waiting(time.Second * 10) { + t.Fatalf("FillSet not waiting for new nodes") + } + if push { + iter.push() + } + } + } + + expNotWaiting := func() { + if iter.waiting(time.Millisecond * 100) { + t.Fatalf("FillSet unexpectedly waiting for new nodes") + } + } + + expNotWaiting() + fs.SetTarget(3) + expWaiting(3, true) + expNotWaiting() + fs.SetTarget(100) + expWaiting(2, true) + expWaiting(1, false) + // lower the target before the previous one has been filled up + fs.SetTarget(0) + iter.push() + expNotWaiting() + fs.SetTarget(10) + expWaiting(4, true) + expNotWaiting() + // remove all previosly set flags + ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + ns.SetState(node, nodestate.Flags{}, sfTest1, 0) + }) + // now expect FillSet to fill the set up again with 10 new nodes + expWaiting(10, true) + expNotWaiting() + + fs.Close() + ns.Stop() +} diff --git a/les/lespay/client/queueiterator.go b/les/lespay/client/queueiterator.go new file mode 100644 index 000000000..ad3f8df5b --- /dev/null +++ b/les/lespay/client/queueiterator.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// QueueIterator returns nodes from the specified selectable set in the same order as +// they entered the set. +type QueueIterator struct { + lock sync.Mutex + cond *sync.Cond + + ns *nodestate.NodeStateMachine + queue []*enode.Node + nextNode *enode.Node + waitCallback func(bool) + fifo, closed bool +} + +// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required +// and none of the disabled flags set. When a node is selected the selectedFlag is set which also +// disables further selectability until it is removed or times out. +func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator { + qi := &QueueIterator{ + ns: ns, + fifo: fifo, + waitCallback: waitCallback, + } + qi.cond = sync.NewCond(&qi.lock) + + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + qi.lock.Lock() + defer qi.lock.Unlock() + + if newMatch { + qi.queue = append(qi.queue, n) + } else { + id := n.ID() + for i, qn := range qi.queue { + if qn.ID() == id { + copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:]) + qi.queue = qi.queue[:len(qi.queue)-1] + break + } + } + } + qi.cond.Signal() + }) + return qi +} + +// Next moves to the next selectable node. +func (qi *QueueIterator) Next() bool { + qi.lock.Lock() + if !qi.closed && len(qi.queue) == 0 { + if qi.waitCallback != nil { + qi.waitCallback(true) + } + for !qi.closed && len(qi.queue) == 0 { + qi.cond.Wait() + } + if qi.waitCallback != nil { + qi.waitCallback(false) + } + } + if qi.closed { + qi.nextNode = nil + qi.lock.Unlock() + return false + } + // Move to the next node in queue. + if qi.fifo { + qi.nextNode = qi.queue[0] + copy(qi.queue[:len(qi.queue)-1], qi.queue[1:]) + qi.queue = qi.queue[:len(qi.queue)-1] + } else { + qi.nextNode = qi.queue[len(qi.queue)-1] + qi.queue = qi.queue[:len(qi.queue)-1] + } + qi.lock.Unlock() + return true +} + +// Close ends the iterator. +func (qi *QueueIterator) Close() { + qi.lock.Lock() + qi.closed = true + qi.lock.Unlock() + qi.cond.Signal() +} + +// Node returns the current node. +func (qi *QueueIterator) Node() *enode.Node { + qi.lock.Lock() + defer qi.lock.Unlock() + + return qi.nextNode +} diff --git a/les/lespay/client/queueiterator_test.go b/les/lespay/client/queueiterator_test.go new file mode 100644 index 000000000..a74301c7d --- /dev/null +++ b/les/lespay/client/queueiterator_test.go @@ -0,0 +1,106 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +func testNodeID(i int) enode.ID { + return enode.ID{42, byte(i % 256), byte(i / 256)} +} + +func testNodeIndex(id enode.ID) int { + if id[0] != 42 { + return -1 + } + return int(id[1]) + int(id[2])*256 +} + +func testNode(i int) *enode.Node { + return enode.SignNull(new(enr.Record), testNodeID(i)) +} + +func TestQueueIteratorFIFO(t *testing.T) { + testQueueIterator(t, true) +} + +func TestQueueIteratorLIFO(t *testing.T) { + testQueueIterator(t, false) +} + +func testQueueIterator(t *testing.T, fifo bool) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil) + ns.Start() + for i := 1; i <= iterTestNodeCount; i++ { + ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) + } + next := func() int { + ch := make(chan struct{}) + go func() { + qi.Next() + close(ch) + }() + select { + case <-ch: + case <-time.After(time.Second * 5): + t.Fatalf("Iterator.Next() timeout") + } + node := qi.Node() + ns.SetState(node, sfTest4, nodestate.Flags{}, 0) + return testNodeIndex(node.ID()) + } + exp := func(i int) { + n := next() + if n != i { + t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n) + } + } + explist := func(list []int) { + for i := range list { + if fifo { + exp(list[i]) + } else { + exp(list[len(list)-1-i]) + } + } + } + + ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) + explist([]int{1, 2, 3}) + ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0) + explist([]int{4, 6}) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0) + explist([]int{1, 3, 2}) + ns.Stop() +} diff --git a/les/lespay/client/valuetracker.go b/les/lespay/client/valuetracker.go index 92bfd694e..4e67b31d9 100644 --- a/les/lespay/client/valuetracker.go +++ b/les/lespay/client/valuetracker.go @@ -213,6 +213,15 @@ func (vt *ValueTracker) StatsExpirer() *utils.Expirer { return &vt.statsExpirer } +// StatsExpirer returns the current expiration factor so that other values can be expired +// with the same rate as the service value statistics. +func (vt *ValueTracker) StatsExpFactor() utils.ExpirationFactor { + vt.statsExpLock.RLock() + defer vt.statsExpLock.RUnlock() + + return vt.statsExpFactor +} + // loadFromDb loads the value tracker's state from the database and converts saved // request basket index mapping if it does not match the specified index to name mapping. func (vt *ValueTracker) loadFromDb(mapping []string) error { @@ -500,16 +509,3 @@ func (vt *ValueTracker) RequestStats() []RequestStatsItem { } return res } - -// TotalServiceValue returns the total service value provided by the given node (as -// a function of the weights which are calculated from the request timeout value). -func (vt *ValueTracker) TotalServiceValue(nv *NodeValueTracker, weights ResponseTimeWeights) float64 { - vt.statsExpLock.RLock() - expFactor := vt.statsExpFactor - vt.statsExpLock.RUnlock() - - nv.lock.Lock() - defer nv.lock.Unlock() - - return nv.rtStats.Value(weights, expFactor) -} diff --git a/les/lespay/client/wrsiterator.go b/les/lespay/client/wrsiterator.go new file mode 100644 index 000000000..8a2e39ad4 --- /dev/null +++ b/les/lespay/client/wrsiterator.go @@ -0,0 +1,128 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "sync" + + "github.com/ethereum/go-ethereum/les/utils" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +// WrsIterator returns nodes from the specified selectable set with a weighted random +// selection. Selection weights are provided by a callback function. +type WrsIterator struct { + lock sync.Mutex + cond *sync.Cond + + ns *nodestate.NodeStateMachine + wrs *utils.WeightedRandomSelect + nextNode *enode.Node + closed bool +} + +// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required +// and none of the disabled flags set. When a node is selected the selectedFlag is set which also +// disables further selectability until it is removed or times out. +func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator { + wfn := func(i interface{}) uint64 { + n := ns.GetNode(i.(enode.ID)) + if n == nil { + return 0 + } + wt, _ := ns.GetField(n, weightField).(uint64) + return wt + } + + w := &WrsIterator{ + ns: ns, + wrs: utils.NewWeightedRandomSelect(wfn), + } + w.cond = sync.NewCond(&w.lock) + + ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) { + if state.HasAll(requireFlags) && state.HasNone(disableFlags) { + w.lock.Lock() + w.wrs.Update(n.ID()) + w.lock.Unlock() + w.cond.Signal() + } + }) + + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + w.lock.Lock() + if newMatch { + w.wrs.Update(n.ID()) + } else { + w.wrs.Remove(n.ID()) + } + w.lock.Unlock() + w.cond.Signal() + }) + return w +} + +// Next selects the next node. +func (w *WrsIterator) Next() bool { + w.nextNode = w.chooseNode() + return w.nextNode != nil +} + +func (w *WrsIterator) chooseNode() *enode.Node { + w.lock.Lock() + defer w.lock.Unlock() + + for { + for !w.closed && w.wrs.IsEmpty() { + w.cond.Wait() + } + if w.closed { + return nil + } + // Choose the next node at random. Even though w.wrs is guaranteed + // non-empty here, Choose might return nil if all items have weight + // zero. + if c := w.wrs.Choose(); c != nil { + id := c.(enode.ID) + w.wrs.Remove(id) + return w.ns.GetNode(id) + } + } + +} + +// Close ends the iterator. +func (w *WrsIterator) Close() { + w.lock.Lock() + w.closed = true + w.lock.Unlock() + w.cond.Signal() +} + +// Node returns the current node. +func (w *WrsIterator) Node() *enode.Node { + w.lock.Lock() + defer w.lock.Unlock() + return w.nextNode +} diff --git a/les/lespay/client/wrsiterator_test.go b/les/lespay/client/wrsiterator_test.go new file mode 100644 index 000000000..77bb5ee0c --- /dev/null +++ b/les/lespay/client/wrsiterator_test.go @@ -0,0 +1,103 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package client + +import ( + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/nodestate" +) + +var ( + testSetup = &nodestate.Setup{} + sfTest1 = testSetup.NewFlag("test1") + sfTest2 = testSetup.NewFlag("test2") + sfTest3 = testSetup.NewFlag("test3") + sfTest4 = testSetup.NewFlag("test4") + sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0))) +) + +const iterTestNodeCount = 6 + +func TestWrsIterator(t *testing.T) { + ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup) + w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight) + ns.Start() + for i := 1; i <= iterTestNodeCount; i++ { + ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0) + ns.SetField(testNode(i), sfiTestWeight, uint64(1)) + } + next := func() int { + ch := make(chan struct{}) + go func() { + w.Next() + close(ch) + }() + select { + case <-ch: + case <-time.After(time.Second * 5): + t.Fatalf("Iterator.Next() timeout") + } + node := w.Node() + ns.SetState(node, sfTest4, nodestate.Flags{}, 0) + return testNodeIndex(node.ID()) + } + set := make(map[int]bool) + expset := func() { + for len(set) > 0 { + n := next() + if !set[n] { + t.Errorf("Item returned by iterator not in the expected set (got %d)", n) + } + delete(set, n) + } + } + + ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0) + set[1] = true + set[2] = true + set[3] = true + expset() + ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0) + ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0) + ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0) + set[4] = true + set[6] = true + expset() + ns.SetField(testNode(2), sfiTestWeight, uint64(0)) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + set[1] = true + set[3] = true + expset() + ns.SetField(testNode(2), sfiTestWeight, uint64(1)) + ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0) + ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0) + ns.SetState(testNode(2), sfTest2, sfTest4, 0) + ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0) + set[1] = true + set[2] = true + set[3] = true + expset() + ns.Stop() +} diff --git a/les/metrics.go b/les/metrics.go index 9ef8c3651..c5edb61c3 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -107,6 +107,13 @@ var ( requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil) requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil) + + serverSelectableGauge = metrics.NewRegisteredGauge("les/client/serverPool/selectable", nil) + serverDialedMeter = metrics.NewRegisteredMeter("les/client/serverPool/dialed", nil) + serverConnectedGauge = metrics.NewRegisteredGauge("les/client/serverPool/connected", nil) + sessionValueMeter = metrics.NewRegisteredMeter("les/client/serverPool/sessionValue", nil) + totalValueGauge = metrics.NewRegisteredGauge("les/client/serverPool/totalValue", nil) + suggestedTimeoutGauge = metrics.NewRegisteredGauge("les/client/serverPool/timeout", nil) ) // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of diff --git a/les/peer.go b/les/peer.go index e92b4580d..4793d9026 100644 --- a/les/peer.go +++ b/les/peer.go @@ -336,7 +336,6 @@ type serverPeer struct { checkpointNumber uint64 // The block height which the checkpoint is registered. checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server. - poolEntry *poolEntry // Statistic for server peer. fcServer *flowcontrol.ServerNode // Client side mirror token bucket. vtLock sync.Mutex valueTracker *lpc.ValueTracker diff --git a/les/protocol.go b/les/protocol.go index f8ad94a7b..4fd19f9be 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -130,7 +130,6 @@ func init() { } requestMapping[uint32(code)] = rm } - } type errCode int diff --git a/les/retrieve.go b/les/retrieve.go index 5fa68b745..4f77004f2 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -24,22 +24,20 @@ import ( "sync" "time" - "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/light" ) var ( retryQueue = time.Millisecond * 100 - softRequestTimeout = time.Millisecond * 500 hardRequestTimeout = time.Second * 10 ) // retrieveManager is a layer on top of requestDistributor which takes care of // matching replies by request ID and handles timeouts and resends if necessary. type retrieveManager struct { - dist *requestDistributor - peers *serverPeerSet - serverPool peerSelector + dist *requestDistributor + peers *serverPeerSet + softRequestTimeout func() time.Duration lock sync.RWMutex sentReqs map[uint64]*sentReq @@ -48,11 +46,6 @@ type retrieveManager struct { // validatorFunc is a function that processes a reply message type validatorFunc func(distPeer, *Msg) error -// peerSelector receives feedback info about response times and timeouts -type peerSelector interface { - adjustResponseTime(*poolEntry, time.Duration, bool) -} - // sentReq represents a request sent and tracked by retrieveManager type sentReq struct { rm *retrieveManager @@ -99,12 +92,12 @@ const ( ) // newRetrieveManager creates the retrieve manager -func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager { +func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, srto func() time.Duration) *retrieveManager { return &retrieveManager{ - peers: peers, - dist: dist, - serverPool: serverPool, - sentReqs: make(map[uint64]*sentReq), + peers: peers, + dist: dist, + sentReqs: make(map[uint64]*sentReq), + softRequestTimeout: srto, } } @@ -325,8 +318,7 @@ func (r *sentReq) tryRequest() { return } - reqSent := mclock.Now() - srto, hrto := false, false + hrto := false r.lock.RLock() s, ok := r.sentTo[p] @@ -338,11 +330,7 @@ func (r *sentReq) tryRequest() { defer func() { // send feedback to server pool and remove peer if hard timeout happened pp, ok := p.(*serverPeer) - if ok && r.rm.serverPool != nil { - respTime := time.Duration(mclock.Now() - reqSent) - r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto) - } - if hrto { + if hrto && ok { pp.Log().Debug("Request timed out hard") if r.rm.peers != nil { r.rm.peers.unregister(pp.id) @@ -363,8 +351,7 @@ func (r *sentReq) tryRequest() { } r.eventsCh <- reqPeerEvent{event, p} return - case <-time.After(softRequestTimeout): - srto = true + case <-time.After(r.rm.softRequestTimeout()): r.eventsCh <- reqPeerEvent{rpSoftTimeout, p} } diff --git a/les/server.go b/les/server.go index f72f31321..4b623f61e 100644 --- a/les/server.go +++ b/les/server.go @@ -157,7 +157,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { return p.Info() } return nil - }) + }, nil) // Add "les" ENR entries. for i := range ps { ps[i].Attributes = []enr.Entry{&lesEntry{}} diff --git a/les/serverpool.go b/les/serverpool.go index d1c53295a..aff774324 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -1,4 +1,4 @@ -// Copyright 2016 The go-ethereum Authors +// Copyright 2020 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -17,904 +17,457 @@ package les import ( - "crypto/ecdsa" - "fmt" - "io" - "math" + "errors" "math/rand" - "net" - "strconv" + "reflect" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" + lpc "github.com/ethereum/go-ethereum/les/lespay/client" "github.com/ethereum/go-ethereum/les/utils" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/rlp" ) const ( - // After a connection has been ended or timed out, there is a waiting period - // before it can be selected for connection again. - // waiting period = base delay * (1 + random(1)) - // base delay = shortRetryDelay for the first shortRetryCnt times after a - // successful connection, after that longRetryDelay is applied - shortRetryCnt = 5 - shortRetryDelay = time.Second * 5 - longRetryDelay = time.Minute * 10 - // maxNewEntries is the maximum number of newly discovered (never connected) nodes. - // If the limit is reached, the least recently discovered one is thrown out. - maxNewEntries = 1000 - // maxKnownEntries is the maximum number of known (already connected) nodes. - // If the limit is reached, the least recently connected one is thrown out. - // (not that unlike new entries, known entries are persistent) - maxKnownEntries = 1000 - // target for simultaneously connected servers - targetServerCount = 5 - // target for servers selected from the known table - // (we leave room for trying new ones if there is any) - targetKnownSelect = 3 - // after dialTimeout, consider the server unavailable and adjust statistics - dialTimeout = time.Second * 30 - // targetConnTime is the minimum expected connection duration before a server - // drops a client without any specific reason - targetConnTime = time.Minute * 10 - // new entry selection weight calculation based on most recent discovery time: - // unity until discoverExpireStart, then exponential decay with discoverExpireConst - discoverExpireStart = time.Minute * 20 - discoverExpireConst = time.Minute * 20 - // known entry selection weight is dropped by a factor of exp(-failDropLn) after - // each unsuccessful connection (restored after a successful one) - failDropLn = 0.1 - // known node connection success and quality statistics have a long term average - // and a short term value which is adjusted exponentially with a factor of - // pstatRecentAdjust with each dial/connection and also returned exponentially - // to the average with the time constant pstatReturnToMeanTC - pstatReturnToMeanTC = time.Hour - // node address selection weight is dropped by a factor of exp(-addrFailDropLn) after - // each unsuccessful connection (restored after a successful one) - addrFailDropLn = math.Ln2 - // responseScoreTC and delayScoreTC are exponential decay time constants for - // calculating selection chances from response times and block delay times - responseScoreTC = time.Millisecond * 100 - delayScoreTC = time.Second * 5 - timeoutPow = 10 - // initStatsWeight is used to initialize previously unknown peers with good - // statistics to give a chance to prove themselves - initStatsWeight = 1 + minTimeout = time.Millisecond * 500 // minimum request timeout suggested by the server pool + timeoutRefresh = time.Second * 5 // recalculate timeout if older than this + dialCost = 10000 // cost of a TCP dial (used for known node selection weight calculation) + dialWaitStep = 1.5 // exponential multiplier of redial wait time when no value was provided by the server + queryCost = 500 // cost of a UDP pre-negotiation query + queryWaitStep = 1.02 // exponential multiplier of redial wait time when no value was provided by the server + waitThreshold = time.Hour * 2000 // drop node if waiting time is over the threshold + nodeWeightMul = 1000000 // multiplier constant for node weight calculation + nodeWeightThreshold = 100 // minimum weight for keeping a node in the the known (valuable) set + minRedialWait = 10 // minimum redial wait time in seconds + preNegLimit = 5 // maximum number of simultaneous pre-negotiation queries + maxQueryFails = 100 // number of consecutive UDP query failures before we print a warning ) -// connReq represents a request for peer connection. -type connReq struct { - p *serverPeer - node *enode.Node - result chan *poolEntry -} - -// disconnReq represents a request for peer disconnection. -type disconnReq struct { - entry *poolEntry - stopped bool - done chan struct{} -} - -// registerReq represents a request for peer registration. -type registerReq struct { - entry *poolEntry - done chan struct{} -} - -// serverPool implements a pool for storing and selecting newly discovered and already -// known light server nodes. It received discovered nodes, stores statistics about -// known nodes and takes care of always having enough good quality servers connected. +// serverPool provides a node iterator for dial candidates. The output is a mix of newly discovered +// nodes, a weighted random selection of known (previously valuable) nodes and trusted/paid nodes. type serverPool struct { - db ethdb.Database - dbKey []byte - server *p2p.Server - connWg sync.WaitGroup + clock mclock.Clock + unixTime func() int64 + db ethdb.KeyValueStore - topic discv5.Topic + ns *nodestate.NodeStateMachine + vt *lpc.ValueTracker + mixer *enode.FairMix + mixSources []enode.Iterator + dialIterator enode.Iterator + validSchemes enr.IdentityScheme + trustedURLs []string + fillSet *lpc.FillSet + queryFails uint32 - discSetPeriod chan time.Duration - discNodes chan *enode.Node - discLookups chan bool - - trustedNodes map[enode.ID]*enode.Node - entries map[enode.ID]*poolEntry - timeout, enableRetry chan *poolEntry - adjustStats chan poolStatAdjust - - knownQueue, newQueue poolEntryQueue - knownSelect, newSelect *utils.WeightedRandomSelect - knownSelected, newSelected int - fastDiscover bool - connCh chan *connReq - disconnCh chan *disconnReq - registerCh chan *registerReq - - closeCh chan struct{} - wg sync.WaitGroup + timeoutLock sync.RWMutex + timeout time.Duration + timeWeights lpc.ResponseTimeWeights + timeoutRefreshed mclock.AbsTime } -// newServerPool creates a new serverPool instance -func newServerPool(db ethdb.Database, ulcServers []string) *serverPool { - pool := &serverPool{ +// nodeHistory keeps track of dial costs which determine node weight together with the +// service value calculated by lpc.ValueTracker. +type nodeHistory struct { + dialCost utils.ExpiredValue + redialWaitStart, redialWaitEnd int64 // unix time (seconds) +} + +type nodeHistoryEnc struct { + DialCost utils.ExpiredValue + RedialWaitStart, RedialWaitEnd uint64 +} + +// queryFunc sends a pre-negotiation query and blocks until a response arrives or timeout occurs. +// It returns 1 if the remote node has confirmed that connection is possible, 0 if not +// possible and -1 if no response arrived (timeout). +type queryFunc func(*enode.Node) int + +var ( + serverPoolSetup = &nodestate.Setup{Version: 1} + sfHasValue = serverPoolSetup.NewPersistentFlag("hasValue") + sfQueried = serverPoolSetup.NewFlag("queried") + sfCanDial = serverPoolSetup.NewFlag("canDial") + sfDialing = serverPoolSetup.NewFlag("dialed") + sfWaitDialTimeout = serverPoolSetup.NewFlag("dialTimeout") + sfConnected = serverPoolSetup.NewFlag("connected") + sfRedialWait = serverPoolSetup.NewFlag("redialWait") + sfAlwaysConnect = serverPoolSetup.NewFlag("alwaysConnect") + sfDisableSelection = nodestate.MergeFlags(sfQueried, sfCanDial, sfDialing, sfConnected, sfRedialWait) + + sfiNodeHistory = serverPoolSetup.NewPersistentField("nodeHistory", reflect.TypeOf(nodeHistory{}), + func(field interface{}) ([]byte, error) { + if n, ok := field.(nodeHistory); ok { + ne := nodeHistoryEnc{ + DialCost: n.dialCost, + RedialWaitStart: uint64(n.redialWaitStart), + RedialWaitEnd: uint64(n.redialWaitEnd), + } + enc, err := rlp.EncodeToBytes(&ne) + return enc, err + } else { + return nil, errors.New("invalid field type") + } + }, + func(enc []byte) (interface{}, error) { + var ne nodeHistoryEnc + err := rlp.DecodeBytes(enc, &ne) + n := nodeHistory{ + dialCost: ne.DialCost, + redialWaitStart: int64(ne.RedialWaitStart), + redialWaitEnd: int64(ne.RedialWaitEnd), + } + return n, err + }, + ) + sfiNodeWeight = serverPoolSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0))) + sfiConnectedStats = serverPoolSetup.NewField("connectedStats", reflect.TypeOf(lpc.ResponseTimeStats{})) +) + +// newServerPool creates a new server pool +func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, discovery enode.Iterator, mixTimeout time.Duration, query queryFunc, clock mclock.Clock, trustedURLs []string) *serverPool { + s := &serverPool{ db: db, - entries: make(map[enode.ID]*poolEntry), - timeout: make(chan *poolEntry, 1), - adjustStats: make(chan poolStatAdjust, 100), - enableRetry: make(chan *poolEntry, 1), - connCh: make(chan *connReq), - disconnCh: make(chan *disconnReq), - registerCh: make(chan *registerReq), - closeCh: make(chan struct{}), - knownSelect: utils.NewWeightedRandomSelect(), - newSelect: utils.NewWeightedRandomSelect(), - fastDiscover: true, - trustedNodes: parseTrustedNodes(ulcServers), + clock: clock, + unixTime: func() int64 { return time.Now().Unix() }, + validSchemes: enode.ValidSchemes, + trustedURLs: trustedURLs, + vt: vt, + ns: nodestate.NewNodeStateMachine(db, []byte(string(dbKey)+"ns:"), clock, serverPoolSetup), + } + s.recalTimeout() + s.mixer = enode.NewFairMix(mixTimeout) + knownSelector := lpc.NewWrsIterator(s.ns, sfHasValue, sfDisableSelection, sfiNodeWeight) + alwaysConnect := lpc.NewQueueIterator(s.ns, sfAlwaysConnect, sfDisableSelection, true, nil) + s.mixSources = append(s.mixSources, knownSelector) + s.mixSources = append(s.mixSources, alwaysConnect) + if discovery != nil { + s.mixSources = append(s.mixSources, discovery) } - pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) - pool.newQueue = newPoolEntryQueue(maxNewEntries, pool.removeEntry) - return pool + iter := enode.Iterator(s.mixer) + if query != nil { + iter = s.addPreNegFilter(iter, query) + } + s.dialIterator = enode.Filter(iter, func(node *enode.Node) bool { + s.ns.SetState(node, sfDialing, sfCanDial, 0) + s.ns.SetState(node, sfWaitDialTimeout, nodestate.Flags{}, time.Second*10) + return true + }) + + s.ns.SubscribeState(nodestate.MergeFlags(sfWaitDialTimeout, sfConnected), func(n *enode.Node, oldState, newState nodestate.Flags) { + if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() { + // dial timeout, no connection + s.setRedialWait(n, dialCost, dialWaitStep) + s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0) + } + }) + + s.ns.AddLogMetrics(sfHasValue, sfDisableSelection, "selectable", nil, nil, serverSelectableGauge) + s.ns.AddLogMetrics(sfDialing, nodestate.Flags{}, "dialed", serverDialedMeter, nil, nil) + s.ns.AddLogMetrics(sfConnected, nodestate.Flags{}, "connected", nil, nil, serverConnectedGauge) + return s } -func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { - pool.server = server - pool.topic = topic - pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) - pool.loadNodes() - pool.connectToTrustedNodes() - - if pool.server.DiscV5 != nil { - pool.discSetPeriod = make(chan time.Duration, 1) - pool.discNodes = make(chan *enode.Node, 100) - pool.discLookups = make(chan bool, 100) - go pool.discoverNodes() - } - pool.checkDial() - pool.wg.Add(1) - go pool.eventLoop() - - // Inject the bootstrap nodes as initial dial candiates. - pool.wg.Add(1) - go func() { - defer pool.wg.Done() - for _, n := range server.BootstrapNodes { - select { - case pool.discNodes <- n: - case <-pool.closeCh: +// addPreNegFilter installs a node filter mechanism that performs a pre-negotiation query. +// Nodes that are filtered out and does not appear on the output iterator are put back +// into redialWait state. +func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enode.Iterator { + s.fillSet = lpc.NewFillSet(s.ns, input, sfQueried) + s.ns.SubscribeState(sfQueried, func(n *enode.Node, oldState, newState nodestate.Flags) { + if newState.Equals(sfQueried) { + fails := atomic.LoadUint32(&s.queryFails) + if fails == maxQueryFails { + log.Warn("UDP pre-negotiation query does not seem to work") + } + if fails > maxQueryFails { + fails = maxQueryFails + } + if rand.Intn(maxQueryFails*2) < int(fails) { + // skip pre-negotiation with increasing chance, max 50% + // this ensures that the client can operate even if UDP is not working at all + s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + // set canDial before resetting queried so that FillSet will not read more + // candidates unnecessarily + s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) return } - } - }() -} - -func (pool *serverPool) stop() { - close(pool.closeCh) - pool.wg.Wait() -} - -// discoverNodes wraps SearchTopic, converting result nodes to enode.Node. -func (pool *serverPool) discoverNodes() { - ch := make(chan *discv5.Node) - go func() { - pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, ch, pool.discLookups) - close(ch) - }() - for n := range ch { - pubkey, err := decodePubkey64(n.ID[:]) - if err != nil { - continue - } - pool.discNodes <- enode.NewV4(pubkey, n.IP, int(n.TCP), int(n.UDP)) - } -} - -// connect should be called upon any incoming connection. If the connection has been -// dialed by the server pool recently, the appropriate pool entry is returned. -// Otherwise, the connection should be rejected. -// Note that whenever a connection has been accepted and a pool entry has been returned, -// disconnect should also always be called. -func (pool *serverPool) connect(p *serverPeer, node *enode.Node) *poolEntry { - log.Debug("Connect new entry", "enode", p.id) - req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} - select { - case pool.connCh <- req: - case <-pool.closeCh: - return nil - } - return <-req.result -} - -// registered should be called after a successful handshake -func (pool *serverPool) registered(entry *poolEntry) { - log.Debug("Registered new entry", "enode", entry.node.ID()) - req := ®isterReq{entry: entry, done: make(chan struct{})} - select { - case pool.registerCh <- req: - case <-pool.closeCh: - return - } - <-req.done -} - -// disconnect should be called when ending a connection. Service quality statistics -// can be updated optionally (not updated if no registration happened, in this case -// only connection statistics are updated, just like in case of timeout) -func (pool *serverPool) disconnect(entry *poolEntry) { - stopped := false - select { - case <-pool.closeCh: - stopped = true - default: - } - log.Debug("Disconnected old entry", "enode", entry.node.ID()) - req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})} - - // Block until disconnection request is served. - pool.disconnCh <- req - <-req.done -} - -const ( - pseBlockDelay = iota - pseResponseTime - pseResponseTimeout -) - -// poolStatAdjust records are sent to adjust peer block delay/response time statistics -type poolStatAdjust struct { - adjustType int - entry *poolEntry - time time.Duration -} - -// adjustBlockDelay adjusts the block announce delay statistics of a node -func (pool *serverPool) adjustBlockDelay(entry *poolEntry, time time.Duration) { - if entry == nil { - return - } - pool.adjustStats <- poolStatAdjust{pseBlockDelay, entry, time} -} - -// adjustResponseTime adjusts the request response time statistics of a node -func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration, timeout bool) { - if entry == nil { - return - } - if timeout { - pool.adjustStats <- poolStatAdjust{pseResponseTimeout, entry, time} - } else { - pool.adjustStats <- poolStatAdjust{pseResponseTime, entry, time} - } -} - -// eventLoop handles pool events and mutex locking for all internal functions -func (pool *serverPool) eventLoop() { - defer pool.wg.Done() - lookupCnt := 0 - var convTime mclock.AbsTime - if pool.discSetPeriod != nil { - pool.discSetPeriod <- time.Millisecond * 100 - } - - // disconnect updates service quality statistics depending on the connection time - // and disconnection initiator. - disconnect := func(req *disconnReq, stopped bool) { - // Handle peer disconnection requests. - entry := req.entry - if entry.state == psRegistered { - connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime) - if connAdjust > 1 { - connAdjust = 1 - } - if stopped { - // disconnect requested by ourselves. - entry.connectStats.add(1, connAdjust) - } else { - // disconnect requested by server side. - entry.connectStats.add(connAdjust, 1) - } - } - entry.state = psNotConnected - - if entry.knownSelected { - pool.knownSelected-- - } else { - pool.newSelected-- - } - pool.setRetryDial(entry) - pool.connWg.Done() - close(req.done) - } - - for { - select { - case entry := <-pool.timeout: - if !entry.removed { - pool.checkDialTimeout(entry) - } - - case entry := <-pool.enableRetry: - if !entry.removed { - entry.delayedRetry = false - pool.updateCheckDial(entry) - } - - case adj := <-pool.adjustStats: - switch adj.adjustType { - case pseBlockDelay: - adj.entry.delayStats.add(float64(adj.time), 1) - case pseResponseTime: - adj.entry.responseStats.add(float64(adj.time), 1) - adj.entry.timeoutStats.add(0, 1) - case pseResponseTimeout: - adj.entry.timeoutStats.add(1, 1) - } - - case node := <-pool.discNodes: - if pool.trustedNodes[node.ID()] == nil { - entry := pool.findOrNewNode(node) - pool.updateCheckDial(entry) - } - - case conv := <-pool.discLookups: - if conv { - if lookupCnt == 0 { - convTime = mclock.Now() - } - lookupCnt++ - if pool.fastDiscover && (lookupCnt == 50 || time.Duration(mclock.Now()-convTime) > time.Minute) { - pool.fastDiscover = false - if pool.discSetPeriod != nil { - pool.discSetPeriod <- time.Minute - } - } - } - - case req := <-pool.connCh: - if pool.trustedNodes[req.p.ID()] != nil { - // ignore trusted nodes - req.result <- &poolEntry{trusted: true} - } else { - // Handle peer connection requests. - entry := pool.entries[req.p.ID()] - if entry == nil { - entry = pool.findOrNewNode(req.node) - } - if entry.state == psConnected || entry.state == psRegistered { - req.result <- nil - continue - } - pool.connWg.Add(1) - entry.peer = req.p - entry.state = psConnected - addr := &poolEntryAddress{ - ip: req.node.IP(), - port: uint16(req.node.TCP()), - lastSeen: mclock.Now(), - } - entry.lastConnected = addr - entry.addr = make(map[string]*poolEntryAddress) - entry.addr[addr.strKey()] = addr - entry.addrSelect = *utils.NewWeightedRandomSelect() - entry.addrSelect.Update(addr) - req.result <- entry - } - - case req := <-pool.registerCh: - if req.entry.trusted { - continue - } - // Handle peer registration requests. - entry := req.entry - entry.state = psRegistered - entry.regTime = mclock.Now() - if !entry.known { - pool.newQueue.remove(entry) - entry.known = true - } - pool.knownQueue.setLatest(entry) - entry.shortRetry = shortRetryCnt - close(req.done) - - case req := <-pool.disconnCh: - if req.entry.trusted { - continue - } - // Handle peer disconnection requests. - disconnect(req, req.stopped) - - case <-pool.closeCh: - if pool.discSetPeriod != nil { - close(pool.discSetPeriod) - } - - // Spawn a goroutine to close the disconnCh after all connections are disconnected. go func() { - pool.connWg.Wait() - close(pool.disconnCh) + q := query(n) + if q == -1 { + atomic.AddUint32(&s.queryFails, 1) + } else { + atomic.StoreUint32(&s.queryFails, 0) + } + if q == 1 { + s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + } else { + s.setRedialWait(n, queryCost, queryWaitStep) + } + s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) }() - - // Handle all remaining disconnection requests before exit. - for req := range pool.disconnCh { - disconnect(req, true) - } - pool.saveNodes() - return } - } -} - -func (pool *serverPool) findOrNewNode(node *enode.Node) *poolEntry { - now := mclock.Now() - entry := pool.entries[node.ID()] - if entry == nil { - log.Debug("Discovered new entry", "id", node.ID()) - entry = &poolEntry{ - node: node, - addr: make(map[string]*poolEntryAddress), - addrSelect: *utils.NewWeightedRandomSelect(), - shortRetry: shortRetryCnt, + }) + return lpc.NewQueueIterator(s.ns, sfCanDial, nodestate.Flags{}, false, func(waiting bool) { + if waiting { + s.fillSet.SetTarget(preNegLimit) + } else { + s.fillSet.SetTarget(0) } - pool.entries[node.ID()] = entry - // initialize previously unknown peers with good statistics to give a chance to prove themselves - entry.connectStats.add(1, initStatsWeight) - entry.delayStats.add(0, initStatsWeight) - entry.responseStats.add(0, initStatsWeight) - entry.timeoutStats.add(0, initStatsWeight) - } - entry.lastDiscovered = now - addr := &poolEntryAddress{ip: node.IP(), port: uint16(node.TCP())} - if a, ok := entry.addr[addr.strKey()]; ok { - addr = a - } else { - entry.addr[addr.strKey()] = addr - } - addr.lastSeen = now - entry.addrSelect.Update(addr) - if !entry.known { - pool.newQueue.setLatest(entry) - } - return entry -} - -// loadNodes loads known nodes and their statistics from the database -func (pool *serverPool) loadNodes() { - enc, err := pool.db.Get(pool.dbKey) - if err != nil { - return - } - var list []*poolEntry - err = rlp.DecodeBytes(enc, &list) - if err != nil { - log.Debug("Failed to decode node list", "err", err) - return - } - for _, e := range list { - log.Debug("Loaded server stats", "id", e.node.ID(), "fails", e.lastConnected.fails, - "conn", fmt.Sprintf("%v/%v", e.connectStats.avg, e.connectStats.weight), - "delay", fmt.Sprintf("%v/%v", time.Duration(e.delayStats.avg), e.delayStats.weight), - "response", fmt.Sprintf("%v/%v", time.Duration(e.responseStats.avg), e.responseStats.weight), - "timeout", fmt.Sprintf("%v/%v", e.timeoutStats.avg, e.timeoutStats.weight)) - pool.entries[e.node.ID()] = e - if pool.trustedNodes[e.node.ID()] == nil { - pool.knownQueue.setLatest(e) - pool.knownSelect.Update((*knownEntry)(e)) - } - } -} - -// connectToTrustedNodes adds trusted server nodes as static trusted peers. -// -// Note: trusted nodes are not handled by the server pool logic, they are not -// added to either the known or new selection pools. They are connected/reconnected -// by p2p.Server whenever possible. -func (pool *serverPool) connectToTrustedNodes() { - //connect to trusted nodes - for _, node := range pool.trustedNodes { - pool.server.AddTrustedPeer(node) - pool.server.AddPeer(node) - log.Debug("Added trusted node", "id", node.ID().String()) - } -} - -// parseTrustedNodes returns valid and parsed enodes -func parseTrustedNodes(trustedNodes []string) map[enode.ID]*enode.Node { - nodes := make(map[enode.ID]*enode.Node) - - for _, node := range trustedNodes { - node, err := enode.Parse(enode.ValidSchemes, node) - if err != nil { - log.Warn("Trusted node URL invalid", "enode", node, "err", err) - continue - } - nodes[node.ID()] = node - } - return nodes -} - -// saveNodes saves known nodes and their statistics into the database. Nodes are -// ordered from least to most recently connected. -func (pool *serverPool) saveNodes() { - list := make([]*poolEntry, len(pool.knownQueue.queue)) - for i := range list { - list[i] = pool.knownQueue.fetchOldest() - } - enc, err := rlp.EncodeToBytes(list) - if err == nil { - pool.db.Put(pool.dbKey, enc) - } -} - -// removeEntry removes a pool entry when the entry count limit is reached. -// Note that it is called by the new/known queues from which the entry has already -// been removed so removing it from the queues is not necessary. -func (pool *serverPool) removeEntry(entry *poolEntry) { - pool.newSelect.Remove((*discoveredEntry)(entry)) - pool.knownSelect.Remove((*knownEntry)(entry)) - entry.removed = true - delete(pool.entries, entry.node.ID()) -} - -// setRetryDial starts the timer which will enable dialing a certain node again -func (pool *serverPool) setRetryDial(entry *poolEntry) { - delay := longRetryDelay - if entry.shortRetry > 0 { - entry.shortRetry-- - delay = shortRetryDelay - } - delay += time.Duration(rand.Int63n(int64(delay) + 1)) - entry.delayedRetry = true - go func() { - select { - case <-pool.closeCh: - case <-time.After(delay): - select { - case <-pool.closeCh: - case pool.enableRetry <- entry: - } - } - }() -} - -// updateCheckDial is called when an entry can potentially be dialed again. It updates -// its selection weights and checks if new dials can/should be made. -func (pool *serverPool) updateCheckDial(entry *poolEntry) { - pool.newSelect.Update((*discoveredEntry)(entry)) - pool.knownSelect.Update((*knownEntry)(entry)) - pool.checkDial() -} - -// checkDial checks if new dials can/should be made. It tries to select servers both -// based on good statistics and recent discovery. -func (pool *serverPool) checkDial() { - fillWithKnownSelects := !pool.fastDiscover - for pool.knownSelected < targetKnownSelect { - entry := pool.knownSelect.Choose() - if entry == nil { - fillWithKnownSelects = false - break - } - pool.dial((*poolEntry)(entry.(*knownEntry)), true) - } - for pool.knownSelected+pool.newSelected < targetServerCount { - entry := pool.newSelect.Choose() - if entry == nil { - break - } - pool.dial((*poolEntry)(entry.(*discoveredEntry)), false) - } - if fillWithKnownSelects { - // no more newly discovered nodes to select and since fast discover period - // is over, we probably won't find more in the near future so select more - // known entries if possible - for pool.knownSelected < targetServerCount { - entry := pool.knownSelect.Choose() - if entry == nil { - break - } - pool.dial((*poolEntry)(entry.(*knownEntry)), true) - } - } -} - -// dial initiates a new connection -func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { - if pool.server == nil || entry.state != psNotConnected { - return - } - entry.state = psDialed - entry.knownSelected = knownSelected - if knownSelected { - pool.knownSelected++ - } else { - pool.newSelected++ - } - addr := entry.addrSelect.Choose().(*poolEntryAddress) - log.Debug("Dialing new peer", "lesaddr", entry.node.ID().String()+"@"+addr.strKey(), "set", len(entry.addr), "known", knownSelected) - entry.dialed = addr - go func() { - pool.server.AddPeer(entry.node) - select { - case <-pool.closeCh: - case <-time.After(dialTimeout): - select { - case <-pool.closeCh: - case pool.timeout <- entry: - } - } - }() -} - -// checkDialTimeout checks if the node is still in dialed state and if so, resets it -// and adjusts connection statistics accordingly. -func (pool *serverPool) checkDialTimeout(entry *poolEntry) { - if entry.state != psDialed { - return - } - log.Debug("Dial timeout", "lesaddr", entry.node.ID().String()+"@"+entry.dialed.strKey()) - entry.state = psNotConnected - if entry.knownSelected { - pool.knownSelected-- - } else { - pool.newSelected-- - } - entry.connectStats.add(0, 1) - entry.dialed.fails++ - pool.setRetryDial(entry) -} - -const ( - psNotConnected = iota - psDialed - psConnected - psRegistered -) - -// poolEntry represents a server node and stores its current state and statistics. -type poolEntry struct { - peer *serverPeer - pubkey [64]byte // secp256k1 key of the node - addr map[string]*poolEntryAddress - node *enode.Node - lastConnected, dialed *poolEntryAddress - addrSelect utils.WeightedRandomSelect - - lastDiscovered mclock.AbsTime - known, knownSelected, trusted bool - connectStats, delayStats poolStats - responseStats, timeoutStats poolStats - state int - regTime mclock.AbsTime - queueIdx int - removed bool - - delayedRetry bool - shortRetry int -} - -// poolEntryEnc is the RLP encoding of poolEntry. -type poolEntryEnc struct { - Pubkey []byte - IP net.IP - Port uint16 - Fails uint - CStat, DStat, RStat, TStat poolStats -} - -func (e *poolEntry) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &poolEntryEnc{ - Pubkey: encodePubkey64(e.node.Pubkey()), - IP: e.lastConnected.ip, - Port: e.lastConnected.port, - Fails: e.lastConnected.fails, - CStat: e.connectStats, - DStat: e.delayStats, - RStat: e.responseStats, - TStat: e.timeoutStats, }) } -func (e *poolEntry) DecodeRLP(s *rlp.Stream) error { - var entry poolEntryEnc - if err := s.Decode(&entry); err != nil { - return err +// start starts the server pool. Note that NodeStateMachine should be started first. +func (s *serverPool) start() { + s.ns.Start() + for _, iter := range s.mixSources { + // add sources to mixer at startup because the mixer instantly tries to read them + // which should only happen after NodeStateMachine has been started + s.mixer.AddSource(iter) } - pubkey, err := decodePubkey64(entry.Pubkey) - if err != nil { - return err - } - addr := &poolEntryAddress{ip: entry.IP, port: entry.Port, fails: entry.Fails, lastSeen: mclock.Now()} - e.node = enode.NewV4(pubkey, entry.IP, int(entry.Port), int(entry.Port)) - e.addr = make(map[string]*poolEntryAddress) - e.addr[addr.strKey()] = addr - e.addrSelect = *utils.NewWeightedRandomSelect() - e.addrSelect.Update(addr) - e.lastConnected = addr - e.connectStats = entry.CStat - e.delayStats = entry.DStat - e.responseStats = entry.RStat - e.timeoutStats = entry.TStat - e.shortRetry = shortRetryCnt - e.known = true - return nil -} - -func encodePubkey64(pub *ecdsa.PublicKey) []byte { - return crypto.FromECDSAPub(pub)[1:] -} - -func decodePubkey64(b []byte) (*ecdsa.PublicKey, error) { - return crypto.UnmarshalPubkey(append([]byte{0x04}, b...)) -} - -// discoveredEntry implements wrsItem -type discoveredEntry poolEntry - -// Weight calculates random selection weight for newly discovered entries -func (e *discoveredEntry) Weight() int64 { - if e.state != psNotConnected || e.delayedRetry { - return 0 - } - t := time.Duration(mclock.Now() - e.lastDiscovered) - if t <= discoverExpireStart { - return 1000000000 - } - return int64(1000000000 * math.Exp(-float64(t-discoverExpireStart)/float64(discoverExpireConst))) -} - -// knownEntry implements wrsItem -type knownEntry poolEntry - -// Weight calculates random selection weight for known entries -func (e *knownEntry) Weight() int64 { - if e.state != psNotConnected || !e.known || e.delayedRetry { - return 0 - } - return int64(1000000000 * e.connectStats.recentAvg() * math.Exp(-float64(e.lastConnected.fails)*failDropLn-e.responseStats.recentAvg()/float64(responseScoreTC)-e.delayStats.recentAvg()/float64(delayScoreTC)) * math.Pow(1-e.timeoutStats.recentAvg(), timeoutPow)) -} - -// poolEntryAddress is a separate object because currently it is necessary to remember -// multiple potential network addresses for a pool entry. This will be removed after -// the final implementation of v5 discovery which will retrieve signed and serial -// numbered advertisements, making it clear which IP/port is the latest one. -type poolEntryAddress struct { - ip net.IP - port uint16 - lastSeen mclock.AbsTime // last time it was discovered, connected or loaded from db - fails uint // connection failures since last successful connection (persistent) -} - -func (a *poolEntryAddress) Weight() int64 { - t := time.Duration(mclock.Now() - a.lastSeen) - return int64(1000000*math.Exp(-float64(t)/float64(discoverExpireConst)-float64(a.fails)*addrFailDropLn)) + 1 -} - -func (a *poolEntryAddress) strKey() string { - return a.ip.String() + ":" + strconv.Itoa(int(a.port)) -} - -// poolStats implement statistics for a certain quantity with a long term average -// and a short term value which is adjusted exponentially with a factor of -// pstatRecentAdjust with each update and also returned exponentially to the -// average with the time constant pstatReturnToMeanTC -type poolStats struct { - sum, weight, avg, recent float64 - lastRecalc mclock.AbsTime -} - -// init initializes stats with a long term sum/update count pair retrieved from the database -func (s *poolStats) init(sum, weight float64) { - s.sum = sum - s.weight = weight - var avg float64 - if weight > 0 { - avg = s.sum / weight - } - s.avg = avg - s.recent = avg - s.lastRecalc = mclock.Now() -} - -// recalc recalculates recent value return-to-mean and long term average -func (s *poolStats) recalc() { - now := mclock.Now() - s.recent = s.avg + (s.recent-s.avg)*math.Exp(-float64(now-s.lastRecalc)/float64(pstatReturnToMeanTC)) - if s.sum == 0 { - s.avg = 0 - } else { - if s.sum > s.weight*1e30 { - s.avg = 1e30 + for _, url := range s.trustedURLs { + if node, err := enode.Parse(s.validSchemes, url); err == nil { + s.ns.SetState(node, sfAlwaysConnect, nodestate.Flags{}, 0) } else { - s.avg = s.sum / s.weight + log.Error("Invalid trusted server URL", "url", url, "error", err) } } - s.lastRecalc = now -} - -// add updates the stats with a new value -func (s *poolStats) add(value, weight float64) { - s.weight += weight - s.sum += value * weight - s.recalc() -} - -// recentAvg returns the short-term adjusted average -func (s *poolStats) recentAvg() float64 { - s.recalc() - return s.recent -} - -func (s *poolStats) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{math.Float64bits(s.sum), math.Float64bits(s.weight)}) -} - -func (s *poolStats) DecodeRLP(st *rlp.Stream) error { - var stats struct { - SumUint, WeightUint uint64 - } - if err := st.Decode(&stats); err != nil { - return err - } - s.init(math.Float64frombits(stats.SumUint), math.Float64frombits(stats.WeightUint)) - return nil -} - -// poolEntryQueue keeps track of its least recently accessed entries and removes -// them when the number of entries reaches the limit -type poolEntryQueue struct { - queue map[int]*poolEntry // known nodes indexed by their latest lastConnCnt value - newPtr, oldPtr, maxCnt int - removeFromPool func(*poolEntry) -} - -// newPoolEntryQueue returns a new poolEntryQueue -func newPoolEntryQueue(maxCnt int, removeFromPool func(*poolEntry)) poolEntryQueue { - return poolEntryQueue{queue: make(map[int]*poolEntry), maxCnt: maxCnt, removeFromPool: removeFromPool} -} - -// fetchOldest returns and removes the least recently accessed entry -func (q *poolEntryQueue) fetchOldest() *poolEntry { - if len(q.queue) == 0 { - return nil - } - for { - if e := q.queue[q.oldPtr]; e != nil { - delete(q.queue, q.oldPtr) - q.oldPtr++ - return e + unixTime := s.unixTime() + s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + s.calculateWeight(node) + if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { + wait := n.redialWaitEnd - unixTime + lastWait := n.redialWaitEnd - n.redialWaitStart + if wait > lastWait { + // if the time until expiration is larger than the last suggested + // waiting time then the system clock was probably adjusted + wait = lastWait + } + s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) } - q.oldPtr++ - } + }) } -// remove removes an entry from the queue -func (q *poolEntryQueue) remove(entry *poolEntry) { - if q.queue[entry.queueIdx] == entry { - delete(q.queue, entry.queueIdx) +// stop stops the server pool +func (s *serverPool) stop() { + s.dialIterator.Close() + if s.fillSet != nil { + s.fillSet.Close() } + s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { + // recalculate weight of connected nodes in order to update hasValue flag if necessary + s.calculateWeight(n) + }) + s.ns.Stop() } -// setLatest adds or updates a recently accessed entry. It also checks if an old entry -// needs to be removed and removes it from the parent pool too with a callback function. -func (q *poolEntryQueue) setLatest(entry *poolEntry) { - if q.queue[entry.queueIdx] == entry { - delete(q.queue, entry.queueIdx) +// registerPeer implements serverPeerSubscriber +func (s *serverPool) registerPeer(p *serverPeer) { + s.ns.SetState(p.Node(), sfConnected, sfDialing.Or(sfWaitDialTimeout), 0) + nvt := s.vt.Register(p.ID()) + s.ns.SetField(p.Node(), sfiConnectedStats, nvt.RtStats()) + p.setValueTracker(s.vt, nvt) + p.updateVtParams() +} + +// unregisterPeer implements serverPeerSubscriber +func (s *serverPool) unregisterPeer(p *serverPeer) { + s.setRedialWait(p.Node(), dialCost, dialWaitStep) + s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0) + s.ns.SetField(p.Node(), sfiConnectedStats, nil) + s.vt.Unregister(p.ID()) + p.setValueTracker(nil, nil) +} + +// recalTimeout calculates the current recommended timeout. This value is used by +// the client as a "soft timeout" value. It also affects the service value calculation +// of individual nodes. +func (s *serverPool) recalTimeout() { + // Use cached result if possible, avoid recalculating too frequently. + s.timeoutLock.RLock() + refreshed := s.timeoutRefreshed + s.timeoutLock.RUnlock() + now := s.clock.Now() + if refreshed != 0 && time.Duration(now-refreshed) < timeoutRefresh { + return + } + // Cached result is stale, recalculate a new one. + rts := s.vt.RtStats() + + // Add a fake statistic here. It is an easy way to initialize with some + // conservative values when the database is new. As soon as we have a + // considerable amount of real stats this small value won't matter. + rts.Add(time.Second*2, 10, s.vt.StatsExpFactor()) + + // Use either 10% failure rate timeout or twice the median response time + // as the recommended timeout. + timeout := minTimeout + if t := rts.Timeout(0.1); t > timeout { + timeout = t + } + if t := rts.Timeout(0.5) * 2; t > timeout { + timeout = t + } + s.timeoutLock.Lock() + if s.timeout != timeout { + s.timeout = timeout + s.timeWeights = lpc.TimeoutWeights(s.timeout) + + suggestedTimeoutGauge.Update(int64(s.timeout / time.Millisecond)) + totalValueGauge.Update(int64(rts.Value(s.timeWeights, s.vt.StatsExpFactor()))) + } + s.timeoutRefreshed = now + s.timeoutLock.Unlock() +} + +// getTimeout returns the recommended request timeout. +func (s *serverPool) getTimeout() time.Duration { + s.recalTimeout() + s.timeoutLock.RLock() + defer s.timeoutLock.RUnlock() + return s.timeout +} + +// getTimeoutAndWeight returns the recommended request timeout as well as the +// response time weight which is necessary to calculate service value. +func (s *serverPool) getTimeoutAndWeight() (time.Duration, lpc.ResponseTimeWeights) { + s.recalTimeout() + s.timeoutLock.RLock() + defer s.timeoutLock.RUnlock() + return s.timeout, s.timeWeights +} + +// addDialCost adds the given amount of dial cost to the node history and returns the current +// amount of total dial cost +func (s *serverPool) addDialCost(n *nodeHistory, amount int64) uint64 { + logOffset := s.vt.StatsExpirer().LogOffset(s.clock.Now()) + if amount > 0 { + n.dialCost.Add(amount, logOffset) + } + totalDialCost := n.dialCost.Value(logOffset) + if totalDialCost < dialCost { + totalDialCost = dialCost + } + return totalDialCost +} + +// serviceValue returns the service value accumulated in this session and in total +func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue float64) { + nvt := s.vt.GetNode(node.ID()) + if nvt == nil { + return 0, 0 + } + currentStats := nvt.RtStats() + _, timeWeights := s.getTimeoutAndWeight() + expFactor := s.vt.StatsExpFactor() + + totalValue = currentStats.Value(timeWeights, expFactor) + if connStats, ok := s.ns.GetField(node, sfiConnectedStats).(lpc.ResponseTimeStats); ok { + diff := currentStats + diff.SubStats(&connStats) + sessionValue = diff.Value(timeWeights, expFactor) + sessionValueMeter.Mark(int64(sessionValue)) + } + return +} + +// updateWeight calculates the node weight and updates the nodeWeight field and the +// hasValue flag. It also saves the node state if necessary. +func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) { + weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost)) + if weight >= nodeWeightThreshold { + s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0) + s.ns.SetField(node, sfiNodeWeight, weight) } else { - if len(q.queue) == q.maxCnt { - e := q.fetchOldest() - q.remove(e) - q.removeFromPool(e) - } + s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetField(node, sfiNodeWeight, nil) } - entry.queueIdx = q.newPtr - q.queue[entry.queueIdx] = entry - q.newPtr++ + s.ns.Persist(node) // saved if node history or hasValue changed +} + +// setRedialWait calculates and sets the redialWait timeout based on the service value +// and dial cost accumulated during the last session/attempt and in total. +// The waiting time is raised exponentially if no service value has been received in order +// to prevent dialing an unresponsive node frequently for a very long time just because it +// was useful in the past. It can still be occasionally dialed though and once it provides +// a significant amount of service value again its waiting time is quickly reduced or reset +// to the minimum. +// Note: node weight is also recalculated and updated by this function. +func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) { + n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) + sessionValue, totalValue := s.serviceValue(node) + totalDialCost := s.addDialCost(&n, addDialCost) + + // if the current dial session has yielded at least the average value/dial cost ratio + // then the waiting time should be reset to the minimum. If the session value + // is below average but still positive then timeout is limited to the ratio of + // average / current service value multiplied by the minimum timeout. If the attempt + // was unsuccessful then timeout is raised exponentially without limitation. + // Note: dialCost is used in the formula below even if dial was not attempted at all + // because the pre-negotiation query did not return a positive result. In this case + // the ratio has no meaning anyway and waitFactor is always raised, though in smaller + // steps because queries are cheaper and therefore we can allow more failed attempts. + unixTime := s.unixTime() + plannedTimeout := float64(n.redialWaitEnd - n.redialWaitStart) // last planned redialWait timeout + var actualWait float64 // actual waiting time elapsed + if unixTime > n.redialWaitEnd { + // the planned timeout has elapsed + actualWait = plannedTimeout + } else { + // if the node was redialed earlier then we do not raise the planned timeout + // exponentially because that could lead to the timeout rising very high in + // a short amount of time + // Note that in case of an early redial actualWait also includes the dial + // timeout or connection time of the last attempt but it still serves its + // purpose of preventing the timeout rising quicker than linearly as a function + // of total time elapsed without a successful connection. + actualWait = float64(unixTime - n.redialWaitStart) + } + // raise timeout exponentially if the last planned timeout has elapsed + // (use at least the last planned timeout otherwise) + nextTimeout := actualWait * waitStep + if plannedTimeout > nextTimeout { + nextTimeout = plannedTimeout + } + // we reduce the waiting time if the server has provided service value during the + // connection (but never under the minimum) + a := totalValue * dialCost * float64(minRedialWait) + b := float64(totalDialCost) * sessionValue + if a < b*nextTimeout { + nextTimeout = a / b + } + if nextTimeout < minRedialWait { + nextTimeout = minRedialWait + } + wait := time.Duration(float64(time.Second) * nextTimeout) + if wait < waitThreshold { + n.redialWaitStart = unixTime + n.redialWaitEnd = unixTime + int64(nextTimeout) + s.ns.SetField(node, sfiNodeHistory, n) + s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait) + s.updateWeight(node, totalValue, totalDialCost) + } else { + // discard known node statistics if waiting time is very long because the node + // hasn't been responsive for a very long time + s.ns.SetField(node, sfiNodeHistory, nil) + s.ns.SetField(node, sfiNodeWeight, nil) + s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + } +} + +// calculateWeight calculates and sets the node weight without altering the node history. +// This function should be called during startup and shutdown only, otherwise setRedialWait +// will keep the weights updated as the underlying statistics are adjusted. +func (s *serverPool) calculateWeight(node *enode.Node) { + n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) + _, totalValue := s.serviceValue(node) + totalDialCost := s.addDialCost(&n, 0) + s.updateWeight(node, totalValue, totalDialCost) } diff --git a/les/serverpool_test.go b/les/serverpool_test.go new file mode 100644 index 000000000..3d0487d10 --- /dev/null +++ b/les/serverpool_test.go @@ -0,0 +1,352 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + lpc "github.com/ethereum/go-ethereum/les/lespay/client" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +const ( + spTestNodes = 1000 + spTestTarget = 5 + spTestLength = 10000 + spMinTotal = 40000 + spMaxTotal = 50000 +) + +func testNodeID(i int) enode.ID { + return enode.ID{42, byte(i % 256), byte(i / 256)} +} + +func testNodeIndex(id enode.ID) int { + if id[0] != 42 { + return -1 + } + return int(id[1]) + int(id[2])*256 +} + +type serverPoolTest struct { + db ethdb.KeyValueStore + clock *mclock.Simulated + quit chan struct{} + preNeg, preNegFail bool + vt *lpc.ValueTracker + sp *serverPool + input enode.Iterator + testNodes []spTestNode + trusted []string + waitCount, waitEnded int32 + + cycle, conn, servedConn int + serviceCycles, dialCount int + disconnect map[int][]int +} + +type spTestNode struct { + connectCycles, waitCycles int + nextConnCycle, totalConn int + connected, service bool + peer *serverPeer +} + +func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest { + nodes := make([]*enode.Node, spTestNodes) + for i := range nodes { + nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i)) + } + return &serverPoolTest{ + clock: &mclock.Simulated{}, + db: memorydb.New(), + input: enode.CycleNodes(nodes), + testNodes: make([]spTestNode, spTestNodes), + preNeg: preNeg, + preNegFail: preNegFail, + } +} + +func (s *serverPoolTest) beginWait() { + // ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state + for atomic.AddInt32(&s.waitCount, 1) > preNegLimit { + atomic.AddInt32(&s.waitCount, -1) + s.clock.Run(time.Second) + } +} + +func (s *serverPoolTest) endWait() { + atomic.AddInt32(&s.waitCount, -1) + atomic.AddInt32(&s.waitEnded, 1) +} + +func (s *serverPoolTest) addTrusted(i int) { + s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String()) +} + +func (s *serverPoolTest) start() { + var testQuery queryFunc + if s.preNeg { + testQuery = func(node *enode.Node) int { + idx := testNodeIndex(node.ID()) + n := &s.testNodes[idx] + canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle + if s.preNegFail { + // simulate a scenario where UDP queries never work + s.beginWait() + s.clock.Sleep(time.Second * 5) + s.endWait() + return -1 + } else { + switch idx % 3 { + case 0: + // pre-neg returns true only if connection is possible + if canConnect { + return 1 + } else { + return 0 + } + case 1: + // pre-neg returns true but connection might still fail + return 1 + case 2: + // pre-neg returns true if connection is possible, otherwise timeout (node unresponsive) + if canConnect { + return 1 + } else { + s.beginWait() + s.clock.Sleep(time.Second * 5) + s.endWait() + return -1 + } + } + return -1 + } + } + } + + s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)) + s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted) + s.sp.validSchemes = enode.ValidSchemesForTesting + s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) } + s.disconnect = make(map[int][]int) + s.sp.start() + s.quit = make(chan struct{}) + go func() { + last := int32(-1) + for { + select { + case <-time.After(time.Millisecond * 100): + c := atomic.LoadInt32(&s.waitEnded) + if c == last { + // advance clock if test is stuck (might happen in rare cases) + s.clock.Run(time.Second) + } + last = c + case <-s.quit: + return + } + } + }() +} + +func (s *serverPoolTest) stop() { + close(s.quit) + s.sp.stop() + s.vt.Stop() + for i := range s.testNodes { + n := &s.testNodes[i] + if n.connected { + n.totalConn += s.cycle + } + n.connected = false + n.peer = nil + n.nextConnCycle = 0 + } + s.conn, s.servedConn = 0, 0 +} + +func (s *serverPoolTest) run() { + for count := spTestLength; count > 0; count-- { + if dcList := s.disconnect[s.cycle]; dcList != nil { + for _, idx := range dcList { + n := &s.testNodes[idx] + s.sp.unregisterPeer(n.peer) + n.totalConn += s.cycle + n.connected = false + n.peer = nil + s.conn-- + if n.service { + s.servedConn-- + } + n.nextConnCycle = s.cycle + n.waitCycles + } + delete(s.disconnect, s.cycle) + } + if s.conn < spTestTarget { + s.dialCount++ + s.beginWait() + s.sp.dialIterator.Next() + s.endWait() + dial := s.sp.dialIterator.Node() + id := dial.ID() + idx := testNodeIndex(id) + n := &s.testNodes[idx] + if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle { + s.conn++ + if n.service { + s.servedConn++ + } + n.totalConn -= s.cycle + n.connected = true + dc := s.cycle + n.connectCycles + s.disconnect[dc] = append(s.disconnect[dc], idx) + n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}} + s.sp.registerPeer(n.peer) + if n.service { + s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0) + } + } + } + s.serviceCycles += s.servedConn + s.clock.Run(time.Second) + s.cycle++ + } +} + +func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) { + for ; count > 0; count-- { + idx := rand.Intn(spTestNodes) + for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected { + idx = rand.Intn(spTestNodes) + } + res = append(res, idx) + s.testNodes[idx] = spTestNode{ + connectCycles: conn, + waitCycles: wait, + service: service, + } + if trusted { + s.addTrusted(idx) + } + } + return +} + +func (s *serverPoolTest) resetNodes() { + for i, n := range s.testNodes { + if n.connected { + n.totalConn += s.cycle + s.sp.unregisterPeer(n.peer) + } + s.testNodes[i] = spTestNode{totalConn: n.totalConn} + } + s.conn, s.servedConn = 0, 0 + s.disconnect = make(map[int][]int) + s.trusted = nil +} + +func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) { + var sum int + for _, idx := range nodes { + n := &s.testNodes[idx] + if n.connected { + n.totalConn += s.cycle + } + sum += n.totalConn + n.totalConn = 0 + if n.connected { + n.totalConn -= s.cycle + } + } + if sum < spMinTotal || sum > spMaxTotal { + t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal) + } +} + +func TestServerPool(t *testing.T) { testServerPool(t, false, false) } +func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) } +func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) } +func testServerPool(t *testing.T, preNeg, fail bool) { + s := newServerPoolTest(preNeg, fail) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) +} + +func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) } +func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) } +func testServerPoolChangedNodes(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.checkNodes(t, nodes) + for i := 0; i < 3; i++ { + s.resetNodes() + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.run() + s.checkNodes(t, nodes) + } + s.stop() +} + +func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) } +func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) { + testServerPoolRestartNoDiscovery(t, true) +} +func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + nodes := s.setNodes(100, 200, 200, true, false) + s.setNodes(100, 20, 20, false, false) + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) + s.input = nil + s.start() + s.run() + s.stop() + s.checkNodes(t, nodes) +} + +func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) } +func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) { + testServerPoolTrustedNoDiscovery(t, true) +} +func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) { + s := newServerPoolTest(preNeg, false) + trusted := s.setNodes(200, 200, 200, true, true) + s.input = nil + s.start() + s.run() + s.stop() + s.checkNodes(t, trusted) +} diff --git a/les/test_helper.go b/les/test_helper.go index 1f02d2529..2a2bbb440 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -508,7 +508,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer clock = &mclock.Simulated{} } dist := newRequestDistributor(speers, clock) - rm := newRetrieveManager(speers, dist, nil) + rm := newRetrieveManager(speers, dist, func() time.Duration { return time.Millisecond * 500 }) odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm) sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig) diff --git a/les/utils/expiredvalue.go b/les/utils/expiredvalue.go index 85f9b88b7..a58587368 100644 --- a/les/utils/expiredvalue.go +++ b/les/utils/expiredvalue.go @@ -63,14 +63,7 @@ func ExpFactor(logOffset Fixed64) ExpirationFactor { // Value calculates the expired value based on a floating point base and integer // power-of-2 exponent. This function should be used by multi-value expired structures. func (e ExpirationFactor) Value(base float64, exp uint64) float64 { - res := base / e.Factor - if exp > e.Exp { - res *= float64(uint64(1) << (exp - e.Exp)) - } - if exp < e.Exp { - res /= float64(uint64(1) << (e.Exp - exp)) - } - return res + return base / e.Factor * math.Pow(2, float64(int64(exp-e.Exp))) } // value calculates the value at the given moment. diff --git a/les/utils/weighted_select.go b/les/utils/weighted_select.go index fbf1f37d6..d6db3c0e6 100644 --- a/les/utils/weighted_select.go +++ b/les/utils/weighted_select.go @@ -16,28 +16,44 @@ package utils -import "math/rand" +import ( + "math/rand" +) -// wrsItem interface should be implemented by any entries that are to be selected from -// a WeightedRandomSelect set. Note that recalculating monotonously decreasing item -// weights on-demand (without constantly calling Update) is allowed -type wrsItem interface { - Weight() int64 -} - -// WeightedRandomSelect is capable of weighted random selection from a set of items -type WeightedRandomSelect struct { - root *wrsNode - idx map[wrsItem]int -} +type ( + // WeightedRandomSelect is capable of weighted random selection from a set of items + WeightedRandomSelect struct { + root *wrsNode + idx map[WrsItem]int + wfn WeightFn + } + WrsItem interface{} + WeightFn func(interface{}) uint64 +) // NewWeightedRandomSelect returns a new WeightedRandomSelect structure -func NewWeightedRandomSelect() *WeightedRandomSelect { - return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)} +func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect { + return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn} +} + +// Update updates an item's weight, adds it if it was non-existent or removes it if +// the new weight is zero. Note that explicitly updating decreasing weights is not necessary. +func (w *WeightedRandomSelect) Update(item WrsItem) { + w.setWeight(item, w.wfn(item)) +} + +// Remove removes an item from the set +func (w *WeightedRandomSelect) Remove(item WrsItem) { + w.setWeight(item, 0) +} + +// IsEmpty returns true if the set is empty +func (w *WeightedRandomSelect) IsEmpty() bool { + return w.root.sumWeight == 0 } // setWeight sets an item's weight to a specific value (removes it if zero) -func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) { +func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { idx, ok := w.idx[item] if ok { w.root.setWeight(idx, weight) @@ -58,33 +74,22 @@ func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) { } } -// Update updates an item's weight, adds it if it was non-existent or removes it if -// the new weight is zero. Note that explicitly updating decreasing weights is not necessary. -func (w *WeightedRandomSelect) Update(item wrsItem) { - w.setWeight(item, item.Weight()) -} - -// Remove removes an item from the set -func (w *WeightedRandomSelect) Remove(item wrsItem) { - w.setWeight(item, 0) -} - // Choose randomly selects an item from the set, with a chance proportional to its // current weight. If the weight of the chosen element has been decreased since the // last stored value, returns it with a newWeight/oldWeight chance, otherwise just // updates its weight and selects another one -func (w *WeightedRandomSelect) Choose() wrsItem { +func (w *WeightedRandomSelect) Choose() WrsItem { for { if w.root.sumWeight == 0 { return nil } - val := rand.Int63n(w.root.sumWeight) + val := uint64(rand.Int63n(int64(w.root.sumWeight))) choice, lastWeight := w.root.choose(val) - weight := choice.Weight() + weight := w.wfn(choice) if weight != lastWeight { w.setWeight(choice, weight) } - if weight >= lastWeight || rand.Int63n(lastWeight) < weight { + if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight { return choice } } @@ -92,16 +97,16 @@ func (w *WeightedRandomSelect) Choose() wrsItem { const wrsBranches = 8 // max number of branches in the wrsNode tree -// wrsNode is a node of a tree structure that can store wrsItems or further wrsNodes. +// wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes. type wrsNode struct { items [wrsBranches]interface{} - weights [wrsBranches]int64 - sumWeight int64 + weights [wrsBranches]uint64 + sumWeight uint64 level, itemCnt, maxItems int } // insert recursively inserts a new item to the tree and returns the item index -func (n *wrsNode) insert(item wrsItem, weight int64) int { +func (n *wrsNode) insert(item WrsItem, weight uint64) int { branch := 0 for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) { branch++ @@ -129,7 +134,7 @@ func (n *wrsNode) insert(item wrsItem, weight int64) int { // setWeight updates the weight of a certain item (which should exist) and returns // the change of the last weight value stored in the tree -func (n *wrsNode) setWeight(idx int, weight int64) int64 { +func (n *wrsNode) setWeight(idx int, weight uint64) uint64 { if n.level == 0 { oldWeight := n.weights[idx] n.weights[idx] = weight @@ -152,12 +157,12 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 { return diff } -// Choose recursively selects an item from the tree and returns it along with its weight -func (n *wrsNode) choose(val int64) (wrsItem, int64) { +// choose recursively selects an item from the tree and returns it along with its weight +func (n *wrsNode) choose(val uint64) (WrsItem, uint64) { for i, w := range n.weights { if val < w { if n.level == 0 { - return n.items[i].(wrsItem), n.weights[i] + return n.items[i].(WrsItem), n.weights[i] } return n.items[i].(*wrsNode).choose(val) } diff --git a/les/utils/weighted_select_test.go b/les/utils/weighted_select_test.go index e1969e1a6..3e1c0ad98 100644 --- a/les/utils/weighted_select_test.go +++ b/les/utils/weighted_select_test.go @@ -26,17 +26,18 @@ type testWrsItem struct { widx *int } -func (t *testWrsItem) Weight() int64 { +func testWeight(i interface{}) uint64 { + t := i.(*testWrsItem) w := *t.widx if w == -1 || w == t.idx { - return int64(t.idx + 1) + return uint64(t.idx + 1) } return 0 } func TestWeightedRandomSelect(t *testing.T) { testFn := func(cnt int) { - s := NewWeightedRandomSelect() + s := NewWeightedRandomSelect(testWeight) w := -1 list := make([]testWrsItem, cnt) for i := range list { diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go new file mode 100644 index 000000000..7091281ae --- /dev/null +++ b/p2p/nodestate/nodestate.go @@ -0,0 +1,880 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package nodestate + +import ( + "errors" + "reflect" + "sync" + "time" + "unsafe" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +type ( + // NodeStateMachine connects different system components operating on subsets of + // network nodes. Node states are represented by 64 bit vectors with each bit assigned + // to a state flag. Each state flag has a descriptor structure and the mapping is + // created automatically. It is possible to subscribe to subsets of state flags and + // receive a callback if one of the nodes has a relevant state flag changed. + // Callbacks can also modify further flags of the same node or other nodes. State + // updates only return after all immediate effects throughout the system have happened + // (deadlocks should be avoided by design of the implemented state logic). The caller + // can also add timeouts assigned to a certain node and a subset of state flags. + // If the timeout elapses, the flags are reset. If all relevant flags are reset then + // the timer is dropped. State flags with no timeout are persisted in the database + // if the flag descriptor enables saving. If a node has no state flags set at any + // moment then it is discarded. + // + // Extra node fields can also be registered so system components can also store more + // complex state for each node that is relevant to them, without creating a custom + // peer set. Fields can be shared across multiple components if they all know the + // field ID. Subscription to fields is also possible. Persistent fields should have + // an encoder and a decoder function. + NodeStateMachine struct { + started, stopped bool + lock sync.Mutex + clock mclock.Clock + db ethdb.KeyValueStore + dbNodeKey []byte + nodes map[enode.ID]*nodeInfo + offlineCallbackList []offlineCallback + + // Registered state flags or fields. Modifications are allowed + // only when the node state machine has not been started. + setup *Setup + fields []*fieldInfo + saveFlags bitMask + + // Installed callbacks. Modifications are allowed only when the + // node state machine has not been started. + stateSubs []stateSub + + // Testing hooks, only for testing purposes. + saveNodeHook func(*nodeInfo) + } + + // Flags represents a set of flags from a certain setup + Flags struct { + mask bitMask + setup *Setup + } + + // Field represents a field from a certain setup + Field struct { + index int + setup *Setup + } + + // flagDefinition describes a node state flag. Each registered instance is automatically + // mapped to a bit of the 64 bit node states. + // If persistent is true then the node is saved when state machine is shutdown. + flagDefinition struct { + name string + persistent bool + } + + // fieldDefinition describes an optional node field of the given type. The contents + // of the field are only retained for each node as long as at least one of the + // state flags is set. + fieldDefinition struct { + name string + ftype reflect.Type + encode func(interface{}) ([]byte, error) + decode func([]byte) (interface{}, error) + } + + // stateSetup contains the list of flags and fields used by the application + Setup struct { + Version uint + flags []flagDefinition + fields []fieldDefinition + } + + // bitMask describes a node state or state mask. It represents a subset + // of node flags with each bit assigned to a flag index (LSB represents flag 0). + bitMask uint64 + + // StateCallback is a subscription callback which is called when one of the + // state flags that is included in the subscription state mask is changed. + // Note: oldState and newState are also masked with the subscription mask so only + // the relevant bits are included. + StateCallback func(n *enode.Node, oldState, newState Flags) + + // FieldCallback is a subscription callback which is called when the value of + // a specific field is changed. + FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) + + // nodeInfo contains node state, fields and state timeouts + nodeInfo struct { + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + db, dirty bool + } + + nodeInfoEnc struct { + Enr enr.Record + Version uint + State bitMask + Fields [][]byte + } + + stateSub struct { + mask bitMask + callback StateCallback + } + + nodeStateTimeout struct { + mask bitMask + timer mclock.Timer + } + + fieldInfo struct { + fieldDefinition + subs []FieldCallback + } + + offlineCallback struct { + node *enode.Node + state bitMask + fields []interface{} + } +) + +// offlineState is a special state that is assumed to be set before a node is loaded from +// the database and after it is shut down. +const offlineState = bitMask(1) + +// NewFlag creates a new node state flag +func (s *Setup) NewFlag(name string) Flags { + if s.flags == nil { + s.flags = []flagDefinition{{name: "offline"}} + } + f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} + s.flags = append(s.flags, flagDefinition{name: name}) + return f +} + +// NewPersistentFlag creates a new persistent node state flag +func (s *Setup) NewPersistentFlag(name string) Flags { + if s.flags == nil { + s.flags = []flagDefinition{{name: "offline"}} + } + f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} + s.flags = append(s.flags, flagDefinition{name: name, persistent: true}) + return f +} + +// OfflineFlag returns the system-defined offline flag belonging to the given setup +func (s *Setup) OfflineFlag() Flags { + return Flags{mask: offlineState, setup: s} +} + +// NewField creates a new node state field +func (s *Setup) NewField(name string, ftype reflect.Type) Field { + f := Field{index: len(s.fields), setup: s} + s.fields = append(s.fields, fieldDefinition{ + name: name, + ftype: ftype, + }) + return f +} + +// NewPersistentField creates a new persistent node field +func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field { + f := Field{index: len(s.fields), setup: s} + s.fields = append(s.fields, fieldDefinition{ + name: name, + ftype: ftype, + encode: encode, + decode: decode, + }) + return f +} + +// flagOp implements binary flag operations and also checks whether the operands belong to the same setup +func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags { + if a.setup == nil { + if a.mask != 0 { + panic("Node state flags have no setup reference") + } + a.setup = b.setup + } + if b.setup == nil { + if b.mask != 0 { + panic("Node state flags have no setup reference") + } + b.setup = a.setup + } + if a.setup != b.setup { + panic("Node state flags belong to a different setup") + } + res := Flags{setup: a.setup} + if trueIfA { + res.mask |= a.mask & ^b.mask + } + if trueIfB { + res.mask |= b.mask & ^a.mask + } + if trueIfBoth { + res.mask |= a.mask & b.mask + } + return res +} + +// And returns the set of flags present in both a and b +func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) } + +// AndNot returns the set of flags present in a but not in b +func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) } + +// Or returns the set of flags present in either a or b +func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) } + +// Xor returns the set of flags present in either a or b but not both +func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) } + +// HasAll returns true if b is a subset of a +func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 } + +// HasNone returns true if a and b have no shared flags +func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 } + +// Equals returns true if a and b have the same flags set +func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 } + +// IsEmpty returns true if a has no flags set +func (a Flags) IsEmpty() bool { return a.mask == 0 } + +// MergeFlags merges multiple sets of state flags +func MergeFlags(list ...Flags) Flags { + if len(list) == 0 { + return Flags{} + } + res := list[0] + for i := 1; i < len(list); i++ { + res = res.Or(list[i]) + } + return res +} + +// String returns a list of the names of the flags specified in the bit mask +func (f Flags) String() string { + if f.mask == 0 { + return "[]" + } + s := "[" + comma := false + for index, flag := range f.setup.flags { + if f.mask&(bitMask(1)< 8*int(unsafe.Sizeof(bitMask(0))) { + panic("Too many node state flags") + } + ns := &NodeStateMachine{ + db: db, + dbNodeKey: dbKey, + clock: clock, + setup: setup, + nodes: make(map[enode.ID]*nodeInfo), + fields: make([]*fieldInfo, len(setup.fields)), + } + stateNameMap := make(map[string]int) + for index, flag := range setup.flags { + if _, ok := stateNameMap[flag.name]; ok { + panic("Node state flag name collision") + } + stateNameMap[flag.name] = index + if flag.persistent { + ns.saveFlags |= bitMask(1) << uint(index) + } + } + fieldNameMap := make(map[string]int) + for index, field := range setup.fields { + if _, ok := fieldNameMap[field.name]; ok { + panic("Node field name collision") + } + ns.fields[index] = &fieldInfo{fieldDefinition: field} + fieldNameMap[field.name] = index + } + return ns +} + +// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask +func (ns *NodeStateMachine) stateMask(flags Flags) bitMask { + if flags.setup != ns.setup && flags.mask != 0 { + panic("Node state flags belong to a different setup") + } + return flags.mask +} + +// fieldIndex checks whether the field belongs to the same setup and returns its internal index +func (ns *NodeStateMachine) fieldIndex(field Field) int { + if field.setup != ns.setup { + panic("Node field belongs to a different setup") + } + return field.index +} + +// SubscribeState adds a node state subscription. The callback is called while the state +// machine mutex is not held and it is allowed to make further state updates. All immediate +// changes throughout the system are processed in the same thread/goroutine. It is the +// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks, +// infinite toggling of flags or hazardous/non-deterministic state changes. +// State subscriptions should be installed before loading the node database or making the +// first state update. +func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { + ns.lock.Lock() + defer ns.lock.Unlock() + + if ns.started { + panic("state machine already started") + } + ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback}) +} + +// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState. +func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) { + ns.lock.Lock() + defer ns.lock.Unlock() + + if ns.started { + panic("state machine already started") + } + f := ns.fields[ns.fieldIndex(field)] + f.subs = append(f.subs, callback) +} + +// newNode creates a new nodeInfo +func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo { + return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))} +} + +// checkStarted checks whether the state machine has already been started and panics otherwise. +func (ns *NodeStateMachine) checkStarted() { + if !ns.started { + panic("state machine not started yet") + } +} + +// Start starts the state machine, enabling state and field operations and disabling +// further subscriptions. +func (ns *NodeStateMachine) Start() { + ns.lock.Lock() + if ns.started { + panic("state machine already started") + } + ns.started = true + if ns.db != nil { + ns.loadFromDb() + } + ns.lock.Unlock() + ns.offlineCallbacks(true) +} + +// Stop stops the state machine and saves its state if a database was supplied +func (ns *NodeStateMachine) Stop() { + ns.lock.Lock() + for _, node := range ns.nodes { + fields := make([]interface{}, len(node.fields)) + copy(fields, node.fields) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + } + ns.stopped = true + if ns.db != nil { + ns.saveToDb() + ns.lock.Unlock() + } else { + ns.lock.Unlock() + } + ns.offlineCallbacks(false) +} + +// loadFromDb loads persisted node states from the database +func (ns *NodeStateMachine) loadFromDb() { + it := ns.db.NewIterator(ns.dbNodeKey, nil) + for it.Next() { + var id enode.ID + if len(it.Key()) != len(ns.dbNodeKey)+len(id) { + log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id)) + continue + } + copy(id[:], it.Key()[len(ns.dbNodeKey):]) + ns.decodeNode(id, it.Value()) + } +} + +type dummyIdentity enode.ID + +func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil } +func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] } + +// decodeNode decodes a node database entry and adds it to the node set if successful +func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { + var enc nodeInfoEnc + if err := rlp.DecodeBytes(data, &enc); err != nil { + log.Error("Failed to decode node info", "id", id, "error", err) + return + } + n, _ := enode.New(dummyIdentity(id), &enc.Enr) + node := ns.newNode(n) + node.db = true + + if enc.Version != ns.setup.Version { + log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version) + ns.deleteNode(id) + return + } + if len(enc.Fields) > len(ns.setup.fields) { + log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields)) + return + } + // Resolve persisted node fields + for i, encField := range enc.Fields { + if len(encField) == 0 { + continue + } + if decode := ns.fields[i].decode; decode != nil { + if field, err := decode(encField); err == nil { + node.fields[i] = field + } else { + log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) + return + } + } else { + log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name) + return + } + } + // It's a compatible node record, add it to set. + ns.nodes[id] = node + node.state = enc.State + fields := make([]interface{}, len(node.fields)) + copy(fields, node.fields) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) +} + +// saveNode saves the given node info to the database +func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { + if ns.db == nil { + return nil + } + + storedState := node.state & ns.saveFlags + for _, t := range node.timeouts { + storedState &= ^t.mask + } + if storedState == 0 { + if node.db { + node.db = false + ns.deleteNode(id) + } + node.dirty = false + return nil + } + + enc := nodeInfoEnc{ + Enr: *node.node.Record(), + Version: ns.setup.Version, + State: storedState, + Fields: make([][]byte, len(ns.fields)), + } + log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) + lastIndex := -1 + for i, f := range node.fields { + if f == nil { + continue + } + encode := ns.fields[i].encode + if encode == nil { + continue + } + blob, err := encode(f) + if err != nil { + return err + } + enc.Fields[i] = blob + lastIndex = i + } + enc.Fields = enc.Fields[:lastIndex+1] + data, err := rlp.EncodeToBytes(&enc) + if err != nil { + return err + } + if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil { + return err + } + node.dirty, node.db = false, true + + if ns.saveNodeHook != nil { + ns.saveNodeHook(node) + } + return nil +} + +// deleteNode removes a node info from the database +func (ns *NodeStateMachine) deleteNode(id enode.ID) { + ns.db.Delete(append(ns.dbNodeKey, id[:]...)) +} + +// saveToDb saves the persistent flags and fields of all nodes that have been changed +func (ns *NodeStateMachine) saveToDb() { + for id, node := range ns.nodes { + if node.dirty { + err := ns.saveNode(id, node) + if err != nil { + log.Error("Failed to save node", "id", id, "error", err) + } + } + } +} + +// updateEnode updates the enode entry belonging to the given node if it already exists +func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) { + id := n.ID() + node := ns.nodes[id] + if node != nil && n.Seq() > node.node.Seq() { + node.node = n + } + return id, node +} + +// Persist saves the persistent state and fields of the given node immediately +func (ns *NodeStateMachine) Persist(n *enode.Node) error { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if id, node := ns.updateEnode(n); node != nil && node.dirty { + err := ns.saveNode(id, node) + if err != nil { + log.Error("Failed to save node", "id", id, "error", err) + } + return err + } + return nil +} + +// SetState updates the given node state flags and processes all resulting callbacks. +// It only returns after all subsequent immediate changes (including those changed by the +// callbacks) have been processed. If a flag with a timeout is set again, the operation +// removes or replaces the existing timeout. +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { + ns.lock.Lock() + ns.checkStarted() + if ns.stopped { + ns.lock.Unlock() + return + } + + set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) + id, node := ns.updateEnode(n) + if node == nil { + if set == 0 { + ns.lock.Unlock() + return + } + node = ns.newNode(n) + ns.nodes[id] = node + } + oldState := node.state + newState := (node.state & (^reset)) | set + changed := oldState ^ newState + node.state = newState + + // Remove the timeout callbacks for all reset and set flags, + // even they are not existent(it's noop). + ns.removeTimeouts(node, set|reset) + + // Register the timeout callback if the new state is not empty + // and timeout itself is required. + if timeout != 0 && newState != 0 { + ns.addTimeout(n, set, timeout) + } + if newState == oldState { + ns.lock.Unlock() + return + } + if newState == 0 { + delete(ns.nodes, id) + if node.db { + ns.deleteNode(id) + } + } else { + if changed&ns.saveFlags != 0 { + node.dirty = true + } + } + ns.lock.Unlock() + // call state update subscription callbacks without holding the mutex + for _, sub := range ns.stateSubs { + if changed&sub.mask != 0 { + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) + } + } + if newState == 0 { + // call field subscriptions for discarded fields + for i, v := range node.fields { + if v != nil { + f := ns.fields[i] + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{setup: ns.setup}, v, nil) + } + } + } + } + } +} + +// offlineCallbacks calls state update callbacks at startup or shutdown +func (ns *NodeStateMachine) offlineCallbacks(start bool) { + for _, cb := range ns.offlineCallbackList { + for _, sub := range ns.stateSubs { + offState := offlineState & sub.mask + onState := cb.state & sub.mask + if offState != onState { + if start { + sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) + } else { + sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) + } + } + } + for i, f := range cb.fields { + if f != nil && ns.fields[i].subs != nil { + for _, fsub := range ns.fields[i].subs { + if start { + fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + } else { + fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + } + } + } + } + } + ns.offlineCallbackList = nil +} + +// AddTimeout adds a node state timeout associated to the given state flag(s). +// After the specified time interval, the relevant states will be reset. +func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if ns.stopped { + return + } + ns.addTimeout(n, ns.stateMask(flags), timeout) +} + +// addTimeout adds a node state timeout associated to the given state flag(s). +func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) { + _, node := ns.updateEnode(n) + if node == nil { + return + } + mask &= node.state + if mask == 0 { + return + } + ns.removeTimeouts(node, mask) + t := &nodeStateTimeout{mask: mask} + t.timer = ns.clock.AfterFunc(timeout, func() { + ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) + }) + node.timeouts = append(node.timeouts, t) + if mask&ns.saveFlags != 0 { + node.dirty = true + } +} + +// removeTimeout removes node state timeouts associated to the given state flag(s). +// If a timeout was associated to multiple flags which are not all included in the +// specified remove mask then only the included flags are de-associated and the timer +// stays active. +func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { + for i := 0; i < len(node.timeouts); i++ { + t := node.timeouts[i] + match := t.mask & mask + if match == 0 { + continue + } + t.mask -= match + if t.mask != 0 { + continue + } + t.timer.Stop() + node.timeouts[i] = node.timeouts[len(node.timeouts)-1] + node.timeouts = node.timeouts[:len(node.timeouts)-1] + i-- + if match&ns.saveFlags != 0 { + node.dirty = true + } + } +} + +// GetField retrieves the given field of the given node +func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if ns.stopped { + return nil + } + if _, node := ns.updateEnode(n); node != nil { + return node.fields[ns.fieldIndex(field)] + } + return nil +} + +// SetField sets the given field of the given node +func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { + ns.lock.Lock() + ns.checkStarted() + if ns.stopped { + ns.lock.Unlock() + return nil + } + _, node := ns.updateEnode(n) + if node == nil { + ns.lock.Unlock() + return nil + } + fieldIndex := ns.fieldIndex(field) + f := ns.fields[fieldIndex] + if value != nil && reflect.TypeOf(value) != f.ftype { + log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) + ns.lock.Unlock() + return errors.New("invalid field type") + } + oldValue := node.fields[fieldIndex] + if value == oldValue { + ns.lock.Unlock() + return nil + } + node.fields[fieldIndex] = value + if f.encode != nil { + node.dirty = true + } + + state := node.state + ns.lock.Unlock() + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) + } + } + return nil +} + +// ForEach calls the callback for each node having all of the required and none of the +// disabled flags set +func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { + ns.lock.Lock() + ns.checkStarted() + type callback struct { + node *enode.Node + state bitMask + } + require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags) + var callbacks []callback + for _, node := range ns.nodes { + if node.state&require == require && node.state&disable == 0 { + callbacks = append(callbacks, callback{node.node, node.state & (require | disable)}) + } + } + ns.lock.Unlock() + for _, c := range callbacks { + cb(c.node, Flags{mask: c.state, setup: ns.setup}) + } +} + +// GetNode returns the enode currently associated with the given ID +func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if node := ns.nodes[id]; node != nil { + return node.node + } + return nil +} + +// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently +// being in a given set specified by required and disabled state flags +func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { + var count int64 + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { + oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) + newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) + if newMatch == oldMatch { + return + } + + if newMatch { + count++ + if name != "" { + log.Debug("Node entered", "set", name, "id", n.ID(), "count", count) + } + if inMeter != nil { + inMeter.Mark(1) + } + } else { + count-- + if name != "" { + log.Debug("Node left", "set", name, "id", n.ID(), "count", count) + } + if outMeter != nil { + outMeter.Mark(1) + } + } + if gauge != nil { + gauge.Update(count) + } + }) +} diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go new file mode 100644 index 000000000..f6ff3ffc0 --- /dev/null +++ b/p2p/nodestate/nodestate_test.go @@ -0,0 +1,389 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package nodestate + +import ( + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) { + setup := &Setup{} + flags := make([]Flags, len(flagPersist)) + for i, persist := range flagPersist { + if persist { + flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i)) + } else { + flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i)) + } + } + fields := make([]Field, len(fieldType)) + for i, ftype := range fieldType { + switch ftype { + case reflect.TypeOf(uint64(0)): + fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec) + case reflect.TypeOf(""): + fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec) + default: + fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype) + } + } + return setup, flags, fields +} + +func testNode(b byte) *enode.Node { + r := &enr.Record{} + r.SetSig(dummyIdentity{b}, []byte{42}) + n, _ := enode.New(dummyIdentity{b}, r) + return n +} + +func TestCallback(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + set0 := make(chan struct{}, 1) + set1 := make(chan struct{}, 1) + set2 := make(chan struct{}, 1) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) + ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[1], Flags{}, time.Second) + ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) + + for i := 0; i < 3; i++ { + select { + case <-set0: + case <-set1: + case <-set2: + case <-time.After(time.Second): + t.Fatalf("failed to invoke callback") + } + } +} + +func TestPersistentFlags(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{true, true, true, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + saveNode := make(chan *nodeInfo, 5) + ns.saveNodeHook = func(node *nodeInfo) { + saveNode <- node + } + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved + ns.SetState(testNode(2), flags[1], Flags{}, 0) + ns.SetState(testNode(3), flags[2], Flags{}, 0) + ns.SetState(testNode(4), flags[3], Flags{}, 0) + ns.SetState(testNode(5), flags[0], Flags{}, 0) + ns.Persist(testNode(5)) + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } + ns.Stop() + + for i := 0; i < 2; i++ { + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } + } + select { + case <-saveNode: + t.Fatalf("Unexpected saveNode") + case <-time.After(time.Millisecond * 100): + } +} + +func TestSetField(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + saveNode := make(chan *nodeInfo, 1) + ns.saveNodeHook = func(node *nodeInfo) { + saveNode <- node + } + + ns.Start() + + // Set field before setting state + ns.SetField(testNode(1), fields[0], "hello world") + field := ns.GetField(testNode(1), fields[0]) + if field != nil { + t.Fatalf("Field shouldn't be set before setting states") + } + // Set field after setting state + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], "hello world") + field = ns.GetField(testNode(1), fields[0]) + if field == nil { + t.Fatalf("Field should be set after setting states") + } + if err := ns.SetField(testNode(1), fields[0], 123); err == nil { + t.Fatalf("Invalid field should be rejected") + } + // Dirty node should be written back + ns.Stop() + select { + case <-saveNode: + case <-time.After(time.Second): + t.Fatalf("Timeout") + } +} + +func TestUnsetField(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.Start() + + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + ns.SetField(testNode(1), fields[0], "hello world") + + ns.SetState(testNode(1), Flags{}, flags[0], 0) + if field := ns.GetField(testNode(1), fields[0]); field != nil { + t.Fatalf("Field should be unset") + } +} + +func TestSetState(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + type change struct{ old, new Flags } + set := make(chan change, 1) + ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { + set <- change{ + old: oldState, + new: newState, + } + }) + + ns.Start() + + check := func(expectOld, expectNew Flags, expectChange bool) { + if expectChange { + select { + case c := <-set: + if !c.old.Equals(expectOld) { + t.Fatalf("Old state mismatch") + } + if !c.new.Equals(expectNew) { + t.Fatalf("New state mismatch") + } + case <-time.After(time.Second): + } + return + } + select { + case <-set: + t.Fatalf("Unexpected change") + case <-time.After(time.Millisecond * 100): + return + } + } + ns.SetState(testNode(1), flags[0], Flags{}, 0) + check(Flags{}, flags[0], true) + + ns.SetState(testNode(1), flags[1], Flags{}, 0) + check(flags[0], flags[0].Or(flags[1]), true) + + ns.SetState(testNode(1), flags[2], Flags{}, 0) + check(Flags{}, Flags{}, false) + + ns.SetState(testNode(1), Flags{}, flags[0], 0) + check(flags[0].Or(flags[1]), flags[1], true) + + ns.SetState(testNode(1), Flags{}, flags[1], 0) + check(flags[1], Flags{}, true) + + ns.SetState(testNode(1), Flags{}, flags[2], 0) + check(Flags{}, Flags{}, false) + + ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) + check(Flags{}, flags[0].Or(flags[1]), true) + clock.Run(time.Second) + check(flags[0].Or(flags[1]), Flags{}, true) +} + +func uint64FieldEnc(field interface{}) ([]byte, error) { + if u, ok := field.(uint64); ok { + enc, err := rlp.EncodeToBytes(&u) + return enc, err + } else { + return nil, errors.New("invalid field type") + } +} + +func uint64FieldDec(enc []byte) (interface{}, error) { + var u uint64 + err := rlp.DecodeBytes(enc, &u) + return u, err +} + +func stringFieldEnc(field interface{}) ([]byte, error) { + if s, ok := field.(string); ok { + return []byte(s), nil + } else { + return nil, errors.New("invalid field type") + } +} + +func stringFieldDec(enc []byte) (interface{}, error) { + return string(enc), nil +} + +func TestPersistentFields(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.Start() + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) + ns.SetField(testNode(1), fields[1], "hello world") + ns.Stop() + + ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns2.Start() + field0 := ns2.GetField(testNode(1), fields[0]) + if !reflect.DeepEqual(field0, uint64(100)) { + t.Fatalf("Field changed") + } + field1 := ns2.GetField(testNode(1), fields[1]) + if !reflect.DeepEqual(field1, "hello world") { + t.Fatalf("Field changed") + } + + s.Version++ + ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + ns3.Start() + if ns3.GetField(testNode(1), fields[0]) != nil { + t.Fatalf("Old field version should have been discarded") + } +} + +func TestFieldSub(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))}) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + var ( + lastState Flags + lastOldValue, lastNewValue interface{} + ) + ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + lastState, lastOldValue, lastNewValue = state, oldValue, newValue + }) + check := func(state Flags, oldValue, newValue interface{}) { + if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue { + t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue) + } + } + ns.Start() + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) + check(flags[0], nil, uint64(100)) + ns.Stop() + check(s.OfflineFlag(), uint64(100), nil) + + ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + lastState, lastOldValue, lastNewValue = state, oldValue, newValue + }) + ns2.Start() + check(s.OfflineFlag(), nil, uint64(100)) + ns2.SetState(testNode(1), Flags{}, flags[0], 0) + check(Flags{}, uint64(100), nil) + ns2.Stop() +} + +func TestDuplicatedFlags(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{true}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + type change struct{ old, new Flags } + set := make(chan change, 1) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + set <- change{oldState, newState} + }) + + ns.Start() + defer ns.Stop() + + check := func(expectOld, expectNew Flags, expectChange bool) { + if expectChange { + select { + case c := <-set: + if !c.old.Equals(expectOld) { + t.Fatalf("Old state mismatch") + } + if !c.new.Equals(expectNew) { + t.Fatalf("New state mismatch") + } + case <-time.After(time.Second): + } + return + } + select { + case <-set: + t.Fatalf("Unexpected change") + case <-time.After(time.Millisecond * 100): + return + } + } + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + check(Flags{}, flags[0], true) + ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s + check(Flags{}, flags[0], false) + + clock.Run(2 * time.Second) + check(flags[0], Flags{}, true) +} diff --git a/params/bootnodes.go b/params/bootnodes.go index 0d72321b0..e1898d762 100644 --- a/params/bootnodes.go +++ b/params/bootnodes.go @@ -65,11 +65,22 @@ var GoerliBootnodes = []string{ const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@" -// These DNS names provide bootstrap connectivity for public testnets and the mainnet. -// See https://github.com/ethereum/discv4-dns-lists for more information. -var KnownDNSNetworks = map[common.Hash]string{ - MainnetGenesisHash: dnsPrefix + "all.mainnet.ethdisco.net", - RopstenGenesisHash: dnsPrefix + "all.ropsten.ethdisco.net", - RinkebyGenesisHash: dnsPrefix + "all.rinkeby.ethdisco.net", - GoerliGenesisHash: dnsPrefix + "all.goerli.ethdisco.net", +// KnownDNSNetwork returns the address of a public DNS-based node list for the given +// genesis hash and protocol. See https://github.com/ethereum/discv4-dns-lists for more +// information. +func KnownDNSNetwork(genesis common.Hash, protocol string) string { + var net string + switch genesis { + case MainnetGenesisHash: + net = "mainnet" + case RopstenGenesisHash: + net = "ropsten" + case RinkebyGenesisHash: + net = "rinkeby" + case GoerliGenesisHash: + net = "goerli" + default: + return "" + } + return dnsPrefix + protocol + "." + net + ".ethdisco.net" }