p2p, p2p/discover: misc connectivity improvements (#16069)

* p2p: add DialRatio for configuration of inbound vs. dialed connections

* p2p: add connection flags to PeerInfo

* p2p/netutil: add SameNet, DistinctNetSet

* p2p/discover: improve revalidation and seeding

This changes node revalidation to be periodic instead of on-demand. This
should prevent issues where dead nodes get stuck in closer buckets
because no other node will ever come along to replace them.

Every 5 seconds (on average), the last node in a random bucket is
checked and moved to the front of the bucket if it is still responding.
If revalidation fails, the last node is replaced by an entry of the
'replacement list' containing recently-seen nodes.

Most close buckets are removed because it's very unlikely we'll ever
encounter a node that would fall into any of those buckets.

Table seeding is also improved: we now require a few minutes of table
membership before considering a node as a potential seed node. This
should make it less likely to store short-lived nodes as potential
seeds.

* p2p/discover: fix nits in UDP transport

We would skip sending neighbors replies if there were fewer than
maxNeighbors results and CheckRelayIP returned an error for the last
one. While here, also resolve a TODO about pong reply tokens.
This commit is contained in:
Felix Lange 2018-02-12 13:36:09 +01:00 committed by Péter Szilágyi
parent 1d39912a9b
commit 9123eceb0f
10 changed files with 806 additions and 282 deletions

@ -122,7 +122,12 @@ func main() {
utils.Fatalf("%v", err)
}
} else {
if _, err := discover.ListenUDP(nodeKey, conn, realaddr, nil, "", restrictList); err != nil {
cfg := discover.Config{
PrivateKey: nodeKey,
AnnounceAddr: realaddr,
NetRestrict: restrictList,
}
if _, err := discover.ListenUDP(conn, cfg); err != nil {
utils.Fatalf("%v", err)
}
}

@ -29,6 +29,7 @@ import (
"regexp"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
@ -51,9 +52,8 @@ type Node struct {
// with ID.
sha common.Hash
// whether this node is currently being pinged in order to replace
// it in a bucket
contested bool
// Time when the node was added to the table.
addedAt time.Time
}
// NewNode creates a new node. It is mostly meant to be used for

@ -23,10 +23,11 @@
package discover
import (
"crypto/rand"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
mrand "math/rand"
"net"
"sort"
"sync"
@ -35,29 +36,45 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
hashBits = len(common.Hash{}) * 8
nBuckets = hashBits + 1 // Number of buckets
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
maxReplacements = 10 // Size of per-bucket replacement list
maxBondingPingPongs = 16
maxFindnodeFailures = 5
// We keep buckets for the upper 1/15 of distances because
// it's very unlikely we'll ever encounter a node that's closer.
hashBits = len(common.Hash{}) * 8
nBuckets = hashBits / 15 // Number of buckets
bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket
autoRefreshInterval = 1 * time.Hour
seedCount = 30
seedMaxAge = 5 * 24 * time.Hour
// IP address limits.
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
tableIPLimit, tableSubnet = 10, 24
maxBondingPingPongs = 16 // Limit on the number of concurrent ping/pong interactions
maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
refreshInterval = 30 * time.Minute
revalidateInterval = 10 * time.Second
copyNodesInterval = 30 * time.Second
seedMinTableTime = 5 * time.Minute
seedCount = 30
seedMaxAge = 5 * 24 * time.Hour
)
type Table struct {
mutex sync.Mutex // protects buckets, their content, and nursery
mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
db *nodeDB // database of known nodes
rand *mrand.Rand // source of randomness, periodically reseeded
ips netutil.DistinctNetSet
db *nodeDB // database of known nodes
refreshReq chan chan struct{}
initDone chan struct{}
closeReq chan struct{}
closed chan struct{}
@ -89,9 +106,13 @@ type transport interface {
// bucket contains nodes, ordered by their last activity. the entry
// that was most recently active is the first element in entries.
type bucket struct{ entries []*Node }
type bucket struct {
entries []*Node // live entries, sorted by time of last contact
replacements []*Node // recently seen nodes to be used if revalidation fails
ips netutil.DistinctNetSet
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) (*Table, error) {
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
// If no node database was given, use an in-memory one
db, err := newNodeDB(nodeDBPath, Version, ourID)
if err != nil {
@ -104,19 +125,42 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}),
closeReq: make(chan struct{}),
closed: make(chan struct{}),
rand: mrand.New(mrand.NewSource(0)),
ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
}
if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets {
tab.buckets[i] = new(bucket)
tab.buckets[i] = &bucket{
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
}
}
go tab.refreshLoop()
tab.seedRand()
tab.loadSeedNodes(false)
// Start the background expiration goroutine after loading seeds so that the search for
// seed nodes also considers older nodes that would otherwise be removed by the
// expiration.
tab.db.ensureExpirer()
go tab.loop()
return tab, nil
}
func (tab *Table) seedRand() {
var b [8]byte
crand.Read(b[:])
tab.mutex.Lock()
tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
tab.mutex.Unlock()
}
// Self returns the local node.
// The returned node should not be modified by the caller.
func (tab *Table) Self() *Node {
@ -127,9 +171,12 @@ func (tab *Table) Self() *Node {
// table. It will not write the same node more than once. The nodes in
// the slice are copies and can be modified by the caller.
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
if !tab.isInitDone() {
return 0
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
// TODO: tree-based buckets would help here
// Find all non-empty buckets and get a fresh slice of their entries.
var buckets [][]*Node
for _, b := range tab.buckets {
@ -141,8 +188,8 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return 0
}
// Shuffle the buckets.
for i := uint32(len(buckets)) - 1; i > 0; i-- {
j := randUint(i)
for i := len(buckets) - 1; i > 0; i-- {
j := tab.rand.Intn(len(buckets))
buckets[i], buckets[j] = buckets[j], buckets[i]
}
// Move head of each bucket into buf, removing buckets that become empty.
@ -161,15 +208,6 @@ func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
return i + 1
}
func randUint(max uint32) uint32 {
if max == 0 {
return 0
}
var b [4]byte
rand.Read(b[:])
return binary.BigEndian.Uint32(b[:]) % max
}
// Close terminates the network listener and flushes the node database.
func (tab *Table) Close() {
select {
@ -180,16 +218,15 @@ func (tab *Table) Close() {
}
}
// SetFallbackNodes sets the initial points of contact. These nodes
// setFallbackNodes sets the initial points of contact. These nodes
// are used to connect to the network if the table is empty and there
// are no known nodes in the database.
func (tab *Table) SetFallbackNodes(nodes []*Node) error {
func (tab *Table) setFallbackNodes(nodes []*Node) error {
for _, n := range nodes {
if err := n.validateComplete(); err != nil {
return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
}
}
tab.mutex.Lock()
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
@ -198,11 +235,19 @@ func (tab *Table) SetFallbackNodes(nodes []*Node) error {
cpy.sha = crypto.Keccak256Hash(n.ID[:])
tab.nursery = append(tab.nursery, &cpy)
}
tab.mutex.Unlock()
tab.refresh()
return nil
}
// isInitDone returns whether the table's initial seeding procedure has completed.
func (tab *Table) isInitDone() bool {
select {
case <-tab.initDone:
return true
default:
return false
}
}
// Resolve searches for a specific node with the given ID.
// It returns nil if the node could not be found.
func (tab *Table) Resolve(targetID NodeID) *Node {
@ -314,33 +359,49 @@ func (tab *Table) refresh() <-chan struct{} {
return done
}
// refreshLoop schedules doRefresh runs and coordinates shutdown.
func (tab *Table) refreshLoop() {
// loop schedules refresh, revalidate runs and coordinates shutdown.
func (tab *Table) loop() {
var (
timer = time.NewTicker(autoRefreshInterval)
waiting []chan struct{} // accumulates waiting callers while doRefresh runs
done chan struct{} // where doRefresh reports completion
revalidate = time.NewTimer(tab.nextRevalidateTime())
refresh = time.NewTicker(refreshInterval)
copyNodes = time.NewTicker(copyNodesInterval)
revalidateDone = make(chan struct{})
refreshDone = make(chan struct{}) // where doRefresh reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
)
defer refresh.Stop()
defer revalidate.Stop()
defer copyNodes.Stop()
// Start initial refresh.
go tab.doRefresh(refreshDone)
loop:
for {
select {
case <-timer.C:
if done == nil {
done = make(chan struct{})
go tab.doRefresh(done)
case <-refresh.C:
tab.seedRand()
if refreshDone == nil {
refreshDone = make(chan struct{})
go tab.doRefresh(refreshDone)
}
case req := <-tab.refreshReq:
waiting = append(waiting, req)
if done == nil {
done = make(chan struct{})
go tab.doRefresh(done)
if refreshDone == nil {
refreshDone = make(chan struct{})
go tab.doRefresh(refreshDone)
}
case <-done:
case <-refreshDone:
for _, ch := range waiting {
close(ch)
}
waiting = nil
done = nil
waiting, refreshDone = nil, nil
case <-revalidate.C:
go tab.doRevalidate(revalidateDone)
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
case <-copyNodes.C:
go tab.copyBondedNodes()
case <-tab.closeReq:
break loop
}
@ -349,8 +410,8 @@ loop:
if tab.net != nil {
tab.net.close()
}
if done != nil {
<-done
if refreshDone != nil {
<-refreshDone
}
for _, ch := range waiting {
close(ch)
@ -365,38 +426,109 @@ loop:
func (tab *Table) doRefresh(done chan struct{}) {
defer close(done)
// Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are
// (hopefully) still alive.
tab.loadSeedNodes(true)
// Run self lookup to discover new neighbor nodes.
tab.lookup(tab.self.ID, false)
// The Kademlia paper specifies that the bucket refresh should
// perform a lookup in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket.
// We perform a lookup with a random target instead.
var target NodeID
rand.Read(target[:])
result := tab.lookup(target, false)
if len(result) > 0 {
// We perform a few lookups with a random target instead.
for i := 0; i < 3; i++ {
var target NodeID
crand.Read(target[:])
tab.lookup(target, false)
}
}
func (tab *Table) loadSeedNodes(bond bool) {
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = append(seeds, tab.nursery...)
if bond {
seeds = tab.bondall(seeds)
}
for i := range seeds {
seed := seeds[i]
age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.lastPong(seed.ID)) }}
log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
tab.add(seed)
}
}
// doRevalidate checks that the last node in a random bucket is still live
// and replaces or deletes the node if it isn't.
func (tab *Table) doRevalidate(done chan<- struct{}) {
defer func() { done <- struct{}{} }()
last, bi := tab.nodeToRevalidate()
if last == nil {
// No non-empty bucket found.
return
}
// The table is empty. Load nodes from the database and insert
// them. This should yield a few previously seen nodes that are
// (hopefully) still alive.
seeds := tab.db.querySeeds(seedCount, seedMaxAge)
seeds = tab.bondall(append(seeds, tab.nursery...))
// Ping the selected node and wait for a pong.
err := tab.ping(last.ID, last.addr())
if len(seeds) == 0 {
log.Debug("No discv4 seed nodes found")
}
for _, n := range seeds {
age := log.Lazy{Fn: func() time.Duration { return time.Since(tab.db.lastPong(n.ID)) }}
log.Trace("Found seed node in database", "id", n.ID, "addr", n.addr(), "age", age)
}
tab.mutex.Lock()
tab.stuff(seeds)
tab.mutex.Unlock()
defer tab.mutex.Unlock()
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
log.Debug("Revalidated node", "b", bi, "id", last.ID)
b.bump(last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
} else {
log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
}
}
// Finally, do a self lookup to fill up the buckets.
tab.lookup(tab.self.ID, false)
// nodeToRevalidate returns the last node in a random, non-empty bucket.
func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, bi = range tab.rand.Perm(len(tab.buckets)) {
b := tab.buckets[bi]
if len(b.entries) > 0 {
last := b.entries[len(b.entries)-1]
return last, bi
}
}
return nil, 0
}
func (tab *Table) nextRevalidateTime() time.Duration {
tab.mutex.Lock()
defer tab.mutex.Unlock()
return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
}
// copyBondedNodes adds nodes from the table to the database if they have been in the table
// longer then minTableTime.
func (tab *Table) copyBondedNodes() {
tab.mutex.Lock()
defer tab.mutex.Unlock()
now := time.Now()
for _, b := range tab.buckets {
for _, n := range b.entries {
if now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.updateNode(n)
}
}
}
}
// closest returns the n nodes in the table that are closest to the
@ -459,15 +591,14 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
if id == tab.self.ID {
return nil, errors.New("is self")
}
// Retrieve a previously known node and any recent findnode failures
node, fails := tab.db.node(id), 0
if node != nil {
fails = tab.db.findFails(id)
if pinged && !tab.isInitDone() {
return nil, errors.New("still initializing")
}
// If the node is unknown (non-bonded) or failed (remotely unknown), bond from scratch
var result error
// Start bonding if we haven't seen this node for a while or if it failed findnode too often.
node, fails := tab.db.node(id), tab.db.findFails(id)
age := time.Since(tab.db.lastPong(id))
if node == nil || fails > 0 || age > nodeDBNodeExpiration {
var result error
if fails > 0 || age > nodeDBNodeExpiration {
log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
tab.bondmu.Lock()
@ -494,10 +625,10 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
node = w.n
}
}
// Add the node to the table even if the bonding ping/pong
// fails. It will be relaced quickly if it continues to be
// unresponsive.
if node != nil {
// Add the node to the table even if the bonding ping/pong
// fails. It will be relaced quickly if it continues to be
// unresponsive.
tab.add(node)
tab.db.updateFindFails(id, 0)
}
@ -522,7 +653,6 @@ func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAdd
}
// Bonding succeeded, update the node database.
w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
tab.db.updateNode(w.n)
close(w.done)
}
@ -534,16 +664,18 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
return err
}
tab.db.updateLastPong(id, time.Now())
// Start the background expiration goroutine after the first
// successful communication. Subsequent calls have no effect if it
// is already running. We do this here instead of somewhere else
// so that the search for seed nodes also considers older nodes
// that would otherwise be removed by the expiration.
tab.db.ensureExpirer()
return nil
}
// bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(sha common.Hash) *bucket {
d := logdist(tab.self.sha, sha)
if d <= bucketMinDistance {
return tab.buckets[0]
}
return tab.buckets[d-bucketMinDistance-1]
}
// add attempts to add the given node its corresponding bucket. If the
// bucket has space available, adding the node succeeds immediately.
// Otherwise, the node is added if the least recently active node in
@ -551,57 +683,29 @@ func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
//
// The caller must not hold tab.mutex.
func (tab *Table) add(new *Node) {
b := tab.buckets[logdist(tab.self.sha, new.sha)]
tab.mutex.Lock()
defer tab.mutex.Unlock()
if b.bump(new) {
return
}
var oldest *Node
if len(b.entries) == bucketSize {
oldest = b.entries[bucketSize-1]
if oldest.contested {
// The node is already being replaced, don't attempt
// to replace it.
return
}
oldest.contested = true
// Let go of the mutex so other goroutines can access
// the table while we ping the least recently active node.
tab.mutex.Unlock()
err := tab.ping(oldest.ID, oldest.addr())
tab.mutex.Lock()
oldest.contested = false
if err == nil {
// The node responded, don't replace it.
return
}
}
added := b.replace(new, oldest)
if added && tab.nodeAddedHook != nil {
tab.nodeAddedHook(new)
b := tab.bucket(new.sha)
if !tab.bumpOrAdd(b, new) {
// Node is not in table. Add it to the replacement list.
tab.addReplacement(b, new)
}
}
// stuff adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must hold tab.mutex.
// if the bucket is not full. The caller must not hold tab.mutex.
func (tab *Table) stuff(nodes []*Node) {
outer:
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, n := range nodes {
if n.ID == tab.self.ID {
continue // don't add self
}
bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
for i := range bucket.entries {
if bucket.entries[i].ID == n.ID {
continue outer // already in bucket
}
}
if len(bucket.entries) < bucketSize {
bucket.entries = append(bucket.entries, n)
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(n)
}
b := tab.bucket(n.sha)
if len(b.entries) < bucketSize {
tab.bumpOrAdd(b, n)
}
}
}
@ -611,36 +715,72 @@ outer:
func (tab *Table) delete(node *Node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
bucket := tab.buckets[logdist(tab.self.sha, node.sha)]
for i := range bucket.entries {
if bucket.entries[i].ID == node.ID {
bucket.entries = append(bucket.entries[:i], bucket.entries[i+1:]...)
return
}
}
tab.deleteInBucket(tab.bucket(node.sha), node)
}
func (b *bucket) replace(n *Node, last *Node) bool {
// Don't add if b already contains n.
for i := range b.entries {
if b.entries[i].ID == n.ID {
return false
}
func (tab *Table) addIP(b *bucket, ip net.IP) bool {
if netutil.IsLAN(ip) {
return true
}
// Replace last if it is still the last entry or just add n if b
// isn't full. If is no longer the last entry, it has either been
// replaced with someone else or became active.
if len(b.entries) == bucketSize && (last == nil || b.entries[bucketSize-1].ID != last.ID) {
if !tab.ips.Add(ip) {
log.Debug("IP exceeds table limit", "ip", ip)
return false
}
if len(b.entries) < bucketSize {
b.entries = append(b.entries, nil)
if !b.ips.Add(ip) {
log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip)
return false
}
copy(b.entries[1:], b.entries)
b.entries[0] = n
return true
}
func (tab *Table) removeIP(b *bucket, ip net.IP) {
if netutil.IsLAN(ip) {
return
}
tab.ips.Remove(ip)
b.ips.Remove(ip)
}
func (tab *Table) addReplacement(b *bucket, n *Node) {
for _, e := range b.replacements {
if e.ID == n.ID {
return // already in list
}
}
if !tab.addIP(b, n.IP) {
return
}
var removed *Node
b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP)
}
}
// replace removes n from the replacement list and replaces 'last' with it if it is the
// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced
// with someone else or became active.
func (tab *Table) replace(b *bucket, last *Node) *Node {
if len(b.entries) >= 0 && b.entries[len(b.entries)-1].ID != last.ID {
// Entry has moved, don't replace it.
return nil
}
// Still the last entry.
if len(b.replacements) == 0 {
tab.deleteInBucket(b, last)
return nil
}
r := b.replacements[tab.rand.Intn(len(b.replacements))]
b.replacements = deleteNode(b.replacements, r)
b.entries[len(b.entries)-1] = r
tab.removeIP(b, last.IP)
return r
}
// bump moves the given node to the front of the bucket entry list
// if it is contained in that list.
func (b *bucket) bump(n *Node) bool {
for i := range b.entries {
if b.entries[i].ID == n.ID {
@ -653,6 +793,50 @@ func (b *bucket) bump(n *Node) bool {
return false
}
// bumpOrAdd moves n to the front of the bucket entry list or adds it if the list isn't
// full. The return value is true if n is in the bucket.
func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
if b.bump(n) {
return true
}
if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
return false
}
b.entries, _ = pushNode(b.entries, n, bucketSize)
b.replacements = deleteNode(b.replacements, n)
n.addedAt = time.Now()
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(n)
}
return true
}
func (tab *Table) deleteInBucket(b *bucket, n *Node) {
b.entries = deleteNode(b.entries, n)
tab.removeIP(b, n.IP)
}
// pushNode adds n to the front of list, keeping at most max items.
func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
if len(list) < max {
list = append(list, nil)
}
removed := list[len(list)-1]
copy(list[1:], list)
list[0] = n
return list, removed
}
// deleteNode removes n from list.
func deleteNode(list []*Node, n *Node) []*Node {
for i := range list {
if list[i].ID == n.ID {
return append(list[:i], list[i+1:]...)
}
}
return list
}
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {

@ -20,6 +20,7 @@ import (
"crypto/ecdsa"
"fmt"
"math/rand"
"sync"
"net"
"reflect"
@ -32,60 +33,65 @@ import (
)
func TestTable_pingReplace(t *testing.T) {
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "")
defer tab.Close()
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
// fill up the sender's bucket.
last := fillBucket(tab, 253)
// this call to bond should replace the last node
// in its bucket if the node is not responding.
transport.responding[last.ID] = lastInBucketIsResponding
transport.responding[pingSender.ID] = newNodeIsResponding
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender.ID] {
t.Error("table did not ping back sender")
}
if newNodeIsResponding {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
if !transport.pinged[last.ID] {
t.Error("table did not ping last node in bucket")
}
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[253].entries); l != bucketSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
}
if lastInBucketIsResponding || !newNodeIsResponding {
if !contains(tab.buckets[253].entries, last.ID) {
t.Error("last entry was removed")
}
if contains(tab.buckets[253].entries, pingSender.ID) {
t.Error("new entry was added")
}
} else {
if contains(tab.buckets[253].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[253].entries, pingSender.ID) {
t.Error("new entry was not added")
}
}
run := func(newNodeResponding, lastInBucketResponding bool) {
name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
t.Run(name, func(t *testing.T) {
t.Parallel()
testPingReplace(t, newNodeResponding, lastInBucketResponding)
})
}
doit(true, true)
doit(false, true)
doit(true, false)
doit(false, false)
run(true, true)
run(false, true)
run(true, false)
run(false, false)
}
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
// Wait for init so bond is accepted.
<-tab.initDone
// fill up the sender's bucket.
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
last := fillBucket(tab, pingSender)
// this call to bond should replace the last node
// in its bucket if the node is not responding.
transport.dead[last.ID] = !lastInBucketIsResponding
transport.dead[pingSender.ID] = !newNodeIsResponding
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
tab.doRevalidate(make(chan struct{}, 1))
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender.ID] {
t.Error("table did not ping back sender")
}
if !transport.pinged[last.ID] {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
t.Error("table did not ping last node in bucket")
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
wantSize := bucketSize
if !lastInBucketIsResponding && !newNodeIsResponding {
wantSize--
}
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
}
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
}
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
}
}
func TestBucket_bumpNoDuplicates(t *testing.T) {
@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
}
}
// This checks that the table-wide IP limit is applied correctly.
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, i)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table")
}
}
// This checks that the table-wide IP limit is applied correctly.
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, d)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table")
}
}
// fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their
// hashes.
func fillBucket(tab *Table, ld int) (last *Node) {
b := tab.buckets[ld]
func fillBucket(tab *Table, n *Node) (last *Node) {
ld := logdist(tab.self.sha, n.sha)
b := tab.bucket(n.sha)
for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
}
@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node)
n.sha = hashAtDistance(base, ld)
n.IP = net.IP{10, 0, 2, byte(ld)}
n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n
}
type pingRecorder struct{ responding, pinged map[NodeID]bool }
type pingRecorder struct {
mu sync.Mutex
dead, pinged map[NodeID]bool
}
func newPingRecorder() *pingRecorder {
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
return &pingRecorder{
dead: make(map[NodeID]bool),
pinged: make(map[NodeID]bool),
}
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder")
return nil, nil
}
func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings
}
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
t.pinged[toid] = true
if t.responding[toid] {
return nil
} else {
if t.dead[toid] {
return errTimeout
} else {
return nil
}
}
@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool {
// for any node table, Target and N
tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "")
transport := newPingRecorder()
tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close()
tab.stuff(test.All)
@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
},
}
test := func(buf []*Node) bool {
tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0)
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "")
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close()
// lookup on empty table returns no nodes

