p2p/discover: move bond logic from table to transport (#17048)

* p2p/discover: move bond logic from table to transport

This commit moves node endpoint verification (bonding) from the table to
the UDP transport implementation. Previously, adding a node to the table
entailed pinging the node if needed. With this change, the ping-back
logic is embedded in the packet handler at a lower level.

It is easy to verify that the basic protocol is unchanged: we still
require a valid pong reply from the node before findnode is accepted.

The node database tracked the time of last ping sent to the node and
time of last valid pong received from the node. Node endpoints are
considered verified when a valid pong is received and the time of last
pong was called 'bond time'. The time of last ping sent was unused. In
this commit, the last ping database entry is repurposed to mean last
ping _received_. This entry is now used to track whether the node needs
to be pinged back.

The other big change is how nodes are added to the table. We used to add
nodes in Table.bond, which ran when a remote node pinged us or when we
encountered the node in a neighbors reply. The transport now adds to the
table directly after the endpoint is verified through ping. To ensure
that the Table can't be filled just by pinging the node repeatedly, we
retain the isInitDone check. During init, only nodes from neighbors
replies are added.

* p2p/discover: reduce findnode failure counter on success

* p2p/discover: remove unused parameter of loadSeedNodes

* p2p/discover: improve ping-back check and comments

* p2p/discover: add neighbors reply nodes always, not just during init
This commit is contained in:
Felix Lange 2018-07-03 15:24:12 +02:00 committed by Péter Szilágyi
parent 9da128db70
commit c73b654fd1
6 changed files with 147 additions and 245 deletions

@ -42,6 +42,7 @@ var (
nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element. nodeDBNilNodeID = NodeID{} // Special node ID to use as a nil element.
nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
nodeDBVersion = 5
) )
// nodeDB stores all nodes we know about. // nodeDB stores all nodes we know about.
@ -257,7 +258,7 @@ func (db *nodeDB) expireNodes() error {
} }
// Skip the node if not expired yet (and not self) // Skip the node if not expired yet (and not self)
if !bytes.Equal(id[:], db.self[:]) { if !bytes.Equal(id[:], db.self[:]) {
if seen := db.bondTime(id); seen.After(threshold) { if seen := db.lastPongReceived(id); seen.After(threshold) {
continue continue
} }
} }
@ -267,29 +268,28 @@ func (db *nodeDB) expireNodes() error {
return nil return nil
} }
// lastPing retrieves the time of the last ping packet send to a remote node, // lastPingReceived retrieves the time of the last ping packet sent by the remote node.
// requesting binding. func (db *nodeDB) lastPingReceived(id NodeID) time.Time {
func (db *nodeDB) lastPing(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
} }
// updateLastPing updates the last time we tried contacting a remote node. // updateLastPing updates the last time remote node pinged us.
func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error { func (db *nodeDB) updateLastPingReceived(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
} }
// bondTime retrieves the time of the last successful pong from remote node. // lastPongReceived retrieves the time of the last successful pong from remote node.
func (db *nodeDB) bondTime(id NodeID) time.Time { func (db *nodeDB) lastPongReceived(id NodeID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
} }
// hasBond reports whether the given node is considered bonded. // hasBond reports whether the given node is considered bonded.
func (db *nodeDB) hasBond(id NodeID) bool { func (db *nodeDB) hasBond(id NodeID) bool {
return time.Since(db.bondTime(id)) < nodeDBNodeExpiration return time.Since(db.lastPongReceived(id)) < nodeDBNodeExpiration
} }
// updateBondTime updates the last pong time of a node. // updateLastPongReceived updates the last pong time of a node.
func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error { func (db *nodeDB) updateLastPongReceived(id NodeID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
} }
@ -332,7 +332,7 @@ seek:
if n.ID == db.self { if n.ID == db.self {
continue seek continue seek
} }
if now.Sub(db.bondTime(n.ID)) > maxAge { if now.Sub(db.lastPongReceived(n.ID)) > maxAge {
continue seek continue seek
} }
for i := range nodes { for i := range nodes {

@ -79,7 +79,7 @@ var nodeDBInt64Tests = []struct {
} }
func TestNodeDBInt64(t *testing.T) { func TestNodeDBInt64(t *testing.T) {
db, _ := newNodeDB("", Version, NodeID{}) db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close() defer db.close()
tests := nodeDBInt64Tests tests := nodeDBInt64Tests
@ -111,27 +111,27 @@ func TestNodeDBFetchStore(t *testing.T) {
inst := time.Now() inst := time.Now()
num := 314 num := 314
db, _ := newNodeDB("", Version, NodeID{}) db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close() defer db.close()
// Check fetch/store operations on a node ping object // Check fetch/store operations on a node ping object
if stored := db.lastPing(node.ID); stored.Unix() != 0 { if stored := db.lastPingReceived(node.ID); stored.Unix() != 0 {
t.Errorf("ping: non-existing object: %v", stored) t.Errorf("ping: non-existing object: %v", stored)
} }
if err := db.updateLastPing(node.ID, inst); err != nil { if err := db.updateLastPingReceived(node.ID, inst); err != nil {
t.Errorf("ping: failed to update: %v", err) t.Errorf("ping: failed to update: %v", err)
} }
if stored := db.lastPing(node.ID); stored.Unix() != inst.Unix() { if stored := db.lastPingReceived(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
} }
// Check fetch/store operations on a node pong object // Check fetch/store operations on a node pong object
if stored := db.bondTime(node.ID); stored.Unix() != 0 { if stored := db.lastPongReceived(node.ID); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored) t.Errorf("pong: non-existing object: %v", stored)
} }
if err := db.updateBondTime(node.ID, inst); err != nil { if err := db.updateLastPongReceived(node.ID, inst); err != nil {
t.Errorf("pong: failed to update: %v", err) t.Errorf("pong: failed to update: %v", err)
} }
if stored := db.bondTime(node.ID); stored.Unix() != inst.Unix() { if stored := db.lastPongReceived(node.ID); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
} }
// Check fetch/store operations on a node findnode-failure object // Check fetch/store operations on a node findnode-failure object
@ -216,7 +216,7 @@ var nodeDBSeedQueryNodes = []struct {
} }
func TestNodeDBSeedQuery(t *testing.T) { func TestNodeDBSeedQuery(t *testing.T) {
db, _ := newNodeDB("", Version, nodeDBSeedQueryNodes[1].node.ID) db, _ := newNodeDB("", nodeDBVersion, nodeDBSeedQueryNodes[1].node.ID)
defer db.close() defer db.close()
// Insert a batch of nodes for querying // Insert a batch of nodes for querying
@ -224,7 +224,7 @@ func TestNodeDBSeedQuery(t *testing.T) {
if err := db.updateNode(seed.node); err != nil { if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err) t.Fatalf("node %d: failed to insert: %v", i, err)
} }
if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to insert bondTime: %v", i, err) t.Fatalf("node %d: failed to insert bondTime: %v", i, err)
} }
} }
@ -267,7 +267,7 @@ func TestNodeDBPersistency(t *testing.T) {
) )
// Create a persistent database and store some values // Create a persistent database and store some values
db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) db, err := newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{})
if err != nil { if err != nil {
t.Fatalf("failed to create persistent database: %v", err) t.Fatalf("failed to create persistent database: %v", err)
} }
@ -277,7 +277,7 @@ func TestNodeDBPersistency(t *testing.T) {
db.close() db.close()
// Reopen the database and check the value // Reopen the database and check the value
db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{}) db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion, NodeID{})
if err != nil { if err != nil {
t.Fatalf("failed to open persistent database: %v", err) t.Fatalf("failed to open persistent database: %v", err)
} }
@ -287,7 +287,7 @@ func TestNodeDBPersistency(t *testing.T) {
db.close() db.close()
// Change the database version and check flush // Change the database version and check flush
db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{}) db, err = newNodeDB(filepath.Join(root, "database"), nodeDBVersion+1, NodeID{})
if err != nil { if err != nil {
t.Fatalf("failed to open persistent database: %v", err) t.Fatalf("failed to open persistent database: %v", err)
} }
@ -324,7 +324,7 @@ var nodeDBExpirationNodes = []struct {
} }
func TestNodeDBExpiration(t *testing.T) { func TestNodeDBExpiration(t *testing.T) {
db, _ := newNodeDB("", Version, NodeID{}) db, _ := newNodeDB("", nodeDBVersion, NodeID{})
defer db.close() defer db.close()
// Add all the test nodes and set their last pong time // Add all the test nodes and set their last pong time
@ -332,7 +332,7 @@ func TestNodeDBExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil { if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err) t.Fatalf("node %d: failed to insert: %v", i, err)
} }
if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err) t.Fatalf("node %d: failed to update bondTime: %v", i, err)
} }
} }
@ -357,7 +357,7 @@ func TestNodeDBSelfExpiration(t *testing.T) {
break break
} }
} }
db, _ := newNodeDB("", Version, self) db, _ := newNodeDB("", nodeDBVersion, self)
defer db.close() defer db.close()
// Add all the test nodes and set their last pong time // Add all the test nodes and set their last pong time
@ -365,7 +365,7 @@ func TestNodeDBSelfExpiration(t *testing.T) {
if err := db.updateNode(seed.node); err != nil { if err := db.updateNode(seed.node); err != nil {
t.Fatalf("node %d: failed to insert: %v", i, err) t.Fatalf("node %d: failed to insert: %v", i, err)
} }
if err := db.updateBondTime(seed.node.ID, seed.pong); err != nil { if err := db.updateLastPongReceived(seed.node.ID, seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err) t.Fatalf("node %d: failed to update bondTime: %v", i, err)
} }
} }

