les: wait for all task goroutines before dropping the peer (#20010)

* les: wait all task routines before drop the peer

* les: address comments

* les: fix issue
This commit is contained in:
gary rong 2019-08-27 19:07:25 +08:00 committed by Péter Szilágyi
parent a978adfd7c
commit 68502595f6
7 changed files with 84 additions and 53 deletions

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -312,7 +313,7 @@ func (h *serverHandler) measure(setup *benchmarkSetup, count int) error {
}() }()
go func() { go func() {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if err := h.handleMsg(serverPeer); err != nil { if err := h.handleMsg(serverPeer, &sync.WaitGroup{}); err != nil {
errCh <- err errCh <- err
return return
} }

@ -181,52 +181,53 @@ func (f *clientPool) stop() {
f.lock.Unlock() f.lock.Unlock()
} }
// registerPeer implements peerSetNotify
func (f *clientPool) registerPeer(p *peer) {
c := f.connect(p, 0)
if c != nil {
p.balanceTracker = &c.balanceTracker
}
}
// connect should be called after a successful handshake. If the connection was // connect should be called after a successful handshake. If the connection was
// rejected, there is no need to call disconnect. // rejected, there is no need to call disconnect.
func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo { func (f *clientPool) connect(peer clientPeer, capacity uint64) bool {
f.lock.Lock() f.lock.Lock()
defer f.lock.Unlock() defer f.lock.Unlock()
// Short circuit is clientPool is already closed.
if f.closed { if f.closed {
return nil return false
} }
address := peer.freeClientId() // Dedup connected peers.
id := peer.ID() id, freeID := peer.ID(), peer.freeClientId()
idStr := peerIdToString(id)
if _, ok := f.connectedMap[id]; ok { if _, ok := f.connectedMap[id]; ok {
clientRejectedMeter.Mark(1) clientRejectedMeter.Mark(1)
log.Debug("Client already connected", "address", address, "id", idStr) log.Debug("Client already connected", "address", freeID, "id", peerIdToString(id))
return nil return false
} }
// Create a clientInfo but do not add it yet
now := f.clock.Now() now := f.clock.Now()
// create a clientInfo but do not add it yet
e := &clientInfo{pool: f, peer: peer, address: address, queueIndex: -1, id: id}
posBalance := f.getPosBalance(id).value posBalance := f.getPosBalance(id).value
e.priority = posBalance != 0 e := &clientInfo{pool: f, peer: peer, address: freeID, queueIndex: -1, id: id, priority: posBalance != 0}
var negBalance uint64 var negBalance uint64
nb := f.negBalanceMap[address] nb := f.negBalanceMap[freeID]
if nb != nil { if nb != nil {
negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now)) / fixedPointMultiplier)) negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now)) / fixedPointMultiplier))
} }
// If the client is a free client, assign with a low free capacity,
// Otherwise assign with the given value(priority client)
if !e.priority { if !e.priority {
capacity = f.freeClientCap capacity = f.freeClientCap
} }
// check whether it fits into connectedQueue // Ensure the capacity will never lower than the free capacity.
if capacity < f.freeClientCap { if capacity < f.freeClientCap {
capacity = f.freeClientCap capacity = f.freeClientCap
} }
e.capacity = capacity e.capacity = capacity
e.balanceTracker.init(f.clock, capacity) e.balanceTracker.init(f.clock, capacity)
e.balanceTracker.setBalance(posBalance, negBalance) e.balanceTracker.setBalance(posBalance, negBalance)
f.setClientPriceFactors(e) f.setClientPriceFactors(e)
// If the number of clients already connected in the clientpool exceeds its
// capacity, evict some clients with lowest priority.
//
// If the priority of the newly added client is lower than the priority of
// all connected clients, the client is rejected.
newCapacity := f.connectedCapacity + capacity newCapacity := f.connectedCapacity + capacity
newCount := f.connectedQueue.Size() + 1 newCount := f.connectedQueue.Size() + 1
if newCapacity > f.capacityLimit || newCount > f.countLimit { if newCapacity > f.capacityLimit || newCount > f.countLimit {
@ -248,8 +249,8 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
f.connectedQueue.Push(c) f.connectedQueue.Push(c)
} }
clientRejectedMeter.Mark(1) clientRejectedMeter.Mark(1)
log.Debug("Client rejected", "address", address, "id", idStr) log.Debug("Client rejected", "address", freeID, "id", peerIdToString(id))
return nil return false
} }
// accept new client, drop old ones // accept new client, drop old ones
for _, c := range kickList { for _, c := range kickList {
@ -258,7 +259,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
} }
// client accepted, finish setting it up // client accepted, finish setting it up
if nb != nil { if nb != nil {
delete(f.negBalanceMap, address) delete(f.negBalanceMap, freeID)
f.negBalanceQueue.Remove(nb.queueIndex) f.negBalanceQueue.Remove(nb.queueIndex)
} }
if e.priority { if e.priority {
@ -272,13 +273,8 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) *clientInfo {
e.peer.updateCapacity(e.capacity) e.peer.updateCapacity(e.capacity)
} }
clientConnectedMeter.Mark(1) clientConnectedMeter.Mark(1)
log.Debug("Client accepted", "address", address) log.Debug("Client accepted", "address", freeID)
return e return true
}
// unregisterPeer implements peerSetNotify
func (f *clientPool) unregisterPeer(p *peer) {
f.disconnect(p)
} }
// disconnect should be called when a connection is terminated. If the disconnection // disconnect should be called when a connection is terminated. If the disconnection
@ -378,6 +374,18 @@ func (f *clientPool) setLimits(count int, totalCap uint64) {
}) })
} }
// requestCost feeds request cost after serving a request from the given peer.
func (f *clientPool) requestCost(p *peer, cost uint64) {
f.lock.Lock()
defer f.lock.Unlock()
info, exist := f.connectedMap[p.ID()]
if !exist || f.closed {
return
}
info.balanceTracker.requestCost(cost)
}
// logOffset calculates the time-dependent offset for the logarithmic // logOffset calculates the time-dependent offset for the logarithmic
// representation of negative balance // representation of negative balance
func (f *clientPool) logOffset(now mclock.AbsTime) int64 { func (f *clientPool) logOffset(now mclock.AbsTime) int64 {

@ -83,14 +83,14 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
// pool should accept new peers up to its connected limit // pool should accept new peers up to its connected limit
for i := 0; i < connLimit; i++ { for i := 0; i < connLimit; i++ {
if pool.connect(poolTestPeer(i), 0) != nil { if pool.connect(poolTestPeer(i), 0) {
connected[i] = true connected[i] = true
} else { } else {
t.Fatalf("Test peer #%d rejected", i) t.Fatalf("Test peer #%d rejected", i)
} }
} }
// since all accepted peers are new and should not be kicked out, the next one should be rejected // since all accepted peers are new and should not be kicked out, the next one should be rejected
if pool.connect(poolTestPeer(connLimit), 0) != nil { if pool.connect(poolTestPeer(connLimit), 0) {
connected[connLimit] = true connected[connLimit] = true
t.Fatalf("Peer accepted over connected limit") t.Fatalf("Peer accepted over connected limit")
} }
@ -116,7 +116,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
connTicks[i] += tickCounter connTicks[i] += tickCounter
} }
} else { } else {
if pool.connect(poolTestPeer(i), 0) != nil { if pool.connect(poolTestPeer(i), 0) {
connected[i] = true connected[i] = true
connTicks[i] -= tickCounter connTicks[i] -= tickCounter
} }
@ -159,7 +159,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
} }
// a previously unknown peer should be accepted now // a previously unknown peer should be accepted now
if pool.connect(poolTestPeer(54321), 0) == nil { if !pool.connect(poolTestPeer(54321), 0) {
t.Fatalf("Previously unknown peer rejected") t.Fatalf("Previously unknown peer rejected")
} }
@ -173,7 +173,7 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD
pool.connect(poolTestPeer(i), 0) pool.connect(poolTestPeer(i), 0)
} }
// expect pool to remember known nodes and kick out one of them to accept a new one // expect pool to remember known nodes and kick out one of them to accept a new one
if pool.connect(poolTestPeer(54322), 0) == nil { if !pool.connect(poolTestPeer(54322), 0) {
t.Errorf("Previously unknown peer rejected after restarting pool") t.Errorf("Previously unknown peer rejected after restarting pool")
} }
pool.stop() pool.stop()