@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr
}
// Config holds Table-related settings.
type Config struct {
// These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey
// These settings are optional:
AnnounceAddr *net.UDPAddr // local address announced in the DHT
NodeDBPath string // if set, the node database is stored at this filesystem location
NetRestrict *netutil.Netlist // network whitelist
Bootnodes []*Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict)
func ListenUDP(c conn, cfg Config) (*Table, error) {
tab, _, err := newUDP(c, cfg)
if err != nil {
return nil, err
}
@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil
}
func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{
conn: c,
priv: priv,
netrestrict: netrestrict,
priv: cfg.PrivateKey,
netrestrict: cfg.NetRestrict,
closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending),
}
realaddr := c.LocalAddr().(*net.UDPAddr)
if cfg.AnnounceAddr != nil {
realaddr = cfg.AnnounceAddr
}
// TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil {
return nil, nil, err
}
udp.Table = tab
go udp.loop()
go udp.readLoop(unhandled)
go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil
}
@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
t.send(toaddr, pingPacket, &ping{
req := &ping{
Version: Version,
From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()),
}
packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil {
return err
}
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
return bytes.Equal(p.(*pong).ReplyTok, hash)
})
t.write(toaddr, req.name(), packet)
return <-errc
}
@ -447,40 +470,45 @@ func init() {
}
}
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error {
packet, err := encodePacket(t.priv, ptype, req)
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil {
return err
return hash, err
}
_, err = t.conn.WriteToUDP(packet, toaddr)
log.Trace(">> "+req.name(), "addr", toaddr, "err", err)
return hash, t.write(toaddr, req.name(), packet)
}
func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
log.Trace(">> "+what, "addr", toaddr, "err", err)
return err
}
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err)
return nil, err
return nil, nil, err
}
packet := b.Bytes()
packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil {
log.Error("Can't sign discv4 packet", "err", err)
return nil, err
return nil, nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
copy(packet, crypto.Keccak256(packet[macSize:]))
return packet, nil
hash = crypto.Keccak256(packet[macSize:])
copy(packet, hash)
return packet, hash, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop(unhandled chan ReadPacket) {
func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close()
if unhandled != nil {
defer close(unhandled)
@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit.
for i, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP) != nil {
continue
for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
p.Nodes = append(p.Nodes, nodeToRPC(n))
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
if len(p.Nodes) == maxNeighbors {
t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0]
sent = true
}
}
if len(p.Nodes) > 0 || !sent {
t.send(from, neighborsPacket, &p)
}
return nil
}

