diff --git a/p2p/dial.go b/p2p/dial.go index 3975b488bf..d190e866af 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -17,11 +17,17 @@ package p2p import ( + "context" + crand "crypto/rand" + "encoding/binary" "errors" "fmt" + mrand "math/rand" "net" + "sync" "time" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -33,8 +39,9 @@ const ( // private networks. dialHistoryExpiration = inboundThrottleTime + 5*time.Second - // If no peers are found for this amount of time, the initial bootnodes are dialed. - fallbackInterval = 20 * time.Second + // Config for the "Looking for peers" message. + dialStatsLogInterval = 10 * time.Second // printed at most this often + dialStatsPeerLimit = 3 // but not if more than this many dialed peers // Endpoint resolution is throttled with bounded backoff. initialResolveDelay = 60 * time.Second @@ -42,161 +49,29 @@ const ( ) // NodeDialer is used to connect to nodes in the network, typically by using -// an underlying net.Dialer but also using net.Pipe in tests +// an underlying net.Dialer but also using net.Pipe in tests. type NodeDialer interface { - Dial(*enode.Node) (net.Conn, error) + Dial(context.Context, *enode.Node) (net.Conn, error) } type nodeResolver interface { Resolve(*enode.Node) *enode.Node } -// TCPDialer implements the NodeDialer interface by using a net.Dialer to -// create TCP connections to nodes in the network -type TCPDialer struct { - *net.Dialer +// tcpDialer implements NodeDialer using real TCP connections. +type tcpDialer struct { + d *net.Dialer } -// Dial creates a TCP connection to the node -func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { - addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()} - return t.Dialer.Dial("tcp", addr.String()) +func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { + return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String()) } -// dialstate schedules dials and discovery lookups. -// It gets a chance to compute new tasks on every iteration -// of the main loop in Server.run. -type dialstate struct { - maxDynDials int - netrestrict *netutil.Netlist - self enode.ID - bootnodes []*enode.Node // default dials when there are no peers - log log.Logger - - start time.Time // time when the dialer was first used - lookupRunning bool - dialing map[enode.ID]connFlag - lookupBuf []*enode.Node // current discovery lookup results - static map[enode.ID]*dialTask - hist expHeap -} - -type task interface { - Do(*Server) -} - -func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate { - s := &dialstate{ - maxDynDials: maxdyn, - self: self, - netrestrict: cfg.NetRestrict, - log: cfg.Logger, - static: make(map[enode.ID]*dialTask), - dialing: make(map[enode.ID]connFlag), - bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), - } - copy(s.bootnodes, cfg.BootstrapNodes) - if s.log == nil { - s.log = log.Root() - } - for _, n := range cfg.StaticNodes { - s.addStatic(n) - } - return s -} - -func (s *dialstate) addStatic(n *enode.Node) { - // This overwrites the task instead of updating an existing - // entry, giving users the opportunity to force a resolve operation. - s.static[n.ID()] = &dialTask{flags: staticDialedConn, dest: n} -} - -func (s *dialstate) removeStatic(n *enode.Node) { - // This removes a task so future attempts to connect will not be made. - delete(s.static, n.ID()) -} - -func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { - var newtasks []task - addDial := func(flag connFlag, n *enode.Node) bool { - if err := s.checkDial(n, peers); err != nil { - s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) - return false - } - s.dialing[n.ID()] = flag - newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) - return true - } - - if s.start.IsZero() { - s.start = now - } - s.hist.expire(now) - - // Create dials for static nodes if they are not connected. - for id, t := range s.static { - err := s.checkDial(t.dest, peers) - switch err { - case errNotWhitelisted, errSelf: - s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) - delete(s.static, t.dest.ID()) - case nil: - s.dialing[id] = t.flags - newtasks = append(newtasks, t) - } - } - - // Compute number of dynamic dials needed. - needDynDials := s.maxDynDials - for _, p := range peers { - if p.rw.is(dynDialedConn) { - needDynDials-- - } - } - for _, flag := range s.dialing { - if flag&dynDialedConn != 0 { - needDynDials-- - } - } - - // If we don't have any peers whatsoever, try to dial a random bootnode. This - // scenario is useful for the testnet (and private networks) where the discovery - // table might be full of mostly bad peers, making it hard to find good ones. - if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval { - bootnode := s.bootnodes[0] - s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) - s.bootnodes = append(s.bootnodes, bootnode) - if addDial(dynDialedConn, bootnode) { - needDynDials-- - } - } - - // Create dynamic dials from discovery results. - i := 0 - for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { - if addDial(dynDialedConn, s.lookupBuf[i]) { - needDynDials-- - } - } - s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] - - // Launch a discovery lookup if more candidates are needed. - if len(s.lookupBuf) < needDynDials && !s.lookupRunning { - s.lookupRunning = true - newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)}) - } - - // Launch a timer to wait for the next node to expire if all - // candidates have been tried and no task is currently active. - // This should prevent cases where the dialer logic is not ticked - // because there are no pending events. - if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { - t := &waitExpireTask{s.hist.nextExpiry().Sub(now)} - newtasks = append(newtasks, t) - } - return newtasks +func nodeAddr(n *enode.Node) net.Addr { + return &net.TCPAddr{IP: n.IP(), Port: n.TCP()} } +// checkDial errors: var ( errSelf = errors.New("is self") errAlreadyDialing = errors.New("already dialing") @@ -205,56 +80,412 @@ var ( errNotWhitelisted = errors.New("not contained in netrestrict whitelist") ) -func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { - _, dialing := s.dialing[n.ID()] - switch { - case dialing: - return errAlreadyDialing - case peers[n.ID()] != nil: - return errAlreadyConnected - case n.ID() == s.self: +// dialer creates outbound connections and submits them into Server. +// Two types of peer connections can be created: +// +// - static dials are pre-configured connections. The dialer attempts +// keep these nodes connected at all times. +// +// - dynamic dials are created from node discovery results. The dialer +// continuously reads candidate nodes from its input iterator and attempts +// to create peer connections to nodes arriving through the iterator. +// +type dialScheduler struct { + dialConfig + setupFunc dialSetupFunc + wg sync.WaitGroup + cancel context.CancelFunc + ctx context.Context + nodesIn chan *enode.Node + doneCh chan *dialTask + addStaticCh chan *enode.Node + remStaticCh chan *enode.Node + addPeerCh chan *conn + remPeerCh chan *conn + + // Everything below here belongs to loop and + // should only be accessed by code on the loop goroutine. + dialing map[enode.ID]*dialTask // active tasks + peers map[enode.ID]connFlag // all connected peers + dialPeers int // current number of dialed peers + + // The static map tracks all static dial tasks. The subset of usable static dial tasks + // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers + // launching random static tasks from the pool over launching dynamic dials from the + // iterator. + static map[enode.ID]*dialTask + staticPool []*dialTask + + // The dial history keeps recently dialed nodes. Members of history are not dialed. + history expHeap + historyTimer mclock.Timer + historyTimerTime mclock.AbsTime + + // for logStats + lastStatsLog mclock.AbsTime + doneSinceLastLog int +} + +type dialSetupFunc func(net.Conn, connFlag, *enode.Node) error + +type dialConfig struct { + self enode.ID // our own ID + maxDialPeers int // maximum number of dialed peers + maxActiveDials int // maximum number of active dials + netRestrict *netutil.Netlist // IP whitelist, disabled if nil + resolver nodeResolver + dialer NodeDialer + log log.Logger + clock mclock.Clock + rand *mrand.Rand +} + +func (cfg dialConfig) withDefaults() dialConfig { + if cfg.maxActiveDials == 0 { + cfg.maxActiveDials = defaultMaxPendingPeers + } + if cfg.log == nil { + cfg.log = log.Root() + } + if cfg.clock == nil { + cfg.clock = mclock.System{} + } + if cfg.rand == nil { + seedb := make([]byte, 8) + crand.Read(seedb) + seed := int64(binary.BigEndian.Uint64(seedb)) + cfg.rand = mrand.New(mrand.NewSource(seed)) + } + return cfg +} + +func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { + d := &dialScheduler{ + dialConfig: config.withDefaults(), + setupFunc: setupFunc, + dialing: make(map[enode.ID]*dialTask), + static: make(map[enode.ID]*dialTask), + peers: make(map[enode.ID]connFlag), + doneCh: make(chan *dialTask), + nodesIn: make(chan *enode.Node), + addStaticCh: make(chan *enode.Node), + remStaticCh: make(chan *enode.Node), + addPeerCh: make(chan *conn), + remPeerCh: make(chan *conn), + } + d.lastStatsLog = d.clock.Now() + d.ctx, d.cancel = context.WithCancel(context.Background()) + d.wg.Add(2) + go d.readNodes(it) + go d.loop(it) + return d +} + +// stop shuts down the dialer, canceling all current dial tasks. +func (d *dialScheduler) stop() { + d.cancel() + d.wg.Wait() +} + +// addStatic adds a static dial candidate. +func (d *dialScheduler) addStatic(n *enode.Node) { + select { + case d.addStaticCh <- n: + case <-d.ctx.Done(): + } +} + +// removeStatic removes a static dial candidate. +func (d *dialScheduler) removeStatic(n *enode.Node) { + select { + case d.remStaticCh <- n: + case <-d.ctx.Done(): + } +} + +// peerAdded updates the peer set. +func (d *dialScheduler) peerAdded(c *conn) { + select { + case d.addPeerCh <- c: + case <-d.ctx.Done(): + } +} + +// peerRemoved updates the peer set. +func (d *dialScheduler) peerRemoved(c *conn) { + select { + case d.remPeerCh <- c: + case <-d.ctx.Done(): + } +} + +// loop is the main loop of the dialer. +func (d *dialScheduler) loop(it enode.Iterator) { + var ( + nodesCh chan *enode.Node + historyExp = make(chan struct{}, 1) + ) + +loop: + for { + // Launch new dials if slots are available. + slots := d.freeDialSlots() + slots -= d.startStaticDials(slots) + if slots > 0 { + nodesCh = d.nodesIn + } else { + nodesCh = nil + } + d.rearmHistoryTimer(historyExp) + d.logStats() + + select { + case node := <-nodesCh: + if err := d.checkDial(node); err != nil { + d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err) + } else { + d.startDial(newDialTask(node, dynDialedConn)) + } + + case task := <-d.doneCh: + id := task.dest.ID() + delete(d.dialing, id) + d.updateStaticPool(id) + d.doneSinceLastLog++ + + case c := <-d.addPeerCh: + if c.is(dynDialedConn) || c.is(staticDialedConn) { + d.dialPeers++ + } + id := c.node.ID() + d.peers[id] = c.flags + // Remove from static pool because the node is now connected. + task := d.static[id] + if task != nil && task.staticPoolIndex >= 0 { + d.removeFromStaticPool(task.staticPoolIndex) + } + // TODO: cancel dials to connected peers + + case c := <-d.remPeerCh: + if c.is(dynDialedConn) || c.is(staticDialedConn) { + d.dialPeers-- + } + delete(d.peers, c.node.ID()) + d.updateStaticPool(c.node.ID()) + + case node := <-d.addStaticCh: + id := node.ID() + _, exists := d.static[id] + d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists) + if exists { + continue loop + } + task := newDialTask(node, staticDialedConn) + d.static[id] = task + if d.checkDial(node) == nil { + d.addToStaticPool(task) + } + + case node := <-d.remStaticCh: + id := node.ID() + task := d.static[id] + d.log.Trace("Removing static node", "id", id, "ok", task != nil) + if task != nil { + delete(d.static, id) + if task.staticPoolIndex >= 0 { + d.removeFromStaticPool(task.staticPoolIndex) + } + } + + case <-historyExp: + d.expireHistory() + + case <-d.ctx.Done(): + it.Close() + break loop + } + } + + d.stopHistoryTimer(historyExp) + for range d.dialing { + <-d.doneCh + } + d.wg.Done() +} + +// readNodes runs in its own goroutine and delivers nodes from +// the input iterator to the nodesIn channel. +func (d *dialScheduler) readNodes(it enode.Iterator) { + defer d.wg.Done() + + for it.Next() { + select { + case d.nodesIn <- it.Node(): + case <-d.ctx.Done(): + } + } +} + +// logStats prints dialer statistics to the log. The message is suppressed when enough +// peers are connected because users should only see it while their client is starting up +// or comes back online. +func (d *dialScheduler) logStats() { + now := d.clock.Now() + if d.lastStatsLog.Add(dialStatsLogInterval) > now { + return + } + if d.dialPeers < dialStatsPeerLimit && d.dialPeers < d.maxDialPeers { + d.log.Info("Looking for peers", "peercount", len(d.peers), "tried", d.doneSinceLastLog, "static", len(d.static)) + } + d.doneSinceLastLog = 0 + d.lastStatsLog = now +} + +// rearmHistoryTimer configures d.historyTimer to fire when the +// next item in d.history expires. +func (d *dialScheduler) rearmHistoryTimer(ch chan struct{}) { + if len(d.history) == 0 || d.historyTimerTime == d.history.nextExpiry() { + return + } + d.stopHistoryTimer(ch) + d.historyTimerTime = d.history.nextExpiry() + timeout := time.Duration(d.historyTimerTime - d.clock.Now()) + d.historyTimer = d.clock.AfterFunc(timeout, func() { ch <- struct{}{} }) +} + +// stopHistoryTimer stops the timer and drains the channel it sends on. +func (d *dialScheduler) stopHistoryTimer(ch chan struct{}) { + if d.historyTimer != nil && !d.historyTimer.Stop() { + <-ch + } +} + +// expireHistory removes expired items from d.history. +func (d *dialScheduler) expireHistory() { + d.historyTimer.Stop() + d.historyTimer = nil + d.historyTimerTime = 0 + d.history.expire(d.clock.Now(), func(hkey string) { + var id enode.ID + copy(id[:], hkey) + d.updateStaticPool(id) + }) +} + +// freeDialSlots returns the number of free dial slots. The result can be negative +// when peers are connected while their task is still running. +func (d *dialScheduler) freeDialSlots() int { + slots := (d.maxDialPeers - d.dialPeers) * 2 + if slots > d.maxActiveDials { + slots = d.maxActiveDials + } + free := slots - len(d.dialing) + return free +} + +// checkDial returns an error if node n should not be dialed. +func (d *dialScheduler) checkDial(n *enode.Node) error { + if n.ID() == d.self { return errSelf - case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): + } + if _, ok := d.dialing[n.ID()]; ok { + return errAlreadyDialing + } + if _, ok := d.peers[n.ID()]; ok { + return errAlreadyConnected + } + if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) { return errNotWhitelisted - case s.hist.contains(string(n.ID().Bytes())): + } + if d.history.contains(string(n.ID().Bytes())) { return errRecentlyDialed } return nil } -func (s *dialstate) taskDone(t task, now time.Time) { - switch t := t.(type) { - case *dialTask: - s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration)) - delete(s.dialing, t.dest.ID()) - case *discoverTask: - s.lookupRunning = false - s.lookupBuf = append(s.lookupBuf, t.results...) +// startStaticDials starts n static dial tasks. +func (d *dialScheduler) startStaticDials(n int) (started int) { + for started = 0; started < n && len(d.staticPool) > 0; started++ { + idx := d.rand.Intn(len(d.staticPool)) + task := d.staticPool[idx] + d.startDial(task) + d.removeFromStaticPool(idx) + } + return started +} + +// updateStaticPool attempts to move the given static dial back into staticPool. +func (d *dialScheduler) updateStaticPool(id enode.ID) { + task, ok := d.static[id] + if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil { + d.addToStaticPool(task) } } -// A dialTask is generated for each node that is dialed. Its -// fields cannot be accessed while the task is running. +func (d *dialScheduler) addToStaticPool(task *dialTask) { + if task.staticPoolIndex >= 0 { + panic("attempt to add task to staticPool twice") + } + d.staticPool = append(d.staticPool, task) + task.staticPoolIndex = len(d.staticPool) - 1 +} + +// removeFromStaticPool removes the task at idx from staticPool. It does that by moving the +// current last element of the pool to idx and then shortening the pool by one. +func (d *dialScheduler) removeFromStaticPool(idx int) { + task := d.staticPool[idx] + end := len(d.staticPool) - 1 + d.staticPool[idx] = d.staticPool[end] + d.staticPool[idx].staticPoolIndex = idx + d.staticPool[end] = nil + d.staticPool = d.staticPool[:end] + task.staticPoolIndex = -1 +} + +// startDial runs the given dial task in a separate goroutine. +func (d *dialScheduler) startDial(task *dialTask) { + d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags) + hkey := string(task.dest.ID().Bytes()) + d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) + d.dialing[task.dest.ID()] = task + go func() { + task.run(d) + d.doneCh <- task + }() +} + +// A dialTask generated for each node that is dialed. type dialTask struct { - flags connFlag + staticPoolIndex int + flags connFlag + // These fields are private to the task and should not be + // accessed by dialScheduler while the task is running. dest *enode.Node - lastResolved time.Time + lastResolved mclock.AbsTime resolveDelay time.Duration } -func (t *dialTask) Do(srv *Server) { +func newDialTask(dest *enode.Node, flags connFlag) *dialTask { + return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1} +} + +type dialError struct { + error +} + +func (t *dialTask) run(d *dialScheduler) { if t.dest.Incomplete() { - if !t.resolve(srv) { + if !t.resolve(d) { return } } - err := t.dial(srv, t.dest) + + err := t.dial(d, t.dest) if err != nil { - srv.log.Trace("Dial error", "task", t, "err", err) // Try resolving the ID of static nodes if dialing failed. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { - if t.resolve(srv) { - t.dial(srv, t.dest) + if t.resolve(d) { + t.dial(d, t.dest) } } } @@ -266,46 +497,42 @@ func (t *dialTask) Do(srv *Server) { // Resolve operations are throttled with backoff to avoid flooding the // discovery network with useless queries for nodes that don't exist. // The backoff delay resets when the node is found. -func (t *dialTask) resolve(srv *Server) bool { - if srv.staticNodeResolver == nil { - srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled") +func (t *dialTask) resolve(d *dialScheduler) bool { + if d.resolver == nil { return false } if t.resolveDelay == 0 { t.resolveDelay = initialResolveDelay } - if time.Since(t.lastResolved) < t.resolveDelay { + if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay { return false } - resolved := srv.staticNodeResolver.Resolve(t.dest) - t.lastResolved = time.Now() + resolved := d.resolver.Resolve(t.dest) + t.lastResolved = d.clock.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) + d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } -type dialError struct { - error -} - // dial performs the actual connection attempt. -func (t *dialTask) dial(srv *Server, dest *enode.Node) error { - fd, err := srv.Dialer.Dial(dest) +func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { + fd, err := d.dialer.Dial(d.ctx, t.dest) if err != nil { + d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err)) return &dialError{err} } mfd := newMeteredConn(fd, false, &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()}) - return srv.SetupConn(mfd, t.flags, dest) + return d.setupFunc(mfd, t.flags, dest) } func (t *dialTask) String() string { @@ -313,37 +540,9 @@ func (t *dialTask) String() string { return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } -// discoverTask runs discovery table operations. -// Only one discoverTask is active at any time. -// discoverTask.Do performs a random lookup. -type discoverTask struct { - want int - results []*enode.Node -} - -func (t *discoverTask) Do(srv *Server) { - t.results = enode.ReadNodes(srv.discmix, t.want) -} - -func (t *discoverTask) String() string { - s := "discovery query" - if len(t.results) > 0 { - s += fmt.Sprintf(" (%d results)", len(t.results)) - } else { - s += fmt.Sprintf(" (want %d)", t.want) +func cleanupDialErr(err error) error { + if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" { + return netErr.Err } - return s -} - -// A waitExpireTask is generated if there are no other tasks -// to keep the loop in Server.run ticking. -type waitExpireTask struct { - time.Duration -} - -func (t waitExpireTask) Do(*Server) { - time.Sleep(t.Duration) -} -func (t waitExpireTask) String() string { - return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) + return err } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 6189ec4d0b..cd8dedff1c 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -17,574 +17,656 @@ package p2p import ( - "encoding/binary" + "context" + "errors" + "fmt" + "math/rand" "net" "reflect" - "strings" + "sync" "testing" "time" - "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" ) -func init() { - spew.Config.Indent = "\t" -} - -type dialtest struct { - init *dialstate // state before and after the test. - rounds []round -} - -type round struct { - peers []*Peer // current peer set - done []task // tasks that got done this round - new []task // the result must match this one -} - -func runDialTest(t *testing.T, test dialtest) { - var ( - vtime time.Time - running int - ) - pm := func(ps []*Peer) map[enode.ID]*Peer { - m := make(map[enode.ID]*Peer) - for _, p := range ps { - m[p.ID()] = p - } - return m - } - for i, round := range test.rounds { - for _, task := range round.done { - running-- - if running < 0 { - panic("running task counter underflow") - } - test.init.taskDone(task, vtime) - } - - new := test.init.newTasks(running, pm(round.peers), vtime) - if !sametasks(new, round.new) { - t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", - i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) - } - t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new))) - - // Time advances by 16 seconds on every round. - vtime = vtime.Add(16 * time.Second) - running += len(new) - } -} - // This test checks that dynamic dials are launched from discovery results. -func TestDialStateDynDial(t *testing.T) { - config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 5, config), - rounds: []round{ - // A discovery query is launched. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &discoverTask{want: 3}, - }, +func TestDialSchedDynDial(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, + } + runDialTest(t, config, []dialTestRound{ + // 3 out of 4 peers are connected, leaving 2 dial slots. + // 9 nodes are discovered, but only 2 are dialed. + { + peersAdded: []*conn{ + {flags: staticDialedConn, node: newNode(uintID(0x00), "")}, + {flags: dynDialedConn, node: newNode(uintID(0x01), "")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "")}, }, - // Dynamic dials are launched when it completes. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &discoverTask{results: []*enode.Node{ - newNode(uintID(2), nil), // this one is already connected and not dialed. - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), // these are not tried because max dyn dials is 5 - newNode(uintID(7), nil), // ... - }}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, + discovered: []*enode.Node{ + newNode(uintID(0x00), "127.0.0.1:30303"), // not dialed because already connected as static peer + newNode(uintID(0x02), "127.0.0.1:30303"), // ... + newNode(uintID(0x03), "127.0.0.1:30303"), + newNode(uintID(0x04), "127.0.0.1:30303"), + newNode(uintID(0x05), "127.0.0.1:30303"), // not dialed because there are only two slots + newNode(uintID(0x06), "127.0.0.1:30303"), // ... + newNode(uintID(0x07), "127.0.0.1:30303"), // ... + newNode(uintID(0x08), "127.0.0.1:30303"), // ... }, - // Some of the dials complete but no new ones are launched yet because - // the sum of active dial count and dynamic peer count is == maxDynDials. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.1:30303"), + newNode(uintID(0x04), "127.0.0.1:30303"), }, - // No new dial tasks are launched in the this round because - // maxDynDials has been reached. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, + }, + + // One dial completes, freeing one dial slot. + { + failed: []enode.ID{ + uintID(0x04), }, - // In this round, the peer with id 2 drops off. The query - // results from last discovery lookup are reused. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x05), "127.0.0.1:30303"), }, - // More peers (3,4) drop off and dial for ID 6 completes. - // The last query result from the discovery lookup is reused - // and a new one is spawned because more candidates are needed. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - &discoverTask{want: 2}, - }, + }, + + // Dial to 0x03 completes, filling the last remaining peer slot. + { + succeeded: []enode.ID{ + uintID(0x03), }, - // Peer 7 is connected, but there still aren't enough dynamic peers - // (4 out of 5). However, a discovery is already running, so ensure - // no new is started. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - }, + failed: []enode.ID{ + uintID(0x05), }, - // Finish the running node discovery with an empty set. A new lookup - // should be immediately requested. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(0), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(5), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(7), nil)}}, - }, - done: []task{ - &discoverTask{}, - }, - new: []task{ - &discoverTask{want: 2}, - }, + discovered: []*enode.Node{ + newNode(uintID(0x09), "127.0.0.1:30303"), // not dialed because there are no free slots + }, + }, + + // 3 peers drop off, creating 6 dial slots. Check that 5 of those slots + // (i.e. up to maxActiveDialTasks) are used. + { + peersRemoved: []enode.ID{ + uintID(0x00), + uintID(0x01), + uintID(0x02), + }, + discovered: []*enode.Node{ + newNode(uintID(0x0a), "127.0.0.1:30303"), + newNode(uintID(0x0b), "127.0.0.1:30303"), + newNode(uintID(0x0c), "127.0.0.1:30303"), + newNode(uintID(0x0d), "127.0.0.1:30303"), + newNode(uintID(0x0f), "127.0.0.1:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x06), "127.0.0.1:30303"), + newNode(uintID(0x07), "127.0.0.1:30303"), + newNode(uintID(0x08), "127.0.0.1:30303"), + newNode(uintID(0x09), "127.0.0.1:30303"), + newNode(uintID(0x0a), "127.0.0.1:30303"), }, }, }) } -// Tests that bootnodes are dialed if no peers are connectd, but not otherwise. -func TestDialStateDynDialBootnode(t *testing.T) { - config := &Config{ - BootstrapNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - }, - Logger: testlog.Logger(t, log.LvlTrace), - } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 5, config), - rounds: []round{ - { - new: []task{ - &discoverTask{want: 5}, - }, - }, - { - done: []task{ - &discoverTask{ - results: []*enode.Node{ - newNode(uintID(4), nil), - newNode(uintID(5), nil), - }, - }, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{want: 3}, - }, - }, - // No dials succeed, bootnodes still pending fallback interval - {}, - // 1 bootnode attempted as fallback interval was reached - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - }, - }, - // No dials succeed, 2nd bootnode is attempted - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - }, - }, - // No dials succeed, 3rd bootnode is attempted - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - }, - }, - // No dials succeed, 1st bootnode is attempted again, expired random nodes retried - { - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &discoverTask{results: []*enode.Node{ - newNode(uintID(6), nil), - }}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, - &discoverTask{want: 4}, - }, - }, - // Random dial succeeds, no more bootnodes are attempted - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}}, - }, - }, - }, - }) -} +// This test checks that candidates that do not match the netrestrict list are not dialed. +func TestDialSchedNetRestrict(t *testing.T) { + t.Parallel() -func newNode(id enode.ID, ip net.IP) *enode.Node { - var r enr.Record - if ip != nil { - r.Set(enr.IP(ip)) - } - return enode.SignNull(&r, id) -} - -// // This test checks that candidates that do not match the netrestrict list are not dialed. -func TestDialStateNetRestrict(t *testing.T) { - // This table always returns the same random nodes - // in the order given below. nodes := []*enode.Node{ - newNode(uintID(1), net.ParseIP("127.0.0.1")), - newNode(uintID(2), net.ParseIP("127.0.0.2")), - newNode(uintID(3), net.ParseIP("127.0.0.3")), - newNode(uintID(4), net.ParseIP("127.0.0.4")), - newNode(uintID(5), net.ParseIP("127.0.2.5")), - newNode(uintID(6), net.ParseIP("127.0.2.6")), - newNode(uintID(7), net.ParseIP("127.0.2.7")), - newNode(uintID(8), net.ParseIP("127.0.2.8")), + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x02), "127.0.0.2:30303"), + newNode(uintID(0x03), "127.0.0.3:30303"), + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.2.5:30303"), + newNode(uintID(0x06), "127.0.2.6:30303"), + newNode(uintID(0x07), "127.0.2.7:30303"), + newNode(uintID(0x08), "127.0.2.8:30303"), } - restrict := new(netutil.Netlist) - restrict.Add("127.0.2.0/24") - - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}), - rounds: []round{ - { - new: []task{ - &discoverTask{want: 10}, - }, - }, - { - done: []task{ - &discoverTask{results: nodes}, - }, - new: []task{ - &dialTask{flags: dynDialedConn, dest: nodes[4]}, - &dialTask{flags: dynDialedConn, dest: nodes[5]}, - &dialTask{flags: dynDialedConn, dest: nodes[6]}, - &dialTask{flags: dynDialedConn, dest: nodes[7]}, - &discoverTask{want: 6}, - }, + config := dialConfig{ + netRestrict: new(netutil.Netlist), + maxActiveDials: 10, + maxDialPeers: 10, + } + config.netRestrict.Add("127.0.2.0/24") + runDialTest(t, config, []dialTestRound{ + { + discovered: nodes, + wantNewDials: nodes[4:8], + }, + { + succeeded: []enode.ID{ + nodes[4].ID(), + nodes[5].ID(), + nodes[6].ID(), + nodes[7].ID(), }, }, }) } -// This test checks that static dials are launched. -func TestDialStateStaticDial(t *testing.T) { - config := &Config{ - StaticNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - }, - Logger: testlog.Logger(t, log.LvlTrace), +// This test checks that static dials work and obey the limits. +func TestDialSchedStaticDial(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 0, config), - rounds: []round{ - // Static dials are launched for the nodes that - // aren't yet connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, - }, + runDialTest(t, config, []dialTestRound{ + // Static dials are launched for the nodes that + // aren't yet connected. + { + peersAdded: []*conn{ + {flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")}, }, - // No new tasks are launched in this round because all static - // nodes are either connected or still being dialed. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, + update: func(d *dialScheduler) { + // These two are not dialed because they're already connected + // as dynamic peers. + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + // These nodes will be dialed: + d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303")) + d.addStatic(newNode(uintID(0x04), "127.0.0.4:30303")) + d.addStatic(newNode(uintID(0x05), "127.0.0.5:30303")) + d.addStatic(newNode(uintID(0x06), "127.0.0.6:30303")) + d.addStatic(newNode(uintID(0x07), "127.0.0.7:30303")) + d.addStatic(newNode(uintID(0x08), "127.0.0.8:30303")) + d.addStatic(newNode(uintID(0x09), "127.0.0.9:30303")) }, - // No new dial tasks are launched because all static - // nodes are now connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(5), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + newNode(uintID(0x06), "127.0.0.6:30303"), }, - // Wait a round for dial history to expire, no new tasks should spawn. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(4), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, + }, + // Dial to 0x03 completes, filling a peer slot. One slot remains, + // two dials are launched to attempt to fill it. + { + succeeded: []enode.ID{ + uintID(0x03), }, - // If a static node is dropped, it should be immediately redialed, - // irrespective whether it was originally static or dynamic. - { - done: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(3), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(5), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, + failed: []enode.ID{ + uintID(0x04), + uintID(0x05), + uintID(0x06), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x04): nil, + uintID(0x05): nil, + uintID(0x06): nil, + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x08), "127.0.0.8:30303"), + newNode(uintID(0x09), "127.0.0.9:30303"), + }, + }, + // Peer 0x01 drops and 0x07 connects as inbound peer. + // Only 0x01 is dialed. + { + peersAdded: []*conn{ + {flags: inboundConn, node: newNode(uintID(0x07), "127.0.0.7:30303")}, + }, + peersRemoved: []enode.ID{ + uintID(0x01), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + }, + }, + }) +} + +// This test checks that removing static nodes stops connecting to them. +func TestDialSchedRemoveStatic(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 1, + maxDialPeers: 1, + } + runDialTest(t, config, []dialTestRound{ + // Add static nodes. + { + update: func(d *dialScheduler) { + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303")) + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + }, + }, + // Dial to 0x01 fails. + { + failed: []enode.ID{ + uintID(0x01), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): nil, + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x02), "127.0.0.2:30303"), + }, + }, + // All static nodes are removed. 0x01 is in history, 0x02 is being + // dialed, 0x03 is in staticPool. + { + update: func(d *dialScheduler) { + d.removeStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.removeStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.removeStatic(newNode(uintID(0x03), "127.0.0.3:30303")) + }, + failed: []enode.ID{ + uintID(0x02), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x02): nil, + }, + }, + // Since all static nodes are removed, they should not be dialed again. + {}, {}, {}, + }) +} + +// This test checks that static dials are selected at random. +func TestDialSchedManyStaticNodes(t *testing.T) { + t.Parallel() + + config := dialConfig{maxDialPeers: 2} + runDialTest(t, config, []dialTestRound{ + { + peersAdded: []*conn{ + {flags: dynDialedConn, node: newNode(uintID(0xFFFE), "")}, + {flags: dynDialedConn, node: newNode(uintID(0xFFFF), "")}, + }, + update: func(d *dialScheduler) { + for id := uint16(0); id < 2000; id++ { + n := newNode(uintID(id), "127.0.0.1:30303") + d.addStatic(n) + } + }, + }, + { + peersRemoved: []enode.ID{ + uintID(0xFFFE), + uintID(0xFFFF), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x0085), "127.0.0.1:30303"), + newNode(uintID(0x02dc), "127.0.0.1:30303"), + newNode(uintID(0x0285), "127.0.0.1:30303"), + newNode(uintID(0x00cb), "127.0.0.1:30303"), }, }, }) } // This test checks that past dials are not retried for some time. -func TestDialStateCache(t *testing.T) { - config := &Config{ - StaticNodes: []*enode.Node{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - }, - Logger: testlog.Logger(t, log.LvlTrace), +func TestDialSchedHistory(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 3, + maxDialPeers: 3, } - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, 0, config), - rounds: []round{ - // Static dials are launched for the nodes that - // aren't yet connected. - { - peers: nil, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, + runDialTest(t, config, []dialTestRound{ + { + update: func(d *dialScheduler) { + d.addStatic(newNode(uintID(0x01), "127.0.0.1:30303")) + d.addStatic(newNode(uintID(0x02), "127.0.0.2:30303")) + d.addStatic(newNode(uintID(0x03), "127.0.0.3:30303")) }, - // No new tasks are launched in this round because all static - // nodes are either connected or still being dialed. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: staticDialedConn, dest: newNode(uintID(2), nil)}, - }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x01), "127.0.0.1:30303"), + newNode(uintID(0x02), "127.0.0.2:30303"), + newNode(uintID(0x03), "127.0.0.3:30303"), }, - // A salvage task is launched to wait for node 3's history - // entry to expire. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, - new: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, + }, + // No new tasks are launched in this round because all static + // nodes are either connected or still being dialed. + { + succeeded: []enode.ID{ + uintID(0x01), + uintID(0x02), }, - // Still waiting for node 3's entry to expire in the cache. - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, + failed: []enode.ID{ + uintID(0x03), }, - { - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x03): nil, }, - // The cache entry for node 3 has expired and is retried. - { - done: []task{ - &waitExpireTask{Duration: 19 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: staticDialedConn, node: newNode(uintID(2), nil)}}, - }, - new: []task{ - &dialTask{flags: staticDialedConn, dest: newNode(uintID(3), nil)}, - }, + }, + // Nothing happens in this round because we're waiting for + // node 0x3's history entry to expire. + {}, + // The cache entry for node 0x03 has expired and is retried. + { + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), }, }, }) } -func TestDialResolve(t *testing.T) { - config := &Config{ - Logger: testlog.Logger(t, log.LvlTrace), - Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, - } - resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) - resolver := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, 0, config) +func TestDialSchedResolve(t *testing.T) { + t.Parallel() - // Check that the task is generated with an incomplete ID. - dest := newNode(uintID(1), nil) - state.addStatic(dest) - tasks := state.newTasks(0, nil, time.Time{}) - if !reflect.DeepEqual(tasks, []task{&dialTask{flags: staticDialedConn, dest: dest}}) { - t.Fatalf("expected dial task, got %#v", tasks) + config := dialConfig{ + maxActiveDials: 1, + maxDialPeers: 1, } + node := newNode(uintID(0x01), "") + resolved := newNode(uintID(0x01), "127.0.0.1:30303") + resolved2 := newNode(uintID(0x01), "127.0.0.55:30303") + runDialTest(t, config, []dialTestRound{ + { + update: func(d *dialScheduler) { + d.addStatic(node) + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): resolved, + }, + wantNewDials: []*enode.Node{ + resolved, + }, + }, + { + failed: []enode.ID{ + uintID(0x01), + }, + wantResolves: map[enode.ID]*enode.Node{ + uintID(0x01): resolved2, + }, + wantNewDials: []*enode.Node{ + resolved2, + }, + }, + }) +} - // Now run the task, it should resolve the ID once. - srv := &Server{ - Config: *config, - log: config.Logger, - staticNodeResolver: resolver, - } - tasks[0].Do(srv) - if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) { - t.Fatalf("wrong resolve calls, got %v", resolver.calls) - } +// ------- +// Code below here is the framework for the tests above. - // Report it as done to the dialer, which should update the static node record. - state.taskDone(tasks[0], time.Now()) - if state.static[uintID(1)].dest != resolved { - t.Fatalf("state.dest not updated") +type dialTestRound struct { + peersAdded []*conn + peersRemoved []enode.ID + update func(*dialScheduler) // called at beginning of round + discovered []*enode.Node // newly discovered nodes + succeeded []enode.ID // dials which succeed this round + failed []enode.ID // dials which fail this round + wantResolves map[enode.ID]*enode.Node + wantNewDials []*enode.Node // dials that should be launched in this round +} + +func runDialTest(t *testing.T, config dialConfig, rounds []dialTestRound) { + var ( + clock = new(mclock.Simulated) + iterator = newDialTestIterator() + dialer = newDialTestDialer() + resolver = new(dialTestResolver) + peers = make(map[enode.ID]*conn) + setupCh = make(chan *conn) + ) + + // Override config. + config.clock = clock + config.dialer = dialer + config.resolver = resolver + config.log = testlog.Logger(t, log.LvlTrace) + config.rand = rand.New(rand.NewSource(0x1111)) + + // Set up the dialer. The setup function below runs on the dialTask + // goroutine and adds the peer. + var dialsched *dialScheduler + setup := func(fd net.Conn, f connFlag, node *enode.Node) error { + conn := &conn{flags: f, node: node} + dialsched.peerAdded(conn) + setupCh <- conn + return nil + } + dialsched = newDialScheduler(config, iterator, setup) + defer dialsched.stop() + + for i, round := range rounds { + // Apply peer set updates. + for _, c := range round.peersAdded { + if peers[c.node.ID()] != nil { + t.Fatalf("round %d: peer %v already connected", i, c.node.ID()) + } + dialsched.peerAdded(c) + peers[c.node.ID()] = c + } + for _, id := range round.peersRemoved { + c := peers[id] + if c == nil { + t.Fatalf("round %d: can't remove non-existent peer %v", i, id) + } + dialsched.peerRemoved(c) + } + + // Init round. + t.Logf("round %d (%d peers)", i, len(peers)) + resolver.setAnswers(round.wantResolves) + if round.update != nil { + round.update(dialsched) + } + iterator.addNodes(round.discovered) + + // Unblock dialTask goroutines. + if err := dialer.completeDials(round.succeeded, nil); err != nil { + t.Fatalf("round %d: %v", i, err) + } + for range round.succeeded { + conn := <-setupCh + peers[conn.node.ID()] = conn + } + if err := dialer.completeDials(round.failed, errors.New("oops")); err != nil { + t.Fatalf("round %d: %v", i, err) + } + + // Wait for new tasks. + if err := dialer.waitForDials(round.wantNewDials); err != nil { + t.Fatalf("round %d: %v", i, err) + } + if !resolver.checkCalls() { + t.Fatalf("unexpected calls to Resolve: %v", resolver.calls) + } + + clock.Run(16 * time.Second) } } -// compares task lists but doesn't care about the order. -func sametasks(a, b []task) bool { - if len(a) != len(b) { +// dialTestIterator is the input iterator for dialer tests. This works a bit like a channel +// with infinite buffer: nodes are added to the buffer with addNodes, which unblocks Next +// and returns them from the iterator. +type dialTestIterator struct { + cur *enode.Node + + mu sync.Mutex + buf []*enode.Node + cond *sync.Cond + closed bool +} + +func newDialTestIterator() *dialTestIterator { + it := &dialTestIterator{} + it.cond = sync.NewCond(&it.mu) + return it +} + +// addNodes adds nodes to the iterator buffer and unblocks Next. +func (it *dialTestIterator) addNodes(nodes []*enode.Node) { + it.mu.Lock() + defer it.mu.Unlock() + + it.buf = append(it.buf, nodes...) + it.cond.Signal() +} + +// Node returns the current node. +func (it *dialTestIterator) Node() *enode.Node { + return it.cur +} + +// Next moves to the next node. +func (it *dialTestIterator) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + it.cur = nil + for len(it.buf) == 0 && !it.closed { + it.cond.Wait() + } + if it.closed { return false } -next: - for _, ta := range a { - for _, tb := range b { - if reflect.DeepEqual(ta, tb) { - continue next - } + it.cur = it.buf[0] + copy(it.buf[:], it.buf[1:]) + it.buf = it.buf[:len(it.buf)-1] + return true +} + +// Close ends the iterator, unblocking Next. +func (it *dialTestIterator) Close() { + it.mu.Lock() + defer it.mu.Unlock() + + it.closed = true + it.buf = nil + it.cond.Signal() +} + +// dialTestDialer is the NodeDialer used by runDialTest. +type dialTestDialer struct { + init chan *dialTestReq + blocked map[enode.ID]*dialTestReq +} + +type dialTestReq struct { + n *enode.Node + unblock chan error +} + +func newDialTestDialer() *dialTestDialer { + return &dialTestDialer{ + init: make(chan *dialTestReq), + blocked: make(map[enode.ID]*dialTestReq), + } +} + +// Dial implements NodeDialer. +func (d *dialTestDialer) Dial(ctx context.Context, n *enode.Node) (net.Conn, error) { + req := &dialTestReq{n: n, unblock: make(chan error, 1)} + select { + case d.init <- req: + select { + case err := <-req.unblock: + pipe, _ := net.Pipe() + return pipe, err + case <-ctx.Done(): + return nil, ctx.Err() + } + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// waitForDials waits for calls to Dial with the given nodes as argument. +// Those calls will be held blocking until completeDials is called with the same nodes. +func (d *dialTestDialer) waitForDials(nodes []*enode.Node) error { + waitset := make(map[enode.ID]*enode.Node) + for _, n := range nodes { + waitset[n.ID()] = n + } + timeout := time.NewTimer(1 * time.Second) + defer timeout.Stop() + + for len(waitset) > 0 { + select { + case req := <-d.init: + want, ok := waitset[req.n.ID()] + if !ok { + return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID()) + } + if !reflect.DeepEqual(req.n, want) { + return fmt.Errorf("ENR of dialed node %v does not match test", req.n.ID()) + } + delete(waitset, req.n.ID()) + d.blocked[req.n.ID()] = req + case <-timeout.C: + var waitlist []enode.ID + for id := range waitset { + waitlist = append(waitlist, id) + } + return fmt.Errorf("timed out waiting for dials to %v", waitlist) + } + } + + return d.checkUnexpectedDial() +} + +func (d *dialTestDialer) checkUnexpectedDial() error { + select { + case req := <-d.init: + return fmt.Errorf("attempt to dial unexpected node %v", req.n.ID()) + case <-time.After(150 * time.Millisecond): + return nil + } +} + +// completeDials unblocks calls to Dial for the given nodes. +func (d *dialTestDialer) completeDials(ids []enode.ID, err error) error { + for _, id := range ids { + req := d.blocked[id] + if req == nil { + return fmt.Errorf("can't complete dial to %v", id) + } + req.unblock <- err + } + return nil +} + +// dialTestResolver tracks calls to resolve. +type dialTestResolver struct { + mu sync.Mutex + calls []enode.ID + answers map[enode.ID]*enode.Node +} + +func (t *dialTestResolver) setAnswers(m map[enode.ID]*enode.Node) { + t.mu.Lock() + defer t.mu.Unlock() + + t.answers = m + t.calls = nil +} + +func (t *dialTestResolver) checkCalls() bool { + t.mu.Lock() + defer t.mu.Unlock() + + for _, id := range t.calls { + if _, ok := t.answers[id]; !ok { + return false } - return false } return true } -func uintID(i uint32) enode.ID { - var id enode.ID - binary.BigEndian.PutUint32(id[:], i) - return id -} +func (t *dialTestResolver) Resolve(n *enode.Node) *enode.Node { + t.mu.Lock() + defer t.mu.Unlock() -// for TestDialResolve -type resolveMock struct { - calls []*enode.Node - answer *enode.Node -} - -func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { - t.calls = append(t.calls, n) - return t.answer + t.calls = append(t.calls, n.ID()) + return t.answers[n.ID()] } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index a2393ba854..e40deb98f0 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -17,15 +17,20 @@ package p2p import ( + "encoding/binary" "errors" "fmt" "math/rand" "net" "reflect" + "strconv" + "strings" "testing" "time" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" ) var discard = Protocol{ @@ -45,10 +50,45 @@ var discard = Protocol{ }, } +// uintID encodes i into a node ID. +func uintID(i uint16) enode.ID { + var id enode.ID + binary.BigEndian.PutUint16(id[:], i) + return id +} + +// newNode creates a node record with the given address. +func newNode(id enode.ID, addr string) *enode.Node { + var r enr.Record + if addr != "" { + // Set the port if present. + if strings.Contains(addr, ":") { + hs, ps, err := net.SplitHostPort(addr) + if err != nil { + panic(fmt.Errorf("invalid address %q", addr)) + } + port, err := strconv.Atoi(ps) + if err != nil { + panic(fmt.Errorf("invalid port in %q", addr)) + } + r.Set(enr.TCP(port)) + r.Set(enr.UDP(port)) + addr = hs + } + // Set the IP. + ip := net.ParseIP(addr) + if ip == nil { + panic(fmt.Errorf("invalid IP %q", addr)) + } + r.Set(enr.IP(ip)) + } + return enode.SignNull(&r, id) +} + func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { fd1, fd2 := net.Pipe() - c1 := &conn{fd: fd1, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd1)} - c2 := &conn{fd: fd2, node: newNode(randomID(), nil), transport: newTestTransport(&newkey().PublicKey, fd2)} + c1 := &conn{fd: fd1, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd1)} + c2 := &conn{fd: fd2, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd2)} for _, p := range protos { c1.caps = append(c1.caps, p.cap()) c2.caps = append(c2.caps, p.cap()) diff --git a/p2p/server.go b/p2p/server.go index fda72e3793..1ed1be2ac6 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -51,7 +51,6 @@ const ( discmixTimeout = 5 * time.Second // Connectivity defaults. - maxActiveDialTasks = 16 defaultMaxPendingPeers = 50 defaultDialRatio = 3 @@ -156,6 +155,8 @@ type Config struct { // Logger is a custom logger to use with the p2p.Server. Logger log.Logger `toml:",omitempty"` + + clock mclock.Clock } // Server manages all peer connections. @@ -183,13 +184,10 @@ type Server struct { ntab *discover.UDPv4 DiscV5 *discv5.Network discmix *enode.FairMix - - staticNodeResolver nodeResolver + dialsched *dialScheduler // Channels into the run loop. quit chan struct{} - addstatic chan *enode.Node - removestatic chan *enode.Node addtrusted chan *enode.Node removetrusted chan *enode.Node peerOp chan peerOpFunc @@ -302,47 +300,57 @@ func (srv *Server) LocalNode() *enode.LocalNode { // Peers returns all connected peers. func (srv *Server) Peers() []*Peer { var ps []*Peer - select { - // Note: We'd love to put this function into a variable but - // that seems to cause a weird compiler error in some - // environments. - case srv.peerOp <- func(peers map[enode.ID]*Peer) { + srv.doPeerOp(func(peers map[enode.ID]*Peer) { for _, p := range peers { ps = append(ps, p) } - }: - <-srv.peerOpDone - case <-srv.quit: - } + }) return ps } // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { var count int - select { - case srv.peerOp <- func(ps map[enode.ID]*Peer) { count = len(ps) }: - <-srv.peerOpDone - case <-srv.quit: - } + srv.doPeerOp(func(ps map[enode.ID]*Peer) { + count = len(ps) + }) return count } -// AddPeer connects to the given node and maintains the connection until the -// server is shut down. If the connection fails for any reason, the server will -// attempt to reconnect the peer. +// AddPeer adds the given node to the static node set. When there is room in the peer set, +// the server will connect to the node. If the connection fails for any reason, the server +// will attempt to reconnect the peer. func (srv *Server) AddPeer(node *enode.Node) { - select { - case srv.addstatic <- node: - case <-srv.quit: - } + srv.dialsched.addStatic(node) } -// RemovePeer disconnects from the given node +// RemovePeer removes a node from the static node set. It also disconnects from the given +// node if it is currently connected as a peer. +// +// This method blocks until all protocols have exited and the peer is removed. Do not use +// RemovePeer in protocol implementations, call Disconnect on the Peer instead. func (srv *Server) RemovePeer(node *enode.Node) { - select { - case srv.removestatic <- node: - case <-srv.quit: + var ( + ch chan *PeerEvent + sub event.Subscription + ) + // Disconnect the peer on the main loop. + srv.doPeerOp(func(peers map[enode.ID]*Peer) { + srv.dialsched.removeStatic(node) + if peer := peers[node.ID()]; peer != nil { + ch = make(chan *PeerEvent, 1) + sub = srv.peerFeed.Subscribe(ch) + peer.Disconnect(DiscRequested) + } + }) + // Wait for the peer connection to end. + if ch != nil { + defer sub.Unsubscribe() + for ev := range ch { + if ev.Peer == node.ID() && ev.Type == PeerEventTypeDrop { + return + } + } } } @@ -437,6 +445,9 @@ func (srv *Server) Start() (err error) { if srv.log == nil { srv.log = log.Root() } + if srv.clock == nil { + srv.clock = mclock.System{} + } if srv.NoDial && srv.ListenAddr == "" { srv.log.Warn("P2P server will be useless, neither dialing nor listening") } @@ -451,15 +462,10 @@ func (srv *Server) Start() (err error) { if srv.listenFunc == nil { srv.listenFunc = net.Listen } - if srv.Dialer == nil { - srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} - } srv.quit = make(chan struct{}) srv.delpeer = make(chan peerDrop) srv.checkpointPostHandshake = make(chan *conn) srv.checkpointAddPeer = make(chan *conn) - srv.addstatic = make(chan *enode.Node) - srv.removestatic = make(chan *enode.Node) srv.addtrusted = make(chan *enode.Node) srv.removetrusted = make(chan *enode.Node) srv.peerOp = make(chan peerOpFunc) @@ -476,11 +482,10 @@ func (srv *Server) Start() (err error) { if err := srv.setupDiscovery(); err != nil { return err } + srv.setupDialScheduler() - dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config) srv.loopWG.Add(1) - go srv.run(dialer) + go srv.run() return nil } @@ -583,7 +588,6 @@ func (srv *Server) setupDiscovery() error { } srv.ntab = ntab srv.discmix.AddSource(ntab.RandomNodes()) - srv.staticNodeResolver = ntab } // Discovery V5 @@ -606,6 +610,47 @@ func (srv *Server) setupDiscovery() error { return nil } +func (srv *Server) setupDialScheduler() { + config := dialConfig{ + self: srv.localnode.ID(), + maxDialPeers: srv.maxDialedConns(), + maxActiveDials: srv.MaxPendingPeers, + log: srv.Logger, + netRestrict: srv.NetRestrict, + dialer: srv.Dialer, + clock: srv.clock, + } + if srv.ntab != nil { + config.resolver = srv.ntab + } + if config.dialer == nil { + config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}} + } + srv.dialsched = newDialScheduler(config, srv.discmix, srv.SetupConn) + for _, n := range srv.StaticNodes { + srv.dialsched.addStatic(n) + } +} + +func (srv *Server) maxInboundConns() int { + return srv.MaxPeers - srv.maxDialedConns() +} + +func (srv *Server) maxDialedConns() (limit int) { + if srv.NoDial || srv.MaxPeers == 0 { + return 0 + } + if srv.DialRatio == 0 { + limit = srv.MaxPeers / defaultDialRatio + } else { + limit = srv.MaxPeers / srv.DialRatio + } + if limit == 0 { + limit = 1 + } + return limit +} + func (srv *Server) setupListening() error { // Launch the listener. listener, err := srv.listenFunc("tcp", srv.ListenAddr) @@ -632,112 +677,55 @@ func (srv *Server) setupListening() error { return nil } -type dialer interface { - newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task - taskDone(task, time.Time) - addStatic(*enode.Node) - removeStatic(*enode.Node) +// doPeerOp runs fn on the main loop. +func (srv *Server) doPeerOp(fn peerOpFunc) { + select { + case srv.peerOp <- fn: + <-srv.peerOpDone + case <-srv.quit: + } } -func (srv *Server) run(dialstate dialer) { +// run is the main loop of the server. +func (srv *Server) run() { srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4()) defer srv.loopWG.Done() defer srv.nodedb.Close() defer srv.discmix.Close() + defer srv.dialsched.stop() var ( peers = make(map[enode.ID]*Peer) inboundCount = 0 trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) - taskdone = make(chan task, maxActiveDialTasks) - tick = time.NewTicker(30 * time.Second) - runningTasks []task - queuedTasks []task // tasks that can't run yet ) - defer tick.Stop() - // Put trusted nodes into a map to speed up checks. // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. for _, n := range srv.TrustedNodes { trusted[n.ID()] = true } - // removes t from runningTasks - delTask := func(t task) { - for i := range runningTasks { - if runningTasks[i] == t { - runningTasks = append(runningTasks[:i], runningTasks[i+1:]...) - break - } - } - } - // starts until max number of active tasks is satisfied - startTasks := func(ts []task) (rest []task) { - i := 0 - for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ { - t := ts[i] - srv.log.Trace("New dial task", "task", t) - go func() { t.Do(srv); taskdone <- t }() - runningTasks = append(runningTasks, t) - } - return ts[i:] - } - scheduleTasks := func() { - // Start from queue first. - queuedTasks = append(queuedTasks[:0], startTasks(queuedTasks)...) - // Query dialer for new tasks and start as many as possible now. - if len(runningTasks) < maxActiveDialTasks { - nt := dialstate.newTasks(len(runningTasks)+len(queuedTasks), peers, time.Now()) - queuedTasks = append(queuedTasks, startTasks(nt)...) - } - } - running: for { - scheduleTasks() - select { - case <-tick.C: - // This is just here to ensure the dial scheduler runs occasionally. - case <-srv.quit: // The server was stopped. Run the cleanup logic. break running - case n := <-srv.addstatic: - // This channel is used by AddPeer to add to the - // ephemeral static peer list. Add it to the dialer, - // it will keep the node connected. - srv.log.Trace("Adding static node", "node", n) - dialstate.addStatic(n) - - case n := <-srv.removestatic: - // This channel is used by RemovePeer to send a - // disconnect request to a peer and begin the - // stop keeping the node connected. - srv.log.Trace("Removing static node", "node", n) - dialstate.removeStatic(n) - if p, ok := peers[n.ID()]; ok { - p.Disconnect(DiscRequested) - } - case n := <-srv.addtrusted: - // This channel is used by AddTrustedPeer to add an enode + // This channel is used by AddTrustedPeer to add a node // to the trusted node set. srv.log.Trace("Adding trusted node", "node", n) trusted[n.ID()] = true - // Mark any already-connected peer as trusted if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, true) } case n := <-srv.removetrusted: - // This channel is used by RemoveTrustedPeer to remove an enode + // This channel is used by RemoveTrustedPeer to remove a node // from the trusted node set. srv.log.Trace("Removing trusted node", "node", n) delete(trusted, n.ID()) - - // Unmark any already-connected peer as trusted if p, ok := peers[n.ID()]; ok { p.rw.set(trustedConn, false) } @@ -747,14 +735,6 @@ running: op(peers) srv.peerOpDone <- struct{}{} - case t := <-taskdone: - // A task got done. Tell dialstate about it so it - // can update its state and remove it from the active - // tasks list. - srv.log.Trace("Dial task done", "task", t) - dialstate.taskDone(t, time.Now()) - delTask(t) - case c := <-srv.checkpointPostHandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). @@ -771,33 +751,25 @@ running: err := srv.addPeerChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. - p := newPeer(srv.log, c, srv.Protocols) - // If message events are enabled, pass the peerFeed - // to the peer - if srv.EnableMsgEvents { - p.events = &srv.peerFeed - } - name := truncateName(c.name) - p.log.Debug("Adding p2p peer", "addr", p.RemoteAddr(), "peers", len(peers)+1, "name", name) - go srv.runPeer(p) + p := srv.launchPeer(c) peers[c.node.ID()] = p - if p.Inbound() { - inboundCount++ - } + srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", truncateName(c.name)) + srv.dialsched.peerAdded(c) if conn, ok := c.fd.(*meteredConn); ok { conn.handshakeDone(p) } + if p.Inbound() { + inboundCount++ + } } - // The dialer logic relies on the assumption that - // dial tasks complete after the peer has been added or - // discarded. Unblock the task last. c.cont <- err case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) - pd.log.Debug("Removing p2p peer", "addr", pd.RemoteAddr(), "peers", len(peers)-1, "duration", d, "req", pd.requested, "err", pd.err) delete(peers, pd.ID()) + srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err) + srv.dialsched.peerRemoved(pd.rw) if pd.Inbound() { inboundCount-- } @@ -822,14 +794,14 @@ running: // is closed. for len(peers) > 0 { p := <-srv.delpeer - p.log.Trace("<-delpeer (spindown)", "remainingTasks", len(runningTasks)) + p.log.Trace("<-delpeer (spindown)") delete(peers, p.ID()) } } func (srv *Server) postHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error { switch { - case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: + case !c.is(trustedConn) && len(peers) >= srv.MaxPeers: return DiscTooManyPeers case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns(): return DiscTooManyPeers @@ -852,21 +824,6 @@ func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c * return srv.postHandshakeChecks(peers, inboundCount, c) } -func (srv *Server) maxInboundConns() int { - return srv.MaxPeers - srv.maxDialedConns() -} - -func (srv *Server) maxDialedConns() int { - if srv.NoDiscovery || srv.NoDial { - return 0 - } - r := srv.DialRatio - if r == 0 { - r = defaultDialRatio - } - return srv.MaxPeers / r -} - // listenLoop runs in its own goroutine and accepts // inbound connections. func (srv *Server) listenLoop() { @@ -935,18 +892,20 @@ func (srv *Server) listenLoop() { } func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error { - if remoteIP != nil { - // Reject connections that do not match NetRestrict. - if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { - return fmt.Errorf("not whitelisted in NetRestrict") - } - // Reject Internet peers that try too often. - srv.inboundHistory.expire(time.Now()) - if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { - return fmt.Errorf("too many attempts") - } - srv.inboundHistory.add(remoteIP.String(), time.Now().Add(inboundThrottleTime)) + if remoteIP == nil { + return nil } + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { + return fmt.Errorf("not whitelisted in NetRestrict") + } + // Reject Internet peers that try too often. + now := srv.clock.Now() + srv.inboundHistory.expire(now, nil) + if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { + return fmt.Errorf("too many attempts") + } + srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime)) return nil } @@ -958,7 +917,6 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) err := srv.setupConn(c, flags, dialDest) if err != nil { c.close(err) - srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err) } return err } @@ -977,7 +935,9 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro if dialDest != nil { dialPubkey = new(ecdsa.PublicKey) if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil { - return errors.New("dial destination doesn't have a secp256k1 public key") + err = errors.New("dial destination doesn't have a secp256k1 public key") + srv.log.Trace("Setting up connection failed", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) + return err } } @@ -1006,7 +966,7 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro // Run the capability negotiation handshake. phs, err := c.doProtoHandshake(srv.ourHandshake) if err != nil { - clog.Trace("Failed proto handshake", "err", err) + clog.Trace("Failed p2p handshake", "err", err) return err } if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) { @@ -1020,9 +980,6 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro return err } - // If the checks completed successfully, the connection has been added as a peer and - // runPeer has been launched. - clog.Trace("Connection set up", "inbound", dialDest == nil) return nil } @@ -1054,15 +1011,22 @@ func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { return <-c.cont } +func (srv *Server) launchPeer(c *conn) *Peer { + p := newPeer(srv.log, c, srv.Protocols) + if srv.EnableMsgEvents { + // If message events are enabled, pass the peerFeed + // to the peer. + p.events = &srv.peerFeed + } + go srv.runPeer(p) + return p +} + // runPeer runs in its own goroutine for each peer. -// it waits until the Peer logic returns and removes -// the peer. func (srv *Server) runPeer(p *Peer) { if srv.newPeerHook != nil { srv.newPeerHook(p) } - - // broadcast peer add srv.peerFeed.Send(&PeerEvent{ Type: PeerEventTypeAdd, Peer: p.ID(), @@ -1070,10 +1034,18 @@ func (srv *Server) runPeer(p *Peer) { LocalAddress: p.LocalAddr().String(), }) - // run the protocol + // Run the per-peer main loop. remoteRequested, err := p.run() - // broadcast peer drop + // Announce disconnect on the main loop to update the peer set. + // The main loop waits for existing peers to be sent on srv.delpeer + // before returning, so this send should not select on srv.quit. + srv.delpeer <- peerDrop{p, err, remoteRequested} + + // Broadcast peer drop to external subscribers. This needs to be + // after the send to delpeer so subscribers have a consistent view of + // the peer set (i.e. Server.Peers() doesn't include the peer when the + // event is received. srv.peerFeed.Send(&PeerEvent{ Type: PeerEventTypeDrop, Peer: p.ID(), @@ -1081,10 +1053,6 @@ func (srv *Server) runPeer(p *Peer) { RemoteAddress: p.RemoteAddr().String(), LocalAddress: p.LocalAddr().String(), }) - - // Note: run waits for existing peers to be sent on srv.delpeer - // before returning, so this send should not select on srv.quit. - srv.delpeer <- peerDrop{p, err, remoteRequested} } // NodeInfo represents a short summary of the information known about the host. diff --git a/p2p/server_test.go b/p2p/server_test.go index 383445c833..958eb29129 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -34,10 +34,6 @@ import ( "golang.org/x/crypto/sha3" ) -// func init() { -// log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) -// } - type testTransport struct { rpub *ecdsa.PublicKey *rlpx @@ -72,11 +68,12 @@ func (c *testTransport) close(err error) { func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server { config := Config{ - Name: "test", - MaxPeers: 10, - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), - Logger: testlog.Logger(t, log.LvlTrace), + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + NoDiscovery: true, + PrivateKey: newkey(), + Logger: testlog.Logger(t, log.LvlTrace), } server := &Server{ Config: config, @@ -131,11 +128,10 @@ func TestServerDial(t *testing.T) { t.Fatalf("could not setup listener: %v", err) } defer listener.Close() - accepted := make(chan net.Conn) + accepted := make(chan net.Conn, 1) go func() { conn, err := listener.Accept() if err != nil { - t.Error("accept error:", err) return } accepted <- conn @@ -205,155 +201,38 @@ func TestServerDial(t *testing.T) { } } -// This test checks that tasks generated by dialstate are -// actually executed and taskdone is called for them. -func TestServerTaskScheduling(t *testing.T) { - var ( - done = make(chan *testTask) - quit, returned = make(chan struct{}), make(chan struct{}) - tc = 0 - tg = taskgen{ - newFunc: func(running int, peers map[enode.ID]*Peer) []task { - tc++ - return []task{&testTask{index: tc - 1}} - }, - doneFunc: func(t task) { - select { - case done <- t.(*testTask): - case <-quit: - } - }, - } - ) +// This test checks that RemovePeer disconnects the peer if it is connected. +func TestServerRemovePeerDisconnect(t *testing.T) { + srv1 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace).New("server", "1"), + }} + srv2 := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + NoDial: true, + ListenAddr: "127.0.0.1:0", + Logger: testlog.Logger(t, log.LvlTrace).New("server", "2"), + }} + srv1.Start() + defer srv1.Stop() + srv2.Start() + defer srv2.Stop() - // The Server in this test isn't actually running - // because we're only interested in what run does. - db, _ := enode.OpenDB("") - srv := &Server{ - Config: Config{MaxPeers: 10}, - localnode: enode.NewLocalNode(db, newkey()), - nodedb: db, - discmix: enode.NewFairMix(0), - quit: make(chan struct{}), - running: true, - log: log.New(), + if !syncAddPeer(srv1, srv2.Self()) { + t.Fatal("peer not connected") } - srv.loopWG.Add(1) - go func() { - srv.run(tg) - close(returned) - }() - - var gotdone []*testTask - for i := 0; i < 100; i++ { - gotdone = append(gotdone, <-done) - } - for i, task := range gotdone { - if task.index != i { - t.Errorf("task %d has wrong index, got %d", i, task.index) - break - } - if !task.called { - t.Errorf("task %d was not called", i) - break - } - } - - close(quit) - srv.Stop() - select { - case <-returned: - case <-time.After(500 * time.Millisecond): - t.Error("Server.run did not return within 500ms") + srv1.RemovePeer(srv2.Self()) + if srv1.PeerCount() > 0 { + t.Fatal("removed peer still connected") } } -// This test checks that Server doesn't drop tasks, -// even if newTasks returns more than the maximum number of tasks. -func TestServerManyTasks(t *testing.T) { - alltasks := make([]task, 300) - for i := range alltasks { - alltasks[i] = &testTask{index: i} - } - - var ( - db, _ = enode.OpenDB("") - srv = &Server{ - quit: make(chan struct{}), - localnode: enode.NewLocalNode(db, newkey()), - nodedb: db, - running: true, - log: log.New(), - discmix: enode.NewFairMix(0), - } - done = make(chan *testTask) - start, end = 0, 0 - ) - defer srv.Stop() - srv.loopWG.Add(1) - go srv.run(taskgen{ - newFunc: func(running int, peers map[enode.ID]*Peer) []task { - start, end = end, end+maxActiveDialTasks+10 - if end > len(alltasks) { - end = len(alltasks) - } - return alltasks[start:end] - }, - doneFunc: func(tt task) { - done <- tt.(*testTask) - }, - }) - - doneset := make(map[int]bool) - timeout := time.After(2 * time.Second) - for len(doneset) < len(alltasks) { - select { - case tt := <-done: - if doneset[tt.index] { - t.Errorf("task %d got done more than once", tt.index) - } else { - doneset[tt.index] = true - } - case <-timeout: - t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) - for i := 0; i < len(alltasks); i++ { - if !doneset[i] { - t.Logf("task %d not done", i) - } - } - return - } - } -} - -type taskgen struct { - newFunc func(running int, peers map[enode.ID]*Peer) []task - doneFunc func(task) -} - -func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task { - return tg.newFunc(running, peers) -} -func (tg taskgen) taskDone(t task, now time.Time) { - tg.doneFunc(t) -} -func (tg taskgen) addStatic(*enode.Node) { -} -func (tg taskgen) removeStatic(*enode.Node) { -} - -type testTask struct { - index int - called bool -} - -func (t *testTask) Do(srv *Server) { - t.called = true -} - -// This test checks that connections are disconnected -// just after the encryption handshake when the server is -// at capacity. Trusted connections should still be accepted. +// This test checks that connections are disconnected just after the encryption handshake +// when the server is at capacity. Trusted connections should still be accepted. func TestServerAtCap(t *testing.T) { trustedNode := newkey() trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey) @@ -363,7 +242,8 @@ func TestServerAtCap(t *testing.T) { MaxPeers: 10, NoDial: true, NoDiscovery: true, - TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, + TrustedNodes: []*enode.Node{newNode(trustedID, "")}, + Logger: testlog.Logger(t, log.LvlTrace), }, } if err := srv.Start(); err != nil { @@ -401,14 +281,14 @@ func TestServerAtCap(t *testing.T) { } // Remove from trusted set and try again - srv.RemoveTrustedPeer(newNode(trustedID, nil)) + srv.RemoveTrustedPeer(newNode(trustedID, "")) c = newconn(trustedID) if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { t.Error("wrong error for insert:", err) } // Add anotherID to trusted set and try again - srv.AddTrustedPeer(newNode(anotherID, nil)) + srv.AddTrustedPeer(newNode(anotherID, "")) c = newconn(anotherID) if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { t.Error("unexpected error for trusted conn @posthandshake:", err) @@ -439,9 +319,9 @@ func TestServerPeerLimits(t *testing.T) { NoDial: true, NoDiscovery: true, Protocols: []Protocol{discard}, + Logger: testlog.Logger(t, log.LvlTrace), }, newTransport: func(fd net.Conn) transport { return tp }, - log: log.New(), } if err := srv.Start(); err != nil { t.Fatalf("couldn't start server: %v", err) @@ -724,3 +604,23 @@ func (l *fakeAddrListener) Accept() (net.Conn, error) { func (c *fakeAddrConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func syncAddPeer(srv *Server, node *enode.Node) bool { + var ( + ch = make(chan *PeerEvent) + sub = srv.SubscribeEvents(ch) + timeout = time.After(2 * time.Second) + ) + defer sub.Unsubscribe() + srv.AddPeer(node) + for { + select { + case ev := <-ch: + if ev.Type == PeerEventTypeAdd && ev.Peer == node.ID() { + return true + } + case <-timeout: + return false + } + } +} diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index 9787082e18..651d9546ae 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -17,6 +17,7 @@ package adapters import ( + "context" "errors" "fmt" "math" @@ -126,7 +127,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { // Dial implements the p2p.NodeDialer interface by connecting to the node using // an in-memory net.Pipe -func (s *SimAdapter) Dial(dest *enode.Node) (conn net.Conn, err error) { +func (s *SimAdapter) Dial(ctx context.Context, dest *enode.Node) (conn net.Conn, err error) { node, ok := s.GetNode(dest.ID()) if !ok { return nil, fmt.Errorf("unknown node: %s", dest.ID()) diff --git a/p2p/util.go b/p2p/util.go index 018cc40e98..3c5f6b8508 100644 --- a/p2p/util.go +++ b/p2p/util.go @@ -18,7 +18,8 @@ package p2p import ( "container/heap" - "time" + + "github.com/ethereum/go-ethereum/common/mclock" ) // expHeap tracks strings and their expiry time. @@ -27,16 +28,16 @@ type expHeap []expItem // expItem is an entry in addrHistory. type expItem struct { item string - exp time.Time + exp mclock.AbsTime } // nextExpiry returns the next expiry time. -func (h *expHeap) nextExpiry() time.Time { +func (h *expHeap) nextExpiry() mclock.AbsTime { return (*h)[0].exp } // add adds an item and sets its expiry time. -func (h *expHeap) add(item string, exp time.Time) { +func (h *expHeap) add(item string, exp mclock.AbsTime) { heap.Push(h, expItem{item, exp}) } @@ -51,15 +52,18 @@ func (h expHeap) contains(item string) bool { } // expire removes items with expiry time before 'now'. -func (h *expHeap) expire(now time.Time) { - for h.Len() > 0 && h.nextExpiry().Before(now) { - heap.Pop(h) +func (h *expHeap) expire(now mclock.AbsTime, onExp func(string)) { + for h.Len() > 0 && h.nextExpiry() < now { + item := heap.Pop(h) + if onExp != nil { + onExp(item.(expItem).item) + } } } // heap.Interface boilerplate func (h expHeap) Len() int { return len(h) } -func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } +func (h expHeap) Less(i, j int) bool { return h[i].exp < h[j].exp } func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) } func (h *expHeap) Pop() interface{} { diff --git a/p2p/util_test.go b/p2p/util_test.go index c9f2648dc9..cc0d2b215f 100644 --- a/p2p/util_test.go +++ b/p2p/util_test.go @@ -19,30 +19,32 @@ package p2p import ( "testing" "time" + + "github.com/ethereum/go-ethereum/common/mclock" ) func TestExpHeap(t *testing.T) { var h expHeap var ( - basetime = time.Unix(4000, 0) + basetime = mclock.AbsTime(10) exptimeA = basetime.Add(2 * time.Second) exptimeB = basetime.Add(3 * time.Second) exptimeC = basetime.Add(4 * time.Second) ) - h.add("a", exptimeA) h.add("b", exptimeB) + h.add("a", exptimeA) h.add("c", exptimeC) - if !h.nextExpiry().Equal(exptimeA) { + if h.nextExpiry() != exptimeA { t.Fatal("wrong nextExpiry") } if !h.contains("a") || !h.contains("b") || !h.contains("c") { t.Fatal("heap doesn't contain all live items") } - h.expire(exptimeA.Add(1)) - if !h.nextExpiry().Equal(exptimeB) { + h.expire(exptimeA.Add(1), nil) + if h.nextExpiry() != exptimeB { t.Fatal("wrong nextExpiry") } if h.contains("a") {