@ -25,7 +25,6 @@ package discover
import ( import (
crand "crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
mrand "math/rand" mrand "math/rand"
"net" "net"
@ -54,9 +53,7 @@ const (
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
tableIPLimit, tableSubnet = 10, 24 tableIPLimit, tableSubnet = 10, 24
maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
refreshInterval = 30 * time.Minute refreshInterval = 30 * time.Minute
revalidateInterval = 10 * time.Second revalidateInterval = 10 * time.Second
copyNodesInterval = 30 * time.Second copyNodesInterval = 30 * time.Second
@ -78,28 +75,17 @@ type Table struct {
closeReq chan struct{} closeReq chan struct{}
closed chan struct{} closed chan struct{}
bondmu sync.Mutex
bonding map[NodeID]*bondproc
bondslots chan struct{} // limits total number of active bonding processes
nodeAddedHook func(*Node) // for testing nodeAddedHook func(*Node) // for testing
net transport net transport
self *Node // metadata of the local node self *Node // metadata of the local node
} }
type bondproc struct {
err error
n *Node
done chan struct{}
}
// transport is implemented by the UDP transport. // transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP // it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key. // sockets and without generating a private key.
type transport interface { type transport interface {
ping(NodeID, *net.UDPAddr) error ping(NodeID, *net.UDPAddr) error
waitping(NodeID) error
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error) findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
close() close()
} }
@ -114,7 +100,7 @@ type bucket struct {
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one // If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID) db, err := newNodeDB(nodeDBPath, nodeDBVersion, ourID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,8 +108,6 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
net: t, net: t,
db: db, db: db,
self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)), self: NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)),
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}), refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}), initDone: make(chan struct{}),
closeReq: make(chan struct{}), closeReq: make(chan struct{}),
@ -134,16 +118,13 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
if err := tab.setFallbackNodes(bootnodes); err != nil { if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err return nil, err
} }
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets { for i := range tab.buckets {
tab.buckets[i] = &bucket{ tab.buckets[i] = &bucket{
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
} }
} }
tab.seedRand() tab.seedRand()
tab.loadSeedNodes(false) tab.loadSeedNodes()
// Start the background expiration goroutine after loading seeds so that the search for // Start the background expiration goroutine after loading seeds so that the search for
// seed nodes also considers older nodes that would otherwise be removed by the // seed nodes also considers older nodes that would otherwise be removed by the
// expiration. // expiration.
@ -315,22 +296,7 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
if !asked[n.ID] { if !asked[n.ID] {
asked[n.ID] = true asked[n.ID] = true
pendingQueries++ pendingQueries++
go func() { go tab.findnode(n, targetID, reply)
// Find potential neighbors to bond with
r, err := tab.net.findnode(n.ID, n.addr(), targetID)
if err != nil {
// Bump the failure counter to detect and evacuate non-bonded entries
fails := tab.db.findFails(n.ID) + 1
tab.db.updateFindFails(n.ID, fails)
log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails)
if fails >= maxFindnodeFailures {
log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
tab.delete(n)
}
}
reply <- tab.bondall(r)
}()
} }
} }
if pendingQueries == 0 { if pendingQueries == 0 {
@ -349,6 +315,29 @@ func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
return result.entries return result.entries
} }
func (tab *Table) findnode(n *Node, targetID NodeID, reply chan<- []*Node) {
fails := tab.db.findFails(n.ID)
r, err := tab.net.findnode(n.ID, n.addr(), targetID)
if err != nil || len(r) == 0 {
fails++
tab.db.updateFindFails(n.ID, fails)
log.Trace("Findnode failed", "id", n.ID, "failcount", fails, "err", err)
if fails >= maxFindnodeFailures {
log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
tab.delete(n)
}
} else if fails > 0 {
tab.db.updateFindFails(n.ID, fails-1)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
// just remove those again during revalidation.
for _, n := range r {
tab.add(n)
}
reply <- r
}
func (tab *Table) refresh() <-chan struct{} { func (tab *Table) refresh() <-chan struct{} {
done := make(chan struct{}) done := make(chan struct{})
select { select {
@ -401,7 +390,7 @@ loop:
case <-revalidateDone: case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime()) revalidate.Reset(tab.nextRevalidateTime())
case <-copyNodes.C: case <-copyNodes.C:
go tab.copyBondedNodes() go tab.copyLiveNodes()
case <-tab.closeReq: case <-tab.closeReq:
break loop break loop
} }
@ -429,7 +418,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Load nodes from the database and insert // Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are // them. This should yield a few previously seen nodes that are
// (hopefully) still alive. // (hopefully) still alive.
tab.loadSeedNodes(true) tab.loadSeedNodes()
// Run self lookup to discover new neighbor nodes. // Run self lookup to discover new neighbor nodes.
tab.lookup(tab.self.ID, false) tab.lookup(tab.self.ID, false)
@ -447,15 +436,12 @@ func (tab *Table) doRefresh(done chan struct{}) {
} }
} }
func (tab *Table) loadSeedNodes(bond bool) { func (tab *Table) loadSeedNodes() {
seeds := tab.db.querySeeds(seedCount, seedMaxAge) seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...) seeds = append(seeds, tab.nursery...)
if bond {
seeds = tab.bondall(seeds)
}
for i := range seeds { for i := range seeds {
seed := seeds[i] seed := seeds[i]
age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }} age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPongReceived(seed.ID)) }}
log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age) log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
tab.add(seed) tab.add(seed)
} }
@ -473,7 +459,7 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
} }
// Ping the selected node and wait for a pong. // Ping the selected node and wait for a pong.
err := tab.ping(last.ID, last.addr()) err := tab.net.ping(last.ID, last.addr())
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
@ -515,9 +501,9 @@ func (tab *Table) nextRevalidateTime() time.Duration {
return time.Duration(tab.rand.Int63n(int64(revalidateInterval))) return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
} }
// copyBondedNodes adds nodes from the table to the database if they have been in the table // copyLiveNodes adds nodes from the table to the database if they have been in the table
// longer then minTableTime. // longer then minTableTime.
func (tab *Table) copyBondedNodes() { func (tab *Table) copyLiveNodes() {
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
@ -553,120 +539,6 @@ func (tab *Table) len() (n int) {
return n return n
} }
// bondall bonds with all given nodes concurrently and returns
// those nodes for which bonding has probably succeeded.
func (tab *Table) bondall(nodes []*Node) (result []*Node) {
rc := make(chan *Node, len(nodes))
for i := range nodes {
go func(n *Node) {
nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP)
rc <- nn
}(nodes[i])
}
for range nodes {
if n := <-rc; n != nil {
result = append(result, n)
}
}
return result
}
// bond ensures the local node has a bond with the given remote node.
// It also attempts to insert the node into the table if bonding succeeds.
// The caller must not hold tab.mutex.
//
// A bond is must be established before sending findnode requests.
// Both sides must have completed a ping/pong exchange for a bond to
// exist. The total number of active bonding processes is limited in
// order to restrain network use.
//
// bond is meant to operate idempotently in that bonding with a remote
// node which still remembers a previously established bond will work.
// The remote node will simply not send a ping back, causing waitping
// to time out.
//
// If pinged is true, the remote node has just pinged us and one half
// of the process can be skipped.
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
if id == tab.self.ID {
return nil, errors.New("is self")
}
if pinged && !tab.isInitDone() {
return nil, errors.New("still initializing")
}
// Start bonding if we haven't seen this node for a while or if it failed findnode too often.
node, fails := tab.db.node(id), tab.db.findFails(id)
age := time.Since(tab.db.bondTime(id))
var result error
if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock()
w := tab.bonding[id]
if w != nil {
// Wait for an existing bonding process to complete.
tab.bondmu.Unlock()
<-w.done
} else {
// Register a new bonding process.
w = &bondproc{done: make(chan struct{})}
tab.bonding[id] = w
tab.bondmu.Unlock()
// Do the ping/pong. The result goes into w.
tab.pingpong(w, pinged, id, addr, tcpPort)
// Unregister the process after it's done.
tab.bondmu.Lock()
delete(tab.bonding, id)
tab.bondmu.Unlock()
}
// Retrieve the bonding results
result = w.err
if result == nil {
node = w.n
}
}
// Add the node to the table even if the bonding ping/pong
// fails. It will be relaced quickly if it continues to be
// unresponsive.
if node != nil {
tab.add(node)
tab.db.updateFindFails(id, 0)
}
return node, result
}
func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
// Request a bonding slot to limit network usage
<-tab.bondslots
defer func() { tab.bondslots <- struct{}{} }()
// Ping the remote side and wait for a pong.
if w.err = tab.ping(id, addr); w.err != nil {
close(w.done)
return
}
if !pinged {
// Give the remote node a chance to ping us before we start
// sending findnode requests. If they still remember us,
// waitping will simply time out.
tab.net.waitping(id)
}
// Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
close(w.done)
}
// ping a remote endpoint and wait for a reply, also updating the node
// database accordingly.
func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
tab.db.updateLastPing(id, time.Now())
if err := tab.net.ping(id, addr); err != nil {
return err
}
tab.db.updateBondTime(id, time.Now())
return nil
}
// bucket returns the bucket for the given node ID hash. // bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(sha common.Hash) *bucket { func (tab *Table) bucket(sha common.Hash) *bucket {
d := logdist(tab.self.sha, sha) d := logdist(tab.self.sha, sha)
@ -676,23 +548,35 @@ func (tab *Table) bucket(sha common.Hash) *bucket {
return tab.buckets[d-bucketMinDistance-1] return tab.buckets[d-bucketMinDistance-1]
} }
// add attempts to add the given node its corresponding bucket. If the // add attempts to add the given node to its corresponding bucket. If the bucket has space
// bucket has space available, adding the node succeeds immediately. // available, adding the node succeeds immediately. Otherwise, the node is added if the
// Otherwise, the node is added if the least recently active node in // least recently active node in the bucket does not respond to a ping packet.
// the bucket does not respond to a ping packet.
// //
// The caller must not hold tab.mutex. // The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) { func (tab *Table) add(n *Node) {
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
b := tab.bucket(new.sha) b := tab.bucket(n.sha)
if !tab.bumpOrAdd(b, new) { if !tab.bumpOrAdd(b, n) {
// Node is not in table. Add it to the replacement list. // Node is not in table. Add it to the replacement list.
tab.addReplacement(b, new) tab.addReplacement(b, n)
} }
} }
// addThroughPing adds the given node to the table. Compared to plain
// 'add' there is an additional safety measure: if the table is still
// initializing the node is not added. This prevents an attack where the
// table could be filled by just sending ping repeatedly.
//
// The caller must not hold tab.mutex.
func (tab *Table) addThroughPing(n *Node) {
if !tab.isInitDone() {
return
}
tab.add(n)
}
// stuff adds nodes the table to the end of their corresponding bucket // stuff adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must not hold tab.mutex. // if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) { func (tab *Table) stuff(nodes []*Node) {
@ -710,8 +594,7 @@ func (tab *Table) stuff(nodes []*Node) {
} }
} }
// delete removes an entry from the node table (used to evacuate // delete removes an entry from the node table. It is used to evacuate dead nodes.
// failed/non-bonded discovery peers).
func (tab *Table) delete(node *Node) { func (tab *Table) delete(node *Node) {
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()

@ -52,27 +52,22 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil) tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
// Wait for init so bond is accepted.
<-tab.initDone <-tab.initDone
// fill up the sender's bucket. // Fill up the sender's bucket.
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
last := fillBucket(tab, pingSender) last := fillBucket(tab, pingSender)
// this call to bond should replace the last node // Add the sender as if it just pinged us. Revalidate should replace the last node in
// in its bucket if the node is not responding. // its bucket if it is unresponsive. Revalidate again to ensure that
transport.dead[last.ID] = !lastInBucketIsResponding transport.dead[last.ID] = !lastInBucketIsResponding
transport.dead[pingSender.ID] = !newNodeIsResponding transport.dead[pingSender.ID] = !newNodeIsResponding
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0) tab.add(pingSender)
tab.doRevalidate(make(chan struct{}, 1))
tab.doRevalidate(make(chan struct{}, 1)) tab.doRevalidate(make(chan struct{}, 1))
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender.ID] {
t.Error("table did not ping back sender")
}
if !transport.pinged[last.ID] { if !transport.pinged[last.ID] {
// second ping goes to oldest node in bucket // Oldest node in bucket is pinged to see whether it is still alive.
// to see whether it is still alive.
t.Error("table did not ping last node in bucket") t.Error("table did not ping last node in bucket")
} }
@ -83,7 +78,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
wantSize-- wantSize--
} }
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize { if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) t.Errorf("wrong bucket size after add: got %d, want %d", l, wantSize)
} }
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding { if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
@ -206,10 +201,7 @@ func newPingRecorder() *pingRecorder {
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
return nil, nil return nil, nil
} }
func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings
}
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -222,6 +214,8 @@ func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
} }
} }
func (t *pingRecorder) close() {}
func TestTable_closest(t *testing.T) { func TestTable_closest(t *testing.T) {
t.Parallel() t.Parallel()

@ -32,8 +32,6 @@ import (
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
const Version = 4
// Errors // Errors
var ( var (
errPacketTooSmall = errors.New("too small") errPacketTooSmall = errors.New("too small")
@ -272,21 +270,33 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
return <-t.sendPing(toid, toaddr, nil)
}
// sendPing sends a ping message to the given node and invokes the callback
// when the reply arrives.
func (t *udp) sendPing(toid NodeID, toaddr *net.UDPAddr, callback func()) <-chan error {
req := &ping{ req := &ping{
Version: Version, Version: 4,
From: t.ourEndpoint, From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
} }
packet, hash, err := encodePacket(t.priv, pingPacket, req) packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil { if err != nil {
return err errc := make(chan error, 1)
errc <- err
return errc
} }
errc := t.pending(toid, pongPacket, func(p interface{}) bool { errc := t.pending(toid, pongPacket, func(p interface{}) bool {
return bytes.Equal(p.(*pong).ReplyTok, hash) ok := bytes.Equal(p.(*pong).ReplyTok, hash)
if ok && callback != nil {
callback()
}
return ok
}) })
t.write(toaddr, req.name(), packet) t.write(toaddr, req.name(), packet)
return <-errc return errc
} }
func (t *udp) waitping(from NodeID) error { func (t *udp) waitping(from NodeID) error {
@ -296,6 +306,13 @@ func (t *udp) waitping(from NodeID) error {
// findnode sends a findnode request to the given node and waits until // findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors. // the node has sent up to k neighbors.
func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
// If we haven't seen a ping from the destination node for a while, it won't remember
// our endpoint proof and reject findnode. Solicit a ping first.
if time.Since(t.db.lastPingReceived(toid)) > nodeDBNodeExpiration {
t.ping(toid, toaddr)
t.waitping(toid)
}
nodes := make([]*Node, 0, bucketSize) nodes := make([]*Node, 0, bucketSize)
nreceived := 0 nreceived := 0
errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
@ -315,8 +332,7 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
Target: target, Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
err := <-errc return nodes, <-errc
return nodes, err
} }
// pending adds a reply callback to the pending reply queue. // pending adds a reply callback to the pending reply queue.
@ -587,10 +603,17 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
ReplyTok: mac, ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
if !t.handleReply(fromID, pingPacket, req) { t.handleReply(fromID, pingPacket, req)
// Note: we're ignoring the provided IP address right now
go t.bond(true, fromID, from, req.From.TCP) // Add the node to the table. Before doing so, ensure that we have a recent enough pong
// recorded in the database so their findnode requests will be accepted later.
n := NewNode(fromID, from.IP, uint16(from.Port), req.From.TCP)
if time.Since(t.db.lastPongReceived(fromID)) > nodeDBNodeExpiration {
t.sendPing(fromID, from, func() { t.addThroughPing(n) })
} else {
t.addThroughPing(n)
} }
t.db.updateLastPingReceived(fromID, time.Now())
return nil return nil
} }
@ -603,6 +626,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if !t.handleReply(fromID, pongPacket, req) { if !t.handleReply(fromID, pongPacket, req) {
return errUnsolicitedReply return errUnsolicitedReply
} }
t.db.updateLastPongReceived(fromID, time.Now())
return nil return nil
} }
@ -613,13 +637,12 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
return errExpired return errExpired
} }
if !t.db.hasBond(fromID) { if !t.db.hasBond(fromID) {
// No bond exists, we don't process the packet. This prevents // No endpoint proof pong exists, we don't process the packet. This prevents an
// an attack vector where the discovery protocol could be used // attack vector where the discovery protocol could be used to amplify traffic in a
// to amplify traffic in a DDOS attack. A malicious actor // DDOS attack. A malicious actor would send a findnode request with the IP address
// would send a findnode request with the IP address and UDP // and UDP port of the target as the source address. The recipient of the findnode
// port of the target as the source address. The recipient of // packet would then send a neighbors packet (which is a much bigger packet than
// the findnode packet would then send a neighbors packet // findnode) to the victim.
// (which is a much bigger packet than findnode) to the victim.
return errUnknownNode return errUnknownNode
} }
target := crypto.Keccak256Hash(req.Target[:]) target := crypto.Keccak256Hash(req.Target[:])

