From fc01a7ce8e4f865b0d282c21a42c21b0bd1d0b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= Date: Tue, 14 Dec 2021 11:34:50 +0100 Subject: [PATCH] les/vflux/client, p2p/nodestate: fix data races (#24058) Fixes #23848 --- les/vflux/client/serverpool_test.go | 39 +++++++++++++++++++++++++---- p2p/nodestate/nodestate.go | 9 ++++++- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/les/vflux/client/serverpool_test.go b/les/vflux/client/serverpool_test.go index 763f72f03a..c7d0245ef2 100644 --- a/les/vflux/client/serverpool_test.go +++ b/les/vflux/client/serverpool_test.go @@ -63,7 +63,11 @@ type ServerPoolTest struct { trusted []string waitCount, waitEnded int32 - lock sync.Mutex + // preNegLock protects the cycle counter, testNodes list and its connected field + // (accessed from both the main thread and the preNeg callback) + preNegLock sync.Mutex + queryWg *sync.WaitGroup // a new wait group is created each time the simulation is started + stopping bool // stopping avoid callind queryWg.Add after queryWg.Wait cycle, conn, servedConn int serviceCycles, dialCount int @@ -111,13 +115,21 @@ func (s *ServerPoolTest) addTrusted(i int) { func (s *ServerPoolTest) start() { var testQuery QueryFunc + s.queryWg = new(sync.WaitGroup) if s.preNeg { testQuery = func(node *enode.Node) int { + s.preNegLock.Lock() + if s.stopping { + s.preNegLock.Unlock() + return 0 + } + s.queryWg.Add(1) idx := testNodeIndex(node.ID()) n := &s.testNodes[idx] - s.lock.Lock() canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle - s.lock.Unlock() + s.preNegLock.Unlock() + defer s.queryWg.Done() + if s.preNegFail { // simulate a scenario where UDP queries never work s.beginWait() @@ -181,11 +193,20 @@ func (s *ServerPoolTest) start() { } func (s *ServerPoolTest) stop() { + // disable further queries and wait if one is currently running + s.preNegLock.Lock() + s.stopping = true + s.preNegLock.Unlock() + s.queryWg.Wait() + quit := make(chan struct{}) s.quit <- quit <-quit s.sp.Stop() s.spi.Close() + s.preNegLock.Lock() + s.stopping = false + s.preNegLock.Unlock() for i := range s.testNodes { n := &s.testNodes[i] if n.connected { @@ -205,7 +226,9 @@ func (s *ServerPoolTest) run() { n := &s.testNodes[idx] s.sp.UnregisterNode(n.node) n.totalConn += s.cycle + s.preNegLock.Lock() n.connected = false + s.preNegLock.Unlock() n.node = nil s.conn-- if n.service { @@ -230,7 +253,9 @@ func (s *ServerPoolTest) run() { s.servedConn++ } n.totalConn -= s.cycle + s.preNegLock.Lock() n.connected = true + s.preNegLock.Unlock() dc := s.cycle + n.connectCycles s.disconnect[dc] = append(s.disconnect[dc], idx) n.node = dial @@ -242,9 +267,9 @@ func (s *ServerPoolTest) run() { } s.serviceCycles += s.servedConn s.clock.Run(time.Second) - s.lock.Lock() + s.preNegLock.Lock() s.cycle++ - s.lock.Unlock() + s.preNegLock.Unlock() } } @@ -255,11 +280,13 @@ func (s *ServerPoolTest) setNodes(count, conn, wait int, service, trusted bool) idx = rand.Intn(spTestNodes) } res = append(res, idx) + s.preNegLock.Lock() s.testNodes[idx] = spTestNode{ connectCycles: conn, waitCycles: wait, service: service, } + s.preNegLock.Unlock() if trusted { s.addTrusted(idx) } @@ -273,7 +300,9 @@ func (s *ServerPoolTest) resetNodes() { n.totalConn += s.cycle s.sp.UnregisterNode(n.node) } + s.preNegLock.Lock() s.testNodes[i] = spTestNode{totalConn: n.totalConn} + s.preNegLock.Unlock() } s.conn, s.servedConn = 0, 0 s.disconnect = make(map[int][]int) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 9323d53cbd..2af0d0a6bd 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -808,7 +808,14 @@ func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time ns.removeTimeouts(node, mask) t := &nodeStateTimeout{mask: mask} t.timer = ns.clock.AfterFunc(timeout, func() { - ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) + ns.lock.Lock() + defer ns.lock.Unlock() + + if !ns.opStart() { + return + } + ns.setState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) + ns.opFinish() }) node.timeouts = append(node.timeouts, t) if mask&ns.saveFlags != 0 {