@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
}
realaddr := test.pipe.LocalAddr().(*net.UDPAddr)
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil)
test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test
}
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
enc, err := encodePacket(test.remotekey, ptype, data)
enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err)
}
@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) error {
func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut()
p, _, _, err := decodePacket(dgram)
p, _, hash, err := decodePacket(dgram)
if err != nil {
return test.errorf("sent packet decode error: %v", err)
return hash, test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil
return hash, nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) {
})
// remote is unknown, the table pings back.
test.waitPacketOut(func(p *ping) error {
hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
}
@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) {
}
return nil
})
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the
// pong packet.

@ -18,8 +18,11 @@
package netutil
import (
"bytes"
"errors"
"fmt"
"net"
"sort"
"strings"
)
@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
}
return nil
}
// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
func SameNet(bits uint, ip, other net.IP) bool {
ip4, other4 := ip.To4(), other.To4()
switch {
case (ip4 == nil) != (other4 == nil):
return false
case ip4 != nil:
return sameNet(bits, ip4, other4)
default:
return sameNet(bits, ip.To16(), other.To16())
}
}
func sameNet(bits uint, ip, other net.IP) bool {
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
return false
}
return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
}
// DistinctNetSet tracks IPs, ensuring that at most N of them
// fall into the same network range.
type DistinctNetSet struct {
Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet
members map[string]uint
buf net.IP
}
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) Add(ip net.IP) bool {
key := s.key(ip)
n := s.members[string(key)]
if n < s.Limit {
s.members[string(key)] = n + 1
return true
}
return false
}
// Remove removes an IP from the set.
func (s *DistinctNetSet) Remove(ip net.IP) {
key := s.key(ip)
if n, ok := s.members[string(key)]; ok {
if n == 1 {
delete(s.members, string(key))
} else {
s.members[string(key)] = n - 1
}
}
}
// Contains whether the given IP is contained in the set.
func (s DistinctNetSet) Contains(ip net.IP) bool {
key := s.key(ip)
_, ok := s.members[string(key)]
return ok
}
// Len returns the number of tracked IPs.
func (s DistinctNetSet) Len() int {
n := uint(0)
for _, i := range s.members {
n += i
}
return int(n)
}
// key encodes the map key for an address into a temporary buffer.
//
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func (s *DistinctNetSet) key(ip net.IP) net.IP {
// Lazily initialize storage.
if s.members == nil {
s.members = make(map[string]uint)
s.buf = make(net.IP, 17)
}
// Canonicalize ip and bits.
typ := byte('6')
if ip4 := ip.To4(); ip4 != nil {
typ, ip = '4', ip4
}
bits := s.Subnet
if bits > uint(len(ip)*8) {
bits = uint(len(ip) * 8)
}
// Encode the prefix into s.buf.
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
s.buf[0] = typ
buf := append(s.buf[:1], ip[:nb]...)
if nb < len(ip) && mask != 0 {
buf = append(buf, ip[nb]&mask)
}
return buf
}
// String implements fmt.Stringer
func (s DistinctNetSet) String() string {
var buf bytes.Buffer
buf.WriteString("{")
keys := make([]string, 0, len(s.members))
for k := range s.members {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
var ip net.IP
if k[0] == '4' {
ip = make(net.IP, 4)
} else {
ip = make(net.IP, 16)
}
copy(ip, k[1:])
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
if i != len(keys)-1 {
buf.WriteString(" ")
}
}
buf.WriteString("}")
return buf.String()
}