@ -94,6 +94,7 @@ type peer struct {
sendQueue *execQueue sendQueue *execQueue
errCh chan error errCh chan error
// responseLock ensures that responses are queued in the same order as // responseLock ensures that responses are queued in the same order as
// RequestProcessed is called // RequestProcessed is called
responseLock sync.Mutex responseLock sync.Mutex
@ -107,11 +108,10 @@ type peer struct {
updateTime mclock.AbsTime updateTime mclock.AbsTime
frozen uint32 // 1 if client is in frozen state frozen uint32 // 1 if client is in frozen state
fcClient *flowcontrol.ClientNode // nil if the peer is server only fcClient *flowcontrol.ClientNode // nil if the peer is server only
fcServer *flowcontrol.ServerNode // nil if the peer is client only fcServer *flowcontrol.ServerNode // nil if the peer is client only
fcParams flowcontrol.ServerParams fcParams flowcontrol.ServerParams
fcCosts requestCostTable fcCosts requestCostTable
balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler.
trusted bool trusted bool
onlyAnnounce bool onlyAnnounce bool

@ -112,9 +112,7 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
maxCapacity = totalRecharge maxCapacity = totalRecharge
} }
srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2) srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2)
srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) })
srv.peers.notify(srv.clientPool)
checkpoint := srv.latestLocalCheckpoint() checkpoint := srv.latestLocalCheckpoint()
if !checkpoint.Empty() { if !checkpoint.Empty() {

@ -54,7 +54,10 @@ const (
MaxTxStatus = 256 // Amount of transactions to queried per request MaxTxStatus = 256 // Amount of transactions to queried per request
) )
var errTooManyInvalidRequest = errors.New("too many invalid requests made") var (
errTooManyInvalidRequest = errors.New("too many invalid requests made")
errFullClientPool = errors.New("client pool is full")
)
// serverHandler is responsible for serving light client and process // serverHandler is responsible for serving light client and process
// all incoming light requests. // all incoming light requests.
@ -124,23 +127,26 @@ func (h *serverHandler) handle(p *peer) error {
} }
defer p.fcClient.Disconnect() defer p.fcClient.Disconnect()
// Disconnect the inbound peer if it's rejected by clientPool
if !h.server.clientPool.connect(p, 0) {
p.Log().Debug("Light Ethereum peer registration failed", "err", errFullClientPool)
return errFullClientPool
}
// Register the peer locally // Register the peer locally
if err := h.server.peers.Register(p); err != nil { if err := h.server.peers.Register(p); err != nil {
h.server.clientPool.disconnect(p)
p.Log().Error("Light Ethereum peer registration failed", "err", err) p.Log().Error("Light Ethereum peer registration failed", "err", err)
return err return err
} }
clientConnectionGauge.Update(int64(h.server.peers.Len())) clientConnectionGauge.Update(int64(h.server.peers.Len()))
// add dummy balance tracker for tests var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
if p.balanceTracker == nil {
p.balanceTracker = &balanceTracker{}
p.balanceTracker.init(&mclock.System{}, 1)
}
connectedAt := mclock.Now() connectedAt := mclock.Now()
defer func() { defer func() {
p.balanceTracker = nil wg.Wait() // Ensure all background task routines have exited.
h.server.peers.Unregister(p.id) h.server.peers.Unregister(p.id)
h.server.clientPool.disconnect(p)
clientConnectionGauge.Update(int64(h.server.peers.Len())) clientConnectionGauge.Update(int64(h.server.peers.Len()))
connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
}() }()
@ -153,7 +159,7 @@ func (h *serverHandler) handle(p *peer) error {
return err return err
default: default:
} }
if err := h.handleMsg(p); err != nil { if err := h.handleMsg(p, &wg); err != nil {
p.Log().Debug("Light Ethereum message handling failed", "err", err) p.Log().Debug("Light Ethereum message handling failed", "err", err)
return err return err
} }
@ -162,7 +168,7 @@ func (h *serverHandler) handle(p *peer) error {
// handleMsg is invoked whenever an inbound message is received from a remote // handleMsg is invoked whenever an inbound message is received from a remote
// peer. The remote connection is torn down upon returning any error. // peer. The remote connection is torn down upon returning any error.
func (h *serverHandler) handleMsg(p *peer) error { func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error {
// Read the next message from the remote peer, and ensure it's fully consumed // Read the next message from the remote peer, and ensure it's fully consumed
msg, err := p.rw.ReadMsg() msg, err := p.rw.ReadMsg()
if err != nil { if err != nil {
@ -243,7 +249,7 @@ func (h *serverHandler) handleMsg(p *peer) error {
// Feed cost tracker request serving statistic. // Feed cost tracker request serving statistic.
h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost)
// Reduce priority "balance" for the specific peer. // Reduce priority "balance" for the specific peer.
p.balanceTracker.requestCost(realCost) h.server.clientPool.requestCost(p, realCost)
} }
if reply != nil { if reply != nil {
p.queueSend(func() { p.queueSend(func() {
@ -273,7 +279,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
} }
query := req.Query query := req.Query
if accept(req.ReqID, query.Amount, MaxHeaderFetch) { if accept(req.ReqID, query.Amount, MaxHeaderFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
hashMode := query.Origin.Hash != (common.Hash{}) hashMode := query.Origin.Hash != (common.Hash{})
first := true first := true
maxNonCanonical := uint64(100) maxNonCanonical := uint64(100)
@ -387,7 +395,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
) )
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
for i, hash := range req.Hashes { for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() { if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime) sendResponse(req.ReqID, 0, nil, task.servingTime)
@ -433,7 +443,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
for i, request := range req.Reqs { for i, request := range req.Reqs {
if i != 0 && !task.waitOrStop() { if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime) sendResponse(req.ReqID, 0, nil, task.servingTime)
@ -502,7 +514,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
) )
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
for i, hash := range req.Hashes { for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() { if i != 0 && !task.waitOrStop() {
sendResponse(req.ReqID, 0, nil, task.servingTime) sendResponse(req.ReqID, 0, nil, task.servingTime)
@ -557,7 +571,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
nodes := light.NewNodeSet() nodes := light.NewNodeSet()
for i, request := range req.Reqs { for i, request := range req.Reqs {
@ -658,7 +674,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
) )
reqCnt := len(req.Reqs) reqCnt := len(req.Reqs)
if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
var ( var (
lastIdx uint64 lastIdx uint64
lastType uint lastType uint
@ -725,7 +743,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
} }
reqCnt := len(req.Txs) reqCnt := len(req.Txs)
if accept(req.ReqID, uint64(reqCnt), MaxTxSend) { if accept(req.ReqID, uint64(reqCnt), MaxTxSend) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
stats := make([]light.TxStatus, len(req.Txs)) stats := make([]light.TxStatus, len(req.Txs))
for i, tx := range req.Txs { for i, tx := range req.Txs {
if i != 0 && !task.waitOrStop() { if i != 0 && !task.waitOrStop() {
@ -771,7 +791,9 @@ func (h *serverHandler) handleMsg(p *peer) error {
} }
reqCnt := len(req.Hashes) reqCnt := len(req.Hashes)
if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
stats := make([]light.TxStatus, len(req.Hashes)) stats := make([]light.TxStatus, len(req.Hashes))
for i, hash := range req.Hashes { for i, hash := range req.Hashes {
if i != 0 && !task.waitOrStop() { if i != 0 && !task.waitOrStop() {

@ -280,6 +280,8 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
} }
server.costTracker, server.freeCapacity = newCostTracker(db, server.config) server.costTracker, server.freeCapacity = newCostTracker(db, server.config)
server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
server.clientPool = newClientPool(db, 1, 10000, clock, nil)
server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool
server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
if server.oracle != nil { if server.oracle != nil {
server.oracle.start(simulation) server.oracle.start(simulation)