@ -124,7 +124,7 @@ func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t) test := newUDPTest(t)
defer test.table.Close() defer test.table.Close()
test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version}) test.packetIn(errExpired, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp}) test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp}) test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp}) test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
@ -247,7 +247,7 @@ func TestUDP_findnode(t *testing.T) {
// ensure there's a bond with the test node, // ensure there's a bond with the test node,
// findnode won't be accepted otherwise. // findnode won't be accepted otherwise.
test.table.db.updateBondTime(PubkeyID(&test.remotekey.PublicKey), time.Now()) test.table.db.updateLastPongReceived(PubkeyID(&test.remotekey.PublicKey), time.Now())
// check that closest neighbors are returned. // check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
@ -273,10 +273,12 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t) test := newUDPTest(t)
defer test.table.Close() defer test.table.Close()
rid := PubkeyID(&test.remotekey.PublicKey)
test.table.db.updateLastPingReceived(rid, time.Now())
// queue a pending findnode request // queue a pending findnode request
resultc, errc := make(chan []*Node), make(chan error) resultc, errc := make(chan []*Node), make(chan error)
go func() { go func() {
rid := PubkeyID(&test.remotekey.PublicKey)
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
if err != nil && len(ns) == 0 { if err != nil && len(ns) == 0 {
errc <- err errc <- err
@ -328,7 +330,7 @@ func TestUDP_successfulPing(t *testing.T) {
defer test.table.Close() defer test.table.Close()
// The remote side sends a ping packet to initiate the exchange. // The remote side sends a ping packet to initiate the exchange.
go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: Version, Expiration: futureExp}) go test.packetIn(nil, pingPacket, &ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp})
// the ping is replied to. // the ping is replied to.
test.waitPacketOut(func(p *pong) { test.waitPacketOut(func(p *pong) {