@ -17,9 +17,11 @@
package netutil
import (
"fmt"
"net"
"reflect"
"testing"
"testing/quick"
"github.com/davecgh/go-spew/spew"
)
@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr)
}
}
func TestSameNet(t *testing.T) {
tests := []struct {
ip, other string
bits uint
want bool
}{
{"0.0.0.0", "0.0.0.0", 32, true},
{"0.0.0.0", "0.0.0.1", 0, true},
{"0.0.0.0", "0.0.0.1", 31, true},
{"0.0.0.0", "0.0.0.1", 32, false},
{"0.33.0.1", "0.34.0.2", 8, true},
{"0.33.0.1", "0.34.0.2", 13, true},
{"0.33.0.1", "0.34.0.2", 15, false},
}
for _, test := range tests {
if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
}
}
}
func ExampleSameNet() {
// This returns true because the IPs are in the same /24 network:
fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
// This call returns false:
fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
// Output:
// true
// false
}
func TestDistinctNetSet(t *testing.T) {
ops := []struct {
add, remove string
fails bool
}{
{add: "127.0.0.1"},
{add: "127.0.0.2"},
{add: "127.0.0.3", fails: true},
{add: "127.32.0.1"},
{add: "127.32.0.2"},
{add: "127.32.0.3", fails: true},
{add: "127.33.0.1", fails: true},
{add: "127.34.0.1"},
{add: "127.34.0.2"},
{add: "127.34.0.3", fails: true},
// Make room for an address, then add again.
{remove: "127.0.0.1"},
{add: "127.0.0.3"},
{add: "127.0.0.3", fails: true},
}
set := DistinctNetSet{Subnet: 15, Limit: 2}
for _, op := range ops {
var desc string
if op.add != "" {
desc = fmt.Sprintf("Add(%s)", op.add)
if ok := set.Add(parseIP(op.add)); ok != !op.fails {
t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
}
} else {
desc = fmt.Sprintf("Remove(%s)", op.remove)
set.Remove(parseIP(op.remove))
}
t.Logf("%s: %v", desc, set)
}
}
func TestDistinctNetSetAddRemove(t *testing.T) {
cfg := &quick.Config{}
fn := func(ips []net.IP) bool {
s := DistinctNetSet{Limit: 3, Subnet: 2}
for _, ip := range ips {
s.Add(ip)
}
for _, ip := range ips {
s.Remove(ip)
}
return s.Len() == 0
}
if err := quick.Check(fn, cfg); err != nil {
t.Fatal(err)
}
}

@ -419,6 +419,9 @@ type PeerInfo struct {
Network struct {
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
Inbound bool `json:"inbound"`
Trusted bool `json:"trusted"`
Static bool `json:"static"`
} `json:"network"`
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
}
@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
}
info.Network.LocalAddress = p.LocalAddr().String()
info.Network.RemoteAddress = p.RemoteAddr().String()
info.Network.Inbound = p.rw.is(inboundConn)
info.Network.Trusted = p.rw.is(trustedConn)
info.Network.Static = p.rw.is(staticDialedConn)
// Gather all the running protocol infos
for _, proto := range p.running {

@ -40,11 +40,10 @@ const (
refreshPeersInterval = 30 * time.Second
staticPeerCheckInterval = 15 * time.Second
// Maximum number of concurrently handshaking inbound connections.
maxAcceptConns = 50
// Maximum number of concurrently dialing outbound connections.
maxActiveDialTasks = 16
// Connectivity defaults.
maxActiveDialTasks = 16
defaultMaxPendingPeers = 50
defaultDialRatio = 3
// Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle.
@ -70,6 +69,11 @@ type Config struct {
// Zero defaults to preset values.
MaxPendingPeers int `toml:",omitempty"`
// DialRatio controls the ratio of inbound to dialed connections.
// Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
// Setting DialRatio to zero defaults it to 3.
DialRatio int `toml:",omitempty"`
// NoDiscovery can be used to disable the peer discovery mechanism.
// Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool
@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) {
if err != nil {
return err
}
realaddr = conn.LocalAddr().(*net.UDPAddr)
if srv.NAT != nil {
if !realaddr.IP.IsLoopback() {
@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) {
// node table
if !srv.NoDiscovery {
ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict)
if err != nil {
return err
cfg := discover.Config{
PrivateKey: srv.PrivateKey,
AnnounceAddr: realaddr,
NodeDBPath: srv.NodeDatabase,
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
}
if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil {
ntab, err := discover.ListenUDP(conn, cfg)
if err != nil {
return err
}
srv.ntab = ntab
@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) {
srv.DiscV5 = ntab
}
dynPeers := (srv.MaxPeers + 1) / 2
if srv.NoDiscovery {
dynPeers = 0
}
dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done()
var (
peers = make(map[discover.NodeID]*Peer)
inboundCount = 0
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task
@ -621,14 +627,14 @@ running:
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select {
case c.cont <- srv.encHandshakeChecks(peers, c):
case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit:
break running
}
case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
err := srv.protoHandshakeChecks(peers, c)
err := srv.protoHandshakeChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols)
@ -639,8 +645,11 @@ running:
}
name := truncateName(c.name)
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
peers[c.id] = p
go srv.runPeer(p)
peers[c.id] = p
if p.Inbound() {
inboundCount++
}
}
// The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or
@ -655,6 +664,9 @@ running:
d := common.PrettyDuration(mclock.Now() - pd.created)
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
delete(peers, pd.ID())
if pd.Inbound() {
inboundCount--
}
}
}
@ -681,20 +693,22 @@ running:
}
}
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer
}
// Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, c)
return srv.encHandshakeChecks(peers, inboundCount, c)
}
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers
case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
return DiscTooManyPeers
case peers[c.id] != nil:
return DiscAlreadyConnected
case c.id == srv.Self().ID:
@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
}
}
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
}
r := srv.DialRatio
if r == 0 {
r = defaultDialRatio
}
return srv.MaxPeers / r
}
type tempError interface {
Temporary() bool
}
@ -714,10 +743,7 @@ func (srv *Server) listenLoop() {
defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
// This channel acts as a semaphore limiting
// active inbound connections that are lingering pre-handshake.
// If all slots are taken, no further connections are accepted.
tokens := maxAcceptConns
tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers
}
@ -758,9 +784,6 @@ func (srv *Server) listenLoop() {
fd = newMeteredConn(fd, true)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection
// has been established.
go func() {
srv.SetupConn(fd, inboundConn, nil)
slots <- struct{}{}