Merge pull request #592 from fjl/disco-ping-pong

Discovery bonding protocol
This commit is contained in:
Jeffrey Wilcke 2015-04-01 17:10:10 +02:00
commit fd171eff7f
6 changed files with 721 additions and 405 deletions

@ -32,8 +32,8 @@ var (
defaultBootNodes = []*discover.Node{ defaultBootNodes = []*discover.Node{
// ETH/DEV cmd/bootnode // ETH/DEV cmd/bootnode
discover.MustParseNode("enode://09fbeec0d047e9a37e63f60f8618aa9df0e49271f3fadb2c070dc09e2099b95827b63a8b837c6fd01d0802d457dd83e3bd48bd3e6509f8209ed90dabbc30e3d3@52.16.188.185:30303"), discover.MustParseNode("enode://09fbeec0d047e9a37e63f60f8618aa9df0e49271f3fadb2c070dc09e2099b95827b63a8b837c6fd01d0802d457dd83e3bd48bd3e6509f8209ed90dabbc30e3d3@52.16.188.185:30303"),
// ETH/DEV cpp-ethereum (poc-8.ethdev.com) // ETH/DEV cpp-ethereum (poc-9.ethdev.com)
discover.MustParseNode("enode://4a44599974518ea5b0f14c31c4463692ac0329cb84851f3435e6d1b18ee4eae4aa495f846a0fa1219bd58035671881d44423876e57db2abd57254d0197da0ebe@5.1.83.226:30303"), discover.MustParseNode("enode://487611428e6c99a11a9795a6abe7b529e81315ca6aad66e2a2fc76e3adf263faba0d35466c2f8f68d561dbefa8878d4df5f1f2ddb1fbeab7f42ffb8cd328bd4a@5.1.83.226:30303"),
} }
) )

@ -13,6 +13,8 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
@ -30,7 +32,8 @@ type Node struct {
DiscPort int // UDP listening port for discovery protocol DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx TCPPort int // TCP listening port for RLPx
active time.Time // this must be set/read using atomic load and store.
activeStamp int64
} }
func newNode(id NodeID, addr *net.UDPAddr) *Node { func newNode(id NodeID, addr *net.UDPAddr) *Node {
@ -39,7 +42,6 @@ func newNode(id NodeID, addr *net.UDPAddr) *Node {
IP: addr.IP, IP: addr.IP,
DiscPort: addr.Port, DiscPort: addr.Port,
TCPPort: addr.Port, TCPPort: addr.Port,
active: time.Now(),
} }
} }
@ -48,6 +50,20 @@ func (n *Node) isValid() bool {
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0 return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
} }
func (n *Node) bumpActive() {
stamp := time.Now().Unix()
atomic.StoreInt64(&n.activeStamp, stamp)
}
func (n *Node) active() time.Time {
stamp := atomic.LoadInt64(&n.activeStamp)
return time.Unix(stamp, 0)
}
func (n *Node) addr() *net.UDPAddr {
return &net.UDPAddr{IP: n.IP, Port: n.DiscPort}
}
// The string representation of a Node is a URL. // The string representation of a Node is a URL.
// Please see ParseNode for a description of the format. // Please see ParseNode for a description of the format.
func (n *Node) String() string { func (n *Node) String() string {
@ -304,3 +320,26 @@ func randomID(a NodeID, n int) (b NodeID) {
} }
return b return b
} }
// nodeDB stores all nodes we know about.
type nodeDB struct {
mu sync.RWMutex
byID map[NodeID]*Node
}
func (db *nodeDB) get(id NodeID) *Node {
db.mu.RLock()
defer db.mu.RUnlock()
return db.byID[id]
}
func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node {
db.mu.Lock()
defer db.mu.Unlock()
if db.byID == nil {
db.byID = make(map[NodeID]*Node)
}
n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)}
db.byID[n.ID] = n
return n
}

@ -14,9 +14,10 @@ import (
) )
const ( const (
alpha = 3 // Kademlia concurrency factor alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets nBuckets = nodeIDBits + 1 // Number of buckets
maxBondingPingPongs = 10
) )
type Table struct { type Table struct {
@ -24,27 +25,50 @@ type Table struct {
buckets [nBuckets]*bucket // index of known nodes by distance buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes nursery []*Node // bootstrap nodes
bondmu sync.Mutex
bonding map[NodeID]*bondproc
bondslots chan struct{} // limits total number of active bonding processes
net transport net transport
self *Node // metadata of the local node self *Node // metadata of the local node
db *nodeDB
}
type bondproc struct {
err error
n *Node
done chan struct{}
} }
// transport is implemented by the UDP transport. // transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP // it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key. // sockets and without generating a private key.
type transport interface { type transport interface {
ping(*Node) error ping(NodeID, *net.UDPAddr) error
findnode(e *Node, target NodeID) ([]*Node, error) waitping(NodeID) error
findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
close() close()
} }
// bucket contains nodes, ordered by their last activity. // bucket contains nodes, ordered by their last activity.
// the entry that was most recently active is the last element
// in entries.
type bucket struct { type bucket struct {
lastLookup time.Time lastLookup time.Time
entries []*Node entries []*Node
} }
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
tab := &Table{net: t, self: newNode(ourID, ourAddr)} tab := &Table{
net: t,
db: new(nodeDB),
self: newNode(ourID, ourAddr),
bonding: make(map[NodeID]*bondproc),
bondslots: make(chan struct{}, maxBondingPingPongs),
}
for i := 0; i < cap(tab.bondslots); i++ {
tab.bondslots <- struct{}{}
}
for i := range tab.buckets { for i := range tab.buckets {
tab.buckets[i] = new(bucket) tab.buckets[i] = new(bucket)
} }
@ -107,8 +131,8 @@ func (tab *Table) Lookup(target NodeID) []*Node {
asked[n.ID] = true asked[n.ID] = true
pendingQueries++ pendingQueries++
go func() { go func() {
result, _ := tab.net.findnode(n, target) r, _ := tab.net.findnode(n.ID, n.addr(), target)
reply <- result reply <- tab.bondall(r)
}() }()
} }
} }
@ -116,13 +140,11 @@ func (tab *Table) Lookup(target NodeID) []*Node {
// we have asked all closest nodes, stop the search // we have asked all closest nodes, stop the search
break break
} }
// wait for the next reply // wait for the next reply
for _, n := range <-reply { for _, n := range <-reply {
cn := n if n != nil && !seen[n.ID] {
if !seen[n.ID] {
seen[n.ID] = true seen[n.ID] = true
result.push(cn, bucketSize) result.push(n, bucketSize)
} }
} }
pendingQueries-- pendingQueries--
@ -145,8 +167,9 @@ func (tab *Table) refresh() {
result := tab.Lookup(randomID(tab.self.ID, ld)) result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 { if len(result) == 0 {
// bootstrap the table with a self lookup // bootstrap the table with a self lookup
all := tab.bondall(tab.nursery)
tab.mutex.Lock() tab.mutex.Lock()
tab.add(tab.nursery) tab.add(all)
tab.mutex.Unlock() tab.mutex.Unlock()
tab.Lookup(tab.self.ID) tab.Lookup(tab.self.ID)
// TODO: the Kademlia paper says that we're supposed to perform // TODO: the Kademlia paper says that we're supposed to perform
@ -176,45 +199,105 @@ func (tab *Table) len() (n int) {
return n return n
} }
// bumpOrAdd updates the activity timestamp for the given node and // bondall bonds with all given nodes concurrently and returns
// attempts to insert the node into a bucket. The returned Node might // those nodes for which bonding has probably succeeded.
// not be part of the table. The caller must hold tab.mutex. func (tab *Table) bondall(nodes []*Node) (result []*Node) {
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) { rc := make(chan *Node, len(nodes))
b := tab.buckets[logdist(tab.self.ID, node)] for i := range nodes {
if n = b.bump(node); n == nil { go func(n *Node) {
n = newNode(node, from) nn, _ := tab.bond(false, n.ID, n.addr(), uint16(n.TCPPort))
if len(b.entries) == bucketSize { rc <- nn
tab.pingReplace(n, b) }(nodes[i])
} else { }
b.entries = append(b.entries, n) for _ = range nodes {
if n := <-rc; n != nil {
result = append(result, n)
} }
} }
return n return result
} }
func (tab *Table) pingReplace(n *Node, b *bucket) { // bond ensures the local node has a bond with the given remote node.
old := b.entries[bucketSize-1] // It also attempts to insert the node into the table if bonding succeeds.
go func() { // The caller must not hold tab.mutex.
if err := tab.net.ping(old); err == nil { //
// it responded, we don't need to replace it. // A bond is must be established before sending findnode requests.
// Both sides must have completed a ping/pong exchange for a bond to
// exist. The total number of active bonding processes is limited in
// order to restrain network use.
//
// bond is meant to operate idempotently in that bonding with a remote
// node which still remembers a previously established bond will work.
// The remote node will simply not send a ping back, causing waitping
// to time out.
//
// If pinged is true, the remote node has just pinged us and one half
// of the process can be skipped.
func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
var n *Node
if n = tab.db.get(id); n == nil {
tab.bondmu.Lock()
w := tab.bonding[id]
if w != nil {
// Wait for an existing bonding process to complete.
tab.bondmu.Unlock()
<-w.done
} else {
// Register a new bonding process.
w = &bondproc{done: make(chan struct{})}
tab.bonding[id] = w
tab.bondmu.Unlock()
// Do the ping/pong. The result goes into w.
tab.pingpong(w, pinged, id, addr, tcpPort)
// Unregister the process after it's done.
tab.bondmu.Lock()
delete(tab.bonding, id)
tab.bondmu.Unlock()
}
n = w.n
if w.err != nil {
return nil, w.err
}
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
tab.pingreplace(n, b)
}
return n, nil
}
func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
<-tab.bondslots
defer func() { tab.bondslots <- struct{}{} }()
if w.err = tab.net.ping(id, addr); w.err != nil {
close(w.done)
return
}
if !pinged {
// Give the remote node a chance to ping us before we start
// sending findnode requests. If they still remember us,
// waitping will simply time out.
tab.net.waitping(id)
}
w.n = tab.db.add(id, addr, tcpPort)
close(w.done)
}
func (tab *Table) pingreplace(new *Node, b *bucket) {
if len(b.entries) == bucketSize {
oldest := b.entries[bucketSize-1]
if err := tab.net.ping(oldest.ID, oldest.addr()); err == nil {
// The node responded, we don't need to replace it.
return return
} }
// it didn't respond, replace the node if it is still the oldest node. } else {
tab.mutex.Lock() // Add a slot at the end so the last entry doesn't
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old { // fall off when adding the new node.
// slide down other entries and put the new one in front. b.entries = append(b.entries, nil)
// TODO: insert in correct position to keep the order }
copy(b.entries[1:], b.entries) copy(b.entries[1:], b.entries)
b.entries[0] = n b.entries[0] = new
}
tab.mutex.Unlock()
}()
}
// bump updates the activity timestamp for the given node.
// The caller must hold tab.mutex.
func (tab *Table) bump(node NodeID) {
tab.buckets[logdist(tab.self.ID, node)].bump(node)
} }
// add puts the entries into the table if their corresponding // add puts the entries into the table if their corresponding
@ -240,17 +323,17 @@ outer:
} }
} }
func (b *bucket) bump(id NodeID) *Node { func (b *bucket) bump(n *Node) bool {
for i, n := range b.entries { for i := range b.entries {
if n.ID == id { if b.entries[i].ID == n.ID {
n.active = time.Now() n.bumpActive()
// move it to the front // move it to the front
copy(b.entries[1:], b.entries[:i+1]) copy(b.entries[1:], b.entries[:i])
b.entries[0] = n b.entries[0] = n
return n return true
} }
} }
return nil return false
} }
// nodesByDistance is a list of nodes, ordered by // nodesByDistance is a list of nodes, ordered by

@ -2,78 +2,109 @@ package discover
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
) )
func TestTable_bumpOrAddBucketAssign(t *testing.T) { func TestTable_pingReplace(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{}) doit := func(newNodeIsResponding, lastInBucketIsResponding bool) {
for i := 1; i < len(tab.buckets); i++ { transport := newPingRecorder()
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{}) tab := newTable(transport, NodeID{}, &net.UDPAddr{})
} last := fillBucket(tab, 200)
for i, b := range tab.buckets { pingSender := randomID(tab.self.ID, 200)
if i > 0 && len(b.entries) != 1 {
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries)) // this gotPing should replace the last node
// if the last node is not responding.
transport.responding[last.ID] = lastInBucketIsResponding
transport.responding[pingSender] = newNodeIsResponding
tab.bond(true, pingSender, &net.UDPAddr{}, 0)
// first ping goes to sender (bonding pingback)
if !transport.pinged[pingSender] {
t.Error("table did not ping back sender")
}
if newNodeIsResponding {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
if !transport.pinged[last.ID] {
t.Error("table did not ping last node in bucket")
}
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after gotPing: got %d, want %d", bucketSize, l)
}
if lastInBucketIsResponding || !newNodeIsResponding {
if !contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was removed")
}
if contains(tab.buckets[200].entries, pingSender) {
t.Error("new entry was added")
}
} else {
if contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[200].entries, pingSender) {
t.Error("new entry was not added")
}
} }
} }
doit(true, true)
doit(false, true)
doit(false, true)
doit(false, false)
} }
func TestTable_bumpOrAddPingReplace(t *testing.T) { func TestBucket_bumpNoDuplicates(t *testing.T) {
pingC := make(pingC) t.Parallel()
tab := newTable(pingC, NodeID{}, &net.UDPAddr{}) cfg := &quick.Config{
last := fillBucket(tab, 200) MaxCount: 1000,
Rand: quickrand,
// this bumpOrAdd should not replace the last node Values: func(args []reflect.Value, rand *rand.Rand) {
// because the node replies to ping. // generate a random list of nodes. this will be the content of the bucket.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{}) n := rand.Intn(bucketSize-1) + 1
nodes := make([]*Node, n)
pinged := <-pingC for i := range nodes {
if pinged != last.ID { nodes[i] = &Node{ID: randomID(NodeID{}, 200)}
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID) }
args[0] = reflect.ValueOf(nodes)
// generate random bump positions.
bumps := make([]int, rand.Intn(100))
for i := range bumps {
bumps[i] = rand.Intn(len(nodes))
}
args[1] = reflect.ValueOf(bumps)
},
} }
tab.mutex.Lock() prop := func(nodes []*Node, bumps []int) (ok bool) {
defer tab.mutex.Unlock() b := &bucket{entries: make([]*Node, len(nodes))}
if l := len(tab.buckets[200].entries); l != bucketSize { copy(b.entries, nodes)
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l) for i, pos := range bumps {
b.bump(b.entries[pos])
if hasDuplicates(b.entries) {
t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps))
for _, n := range b.entries {
t.Logf(" %p", n)
}
return false
}
}
return true
} }
if !contains(tab.buckets[200].entries, last.ID) { if err := quick.Check(prop, cfg); err != nil {
t.Error("last entry was removed") t.Error(err)
}
if contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was added")
}
}
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was not added")
} }
} }
@ -85,44 +116,27 @@ func fillBucket(tab *Table, ld int) (last *Node) {
return b.entries[bucketSize-1] return b.entries[bucketSize-1]
} }
type pingC chan NodeID type pingRecorder struct{ responding, pinged map[NodeID]bool }
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) { func newPingRecorder() *pingRecorder {
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)}
}
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder") panic("findnode called on pingRecorder")
} }
func (t pingC) close() { func (t *pingRecorder) close() {
panic("close called on pingRecorder") panic("close called on pingRecorder")
} }
func (t pingC) ping(n *Node) error { func (t *pingRecorder) waitping(from NodeID) error {
if t == nil { return nil // remote always pings
return errTimeout
}
t <- n.ID
return nil
} }
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
func TestTable_bump(t *testing.T) { t.pinged[toid] = true
tab := newTable(nil, NodeID{}, &net.UDPAddr{}) if t.responding[toid] {
return nil
// add an old entry and two recent ones } else {
oldactive := time.Now().Add(-2 * time.Minute) return errTimeout
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
others := []*Node{
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
}
tab.add(append(others, old))
if tab.buckets[200].entries[0] == old {
t.Fatal("old entry is at front of bucket")
}
// bumping the old entry should move it to the front
tab.bump(old.ID)
if old.active == oldactive {
t.Error("activity timestamp not updated")
}
if tab.buckets[200].entries[0] != old {
t.Errorf("bumped entry did not move to the front of bucket")
} }
} }
@ -210,7 +224,7 @@ func TestTable_Lookup(t *testing.T) {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
} }
// seed table with initial node (otherwise lookup will terminate immediately) // seed table with initial node (otherwise lookup will terminate immediately)
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200}) tab.add([]*Node{newNode(randomID(target, 200), &net.UDPAddr{Port: 200})})
results := tab.Lookup(target) results := tab.Lookup(target)
t.Logf("results:") t.Logf("results:")
@ -238,16 +252,16 @@ type findnodeOracle struct {
target NodeID target NodeID
} }
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) { func (t findnodeOracle) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
t.t.Logf("findnode query at dist %d", n.DiscPort) t.t.Logf("findnode query at dist %d", toaddr.Port)
// current log distance is encoded in port number // current log distance is encoded in port number
var result []*Node var result []*Node
switch n.DiscPort { switch toaddr.Port {
case 0: case 0:
panic("query to node at distance 0") panic("query to node at distance 0")
default: default:
// TODO: add more randomness to distances // TODO: add more randomness to distances
next := n.DiscPort - 1 next := toaddr.Port - 1
for i := 0; i < bucketSize; i++ { for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next}) result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
} }
@ -255,11 +269,9 @@ func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
return result, nil return result, nil
} }
func (t findnodeOracle) close() {} func (t findnodeOracle) close() {}
func (t findnodeOracle) waitping(from NodeID) error { return nil }
func (t findnodeOracle) ping(n *Node) error { func (t findnodeOracle) ping(toid NodeID, toaddr *net.UDPAddr) error { return nil }
return errors.New("ping is not supported by this transport")
}
func hasDuplicates(slice []*Node) bool { func hasDuplicates(slice []*Node) bool {
seen := make(map[NodeID]bool) seen := make(map[NodeID]bool)

@ -16,13 +16,18 @@ import (
var log = logger.NewLogger("P2P Discovery") var log = logger.NewLogger("P2P Discovery")
const Version = 3
// Errors // Errors
var ( var (
errPacketTooSmall = errors.New("too small") errPacketTooSmall = errors.New("too small")
errBadHash = errors.New("bad hash") errBadHash = errors.New("bad hash")
errExpired = errors.New("expired") errExpired = errors.New("expired")
errTimeout = errors.New("RPC timeout") errBadVersion = errors.New("version mismatch")
errClosed = errors.New("socket closed") errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node")
errTimeout = errors.New("RPC timeout")
errClosed = errors.New("socket closed")
) )
// Timeouts // Timeouts
@ -45,6 +50,7 @@ const (
// RPC request structures // RPC request structures
type ( type (
ping struct { ping struct {
Version uint // must match Version
IP string // our IP IP string // our IP
Port uint16 // our port Port uint16 // our port
Expiration uint64 Expiration uint64
@ -76,14 +82,27 @@ type rpcNode struct {
ID NodeID ID NodeID
} }
type packet interface {
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
}
type conn interface {
ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
Close() error
LocalAddr() net.Addr
}
// udp implements the RPC protocol. // udp implements the RPC protocol.
type udp struct { type udp struct {
conn *net.UDPConn conn conn
priv *ecdsa.PrivateKey priv *ecdsa.PrivateKey
addpending chan *pending addpending chan *pending
replies chan reply gotreply chan reply
closing chan struct{}
nat nat.Interface closing chan struct{}
nat nat.Interface
*Table *Table
} }
@ -120,6 +139,9 @@ type reply struct {
from NodeID from NodeID
ptype byte ptype byte
data interface{} data interface{}
// loop indicates whether there was
// a matching request by sending on this channel.
matched chan<- bool
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
@ -132,15 +154,20 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
if err != nil { if err != nil {
return nil, err return nil, err
} }
tab, _ := newUDP(priv, conn, natm)
log.Infoln("Listening,", tab.self)
return tab, nil
}
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface) (*Table, *udp) {
udp := &udp{ udp := &udp{
conn: conn, conn: c,
priv: priv, priv: priv,
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
replies: make(chan reply),
} }
realaddr := c.LocalAddr().(*net.UDPAddr)
realaddr := conn.LocalAddr().(*net.UDPAddr)
if natm != nil { if natm != nil {
if !realaddr.IP.IsLoopback() { if !realaddr.IP.IsLoopback() {
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
@ -151,11 +178,9 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table
} }
} }
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr) udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
go udp.loop() go udp.loop()
go udp.readLoop() go udp.readLoop()
log.Infoln("Listening, ", udp.self) return udp.Table, udp
return udp.Table, nil
} }
func (t *udp) close() { func (t *udp) close() {
@ -165,10 +190,11 @@ func (t *udp) close() {
} }
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(e *Node) error { func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT // TODO: maybe check for ReplyTo field in callback to measure RTT
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true }) errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
t.send(e, pingPacket, ping{ t.send(toaddr, pingPacket, ping{
Version: Version,
IP: t.self.IP.String(), IP: t.self.IP.String(),
Port: uint16(t.self.TCPPort), Port: uint16(t.self.TCPPort),
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
@ -176,12 +202,16 @@ func (t *udp) ping(e *Node) error {
return <-errc return <-errc
} }
func (t *udp) waitping(from NodeID) error {
return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
}
// findnode sends a findnode request to the given node and waits until // findnode sends a findnode request to the given node and waits until
// the node has sent up to k neighbors. // the node has sent up to k neighbors.
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) { func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
nodes := make([]*Node, 0, bucketSize) nodes := make([]*Node, 0, bucketSize)
nreceived := 0 nreceived := 0
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool { errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
reply := r.(*neighbors) reply := r.(*neighbors)
for _, n := range reply.Nodes { for _, n := range reply.Nodes {
nreceived++ nreceived++
@ -191,8 +221,7 @@ func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
} }
return nreceived >= bucketSize return nreceived >= bucketSize
}) })
t.send(toaddr, findnodePacket, findnode{
t.send(to, findnodePacket, findnode{
Target: target, Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
@ -214,6 +243,17 @@ func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-
return ch return ch
} }
func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
matched := make(chan bool)
select {
case t.gotreply <- reply{from, ptype, req, matched}:
// loop will handle it
return <-matched
case <-t.closing:
return false
}
}
// loop runs in its own goroutin. it keeps track of // loop runs in its own goroutin. it keeps track of
// the refresh timer and the pending reply queue. // the refresh timer and the pending reply queue.
func (t *udp) loop() { func (t *udp) loop() {
@ -244,6 +284,7 @@ func (t *udp) loop() {
for _, p := range pending { for _, p := range pending {
p.errc <- errClosed p.errc <- errClosed
} }
pending = nil
return return
case p := <-t.addpending: case p := <-t.addpending:
@ -251,18 +292,21 @@ func (t *udp) loop() {
pending = append(pending, p) pending = append(pending, p)
rearmTimeout() rearmTimeout()
case reply := <-t.replies: case r := <-t.gotreply:
// run matching callbacks, remove if they return false. var matched bool
for i := 0; i < len(pending); i++ { for i := 0; i < len(pending); i++ {
p := pending[i] if p := pending[i]; p.from == r.from && p.ptype == r.ptype {
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) { matched = true
p.errc <- nil if p.callback(r.data) {
copy(pending[i:], pending[i+1:]) // callback indicates the request is done, remove it.
pending = pending[:len(pending)-1] p.errc <- nil
i-- copy(pending[i:], pending[i+1:])
pending = pending[:len(pending)-1]
i--
}
} }
} }
rearmTimeout() r.matched <- matched
case now := <-timeout.C: case now := <-timeout.C:
// notify and remove callbacks whose deadline is in the past. // notify and remove callbacks whose deadline is in the past.
@ -287,28 +331,11 @@ const (
var headSpace = make([]byte, headSize) var headSpace = make([]byte, headSize)
func (t *udp) send(to *Node, ptype byte, req interface{}) error { func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
b := new(bytes.Buffer) packet, err := encodePacket(t.priv, ptype, req)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err)
return err
}
packet := b.Bytes()
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
if err != nil { if err != nil {
log.Errorln("could not sign packet:", err)
return err return err
} }
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// the future.
copy(packet, crypto.Sha3(packet[macSize:]))
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req) log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
log.DebugDetailln("UDP send failed:", err) log.DebugDetailln("UDP send failed:", err)
@ -316,6 +343,28 @@ func (t *udp) send(to *Node, ptype byte, req interface{}) error {
return err return err
} }
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil {
log.Errorln("error encoding packet:", err)
return nil, err
}
packet := b.Bytes()
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
if err != nil {
log.Errorln("could not sign packet:", err)
return nil, err
}
copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in
// The future.
copy(packet, crypto.Sha3(packet[macSize:]))
return packet, nil
}
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop() { func (t *udp) readLoop() {
defer t.conn.Close() defer t.conn.Close()
@ -325,29 +374,34 @@ func (t *udp) readLoop() {
if err != nil { if err != nil {
return return
} }
if err := t.packetIn(from, buf[:nbytes]); err != nil { packet, fromID, hash, err := decodePacket(buf[:nbytes])
if err != nil {
log.Debugf("Bad packet from %v: %v\n", from, err) log.Debugf("Bad packet from %v: %v\n", from, err)
continue
} }
log.DebugDetailf("<<< %v %T %v\n", from, packet, packet)
go func() {
if err := packet.handle(t, from, fromID, hash); err != nil {
log.Debugf("error handling %T from %v: %v", packet, from, err)
}
}()
} }
} }
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error { func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
if len(buf) < headSize+1 { if len(buf) < headSize+1 {
return errPacketTooSmall return nil, NodeID{}, nil, errPacketTooSmall
} }
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
shouldhash := crypto.Sha3(buf[macSize:]) shouldhash := crypto.Sha3(buf[macSize:])
if !bytes.Equal(hash, shouldhash) { if !bytes.Equal(hash, shouldhash) {
return errBadHash return nil, NodeID{}, nil, errBadHash
} }
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
if err != nil { if err != nil {
return err return nil, NodeID{}, hash, err
}
var req interface {
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
} }
var req packet
switch ptype := sigdata[0]; ptype { switch ptype := sigdata[0]; ptype {
case pingPacket: case pingPacket:
req = new(ping) req = new(ping)
@ -358,31 +412,27 @@ func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
case neighborsPacket: case neighborsPacket:
req = new(neighbors) req = new(neighbors)
default: default:
return fmt.Errorf("unknown type: %d", ptype) return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
} }
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil { err = rlp.Decode(bytes.NewReader(sigdata[1:]), req)
return err return req, fromID, hash, err
}
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
return req.handle(t, from, fromID, hash)
} }
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
t.mutex.Lock() if req.Version != Version {
// Note: we're ignoring the provided IP address right now return errBadVersion
n := t.bumpOrAdd(fromID, from)
if req.Port != 0 {
n.TCPPort = int(req.Port)
} }
t.mutex.Unlock() t.send(from, pongPacket, pong{
t.send(n, pongPacket, pong{
ReplyTok: mac, ReplyTok: mac,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
if !t.handleReply(fromID, pingPacket, req) {
// Note: we're ignoring the provided IP address right now
t.bond(true, fromID, from, req.Port)
}
return nil return nil
} }
@ -390,11 +440,9 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) er
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
t.mutex.Lock() if !t.handleReply(fromID, pongPacket, req) {
t.bump(fromID) return errUnsolicitedReply
t.mutex.Unlock() }
t.replies <- reply{fromID, pongPacket, req}
return nil return nil
} }
@ -402,12 +450,21 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
if t.db.get(fromID) == nil {
// No bond exists, we don't process the packet. This prevents
// an attack vector where the discovery protocol could be used
// to amplify traffic in a DDOS attack. A malicious actor
// would send a findnode request with the IP address and UDP
// port of the target as the source address. The recipient of
// the findnode packet would then send a neighbors packet
// (which is a much bigger packet than findnode) to the victim.
return errUnknownNode
}
t.mutex.Lock() t.mutex.Lock()
e := t.bumpOrAdd(fromID, from)
closest := t.closest(req.Target, bucketSize).entries closest := t.closest(req.Target, bucketSize).entries
t.mutex.Unlock() t.mutex.Unlock()
t.send(e, neighborsPacket, neighbors{ t.send(from, neighborsPacket, neighbors{
Nodes: closest, Nodes: closest,
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}) })
@ -418,12 +475,9 @@ func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byt
if expired(req.Expiration) { if expired(req.Expiration) {
return errExpired return errExpired
} }
t.mutex.Lock() if !t.handleReply(fromID, neighborsPacket, req) {
t.bump(fromID) return errUnsolicitedReply
t.add(req.Nodes) }
t.mutex.Unlock()
t.replies <- reply{fromID, neighborsPacket, req}
return nil return nil
} }

@ -1,10 +1,18 @@
package discover package discover
import ( import (
"bytes"
"crypto/ecdsa"
"errors"
"fmt" "fmt"
"io"
logpkg "log" logpkg "log"
"net" "net"
"os" "os"
"path"
"reflect"
"runtime"
"sync"
"testing" "testing"
"time" "time"
@ -15,22 +23,243 @@ func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel)) logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
} }
func TestUDP_ping(t *testing.T) { type udpTest struct {
t *testing.T
pipe *dgramPipe
table *Table
udp *udp
sent [][]byte
localkey, remotekey *ecdsa.PrivateKey
remoteaddr *net.UDPAddr
}
func newUDPTest(t *testing.T) *udpTest {
test := &udpTest{
t: t,
pipe: newpipe(),
localkey: newkey(),
remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303},
}
test.table, test.udp = newUDP(test.localkey, test.pipe, nil)
return test
}
// handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
enc, err := encodePacket(test.remotekey, ptype, data)
if err != nil {
return test.errorf("packet (%d) encode error: %v", err)
}
test.sent = append(test.sent, enc)
err = data.handle(test.udp, test.remoteaddr, PubkeyID(&test.remotekey.PublicKey), enc[:macSize])
if err != wantError {
return test.errorf("error mismatch: got %q, want %q", err, wantError)
}
return nil
}
// waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) error {
dgram := test.pipe.waitPacketOut()
p, _, _, err := decodePacket(dgram)
if err != nil {
return test.errorf("sent packet decode error: %v", err)
}
fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype {
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
}
fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil
}
func (test *udpTest) errorf(format string, args ...interface{}) error {
_, file, line, ok := runtime.Caller(2) // errorf + waitPacketOut
if ok {
file = path.Base(file)
} else {
file = "???"
line = 1
}
err := fmt.Errorf(format, args...)
fmt.Printf("\t%s:%d: %v\n", file, line, err)
test.t.Fail()
return err
}
// shared test variables
var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
)
func TestUDP_packetErrors(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
test.packetIn(errExpired, pingPacket, &ping{IP: "foo", Port: 99, Version: Version})
test.packetIn(errBadVersion, pingPacket, &ping{IP: "foo", Port: 99, Version: 99, Expiration: futureExp})
test.packetIn(errUnsolicitedReply, pongPacket, &pong{ReplyTok: []byte{}, Expiration: futureExp})
test.packetIn(errUnknownNode, findnodePacket, &findnode{Expiration: futureExp})
test.packetIn(errUnsolicitedReply, neighborsPacket, &neighbors{Expiration: futureExp})
}
func TestUDP_pingTimeout(t *testing.T) {
t.Parallel() t.Parallel()
test := newUDPTest(t)
defer test.table.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) toid := NodeID{1, 2, 3, 4}
defer n1.Close() if err := test.udp.ping(toid, toaddr); err != errTimeout {
defer n2.Close() t.Error("expected timeout error, got", err)
}
}
if err := n1.net.ping(n2.self); err != nil { func TestUDP_findnodeTimeout(t *testing.T) {
t.Fatalf("ping error: %v", err) t.Parallel()
test := newUDPTest(t)
defer test.table.Close()
toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222}
toid := NodeID{1, 2, 3, 4}
target := NodeID{4, 5, 6, 7}
result, err := test.udp.findnode(toid, toaddr, target)
if err != errTimeout {
t.Error("expected timeout error, got", err)
} }
if find(n2, n1.self.ID) == nil { if len(result) > 0 {
t.Errorf("node 2 does not contain id of node 1") t.Error("expected empty result, got", result)
} }
if e := find(n1, n2.self.ID); e != nil { }
t.Errorf("node 1 does contains id of node 2: %v", e)
func TestUDP_findnode(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
// put a few nodes into the table. their exact
// distribution shouldn't matter much, altough we need to
// take care not to overflow any bucket.
target := testTarget
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
nodes.push(&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(test.table.self.ID, i+2),
}, bucketSize)
}
test.table.add(nodes.entries)
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
test.table.db.add(PubkeyID(&test.remotekey.PublicKey), test.remoteaddr, 99)
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
test.waitPacketOut(func(p *neighbors) {
expected := test.table.closest(testTarget, bucketSize)
if len(p.Nodes) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
for i := range p.Nodes {
if p.Nodes[i].ID != expected.entries[i].ID {
t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, p.Nodes[i], expected.entries[i])
}
}
})
}
func TestUDP_findnodeMultiReply(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
// queue a pending findnode request
resultc, errc := make(chan []*Node), make(chan error)
go func() {
rid := PubkeyID(&test.remotekey.PublicKey)
ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget)
if err != nil && len(ns) == 0 {
errc <- err
} else {
resultc <- ns
}
}()
// wait for the findnode to be sent.
// after it is sent, the transport is waiting for a reply
test.waitPacketOut(func(p *findnode) {
if p.Target != testTarget {
t.Errorf("wrong target: got %v, want %v", p.Target, testTarget)
}
})
// send the reply as two packets.
list := []*Node{
MustParseNode("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303"),
MustParseNode("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"),
MustParseNode("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301"),
MustParseNode("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"),
}
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[:2]})
test.packetIn(nil, neighborsPacket, &neighbors{Expiration: futureExp, Nodes: list[2:]})
// check that the sent neighbors are all returned by findnode
select {
case result := <-resultc:
if !reflect.DeepEqual(result, list) {
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list)
}
case err := <-errc:
t.Errorf("findnode error: %v", err)
case <-time.After(5 * time.Second):
t.Error("findnode did not return within 5 seconds")
}
}
func TestUDP_successfulPing(t *testing.T) {
test := newUDPTest(t)
defer test.table.Close()
done := make(chan struct{})
go func() {
test.packetIn(nil, pingPacket, &ping{IP: "foo", Port: 99, Version: Version, Expiration: futureExp})
close(done)
}()
// the ping is replied to.
test.waitPacketOut(func(p *pong) {
pinghash := test.sent[0][:macSize]
if !bytes.Equal(p.ReplyTok, pinghash) {
t.Errorf("got ReplyTok %x, want %x", p.ReplyTok, pinghash)
}
})
// remote is unknown, the table pings back.
test.waitPacketOut(func(p *ping) error { return nil })
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp})
// ping should return shortly after getting the pong packet.
<-done
// check that the node was added.
rid := PubkeyID(&test.remotekey.PublicKey)
rnode := find(test.table, rid)
if rnode == nil {
t.Fatalf("node %v not found in table", rid)
}
if !bytes.Equal(rnode.IP, test.remoteaddr.IP) {
t.Errorf("node has wrong IP: got %v, want: %v", rnode.IP, test.remoteaddr.IP)
}
if rnode.DiscPort != test.remoteaddr.Port {
t.Errorf("node has wrong Port: got %v, want: %v", rnode.DiscPort, test.remoteaddr.Port)
}
if rnode.TCPPort != 99 {
t.Errorf("node has wrong Port: got %v, want: %v", rnode.TCPPort, 99)
} }
} }
@ -45,167 +274,66 @@ func find(tab *Table, id NodeID) *Node {
return nil return nil
} }
func TestUDP_findnode(t *testing.T) { // dgramPipe is a fake UDP socket. It queues all sent datagrams.
t.Parallel() type dgramPipe struct {
mu *sync.Mutex
cond *sync.Cond
closing chan struct{}
closed bool
queue [][]byte
}
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) func newpipe() *dgramPipe {
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) mu := new(sync.Mutex)
defer n1.Close() return &dgramPipe{
defer n2.Close() closing: make(chan struct{}),
cond: &sync.Cond{L: mu},
// put a few nodes into n2. the exact distribution shouldn't mu: mu,
// matter much, altough we need to take care not to overflow
// any bucket.
target := randomID(n1.self.ID, 100)
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
n2.add([]*Node{&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(n2.self.ID, i+2),
}})
}
n2.add(nodes.entries)
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
expected := n2.closest(target, bucketSize)
err := runUDP(10, func() error {
result, _ := n1.net.findnode(n2.self, target)
if len(result) != bucketSize {
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
}
for i := range result {
if result[i].ID != expected.entries[i].ID {
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
}
}
return nil
})
if err != nil {
t.Error(err)
} }
} }
func TestUDP_replytimeout(t *testing.T) { // WriteToUDP queues a datagram.
t.Parallel() func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) {
msg := make([]byte, len(b))
// reserve a port so we don't talk to an existing service by accident copy(msg, b)
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") c.mu.Lock()
fd, err := net.ListenUDP("udp", addr) defer c.mu.Unlock()
if err != nil { if c.closed {
t.Fatal(err) return 0, errors.New("closed")
}
defer fd.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
if err := n1.net.ping(n2); err != errTimeout {
t.Error("expected timeout error, got", err)
}
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
t.Error("expected timeout error, got", err)
} else if len(result) > 0 {
t.Error("expected empty result, got", result)
} }
c.queue = append(c.queue, msg)
c.cond.Signal()
return len(b), nil
} }
func TestUDP_findnodeMultiReply(t *testing.T) { // ReadFromUDP just hangs until the pipe is closed.
t.Parallel() func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
<-c.closing
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil) return 0, nil, io.EOF
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
udp2 := n2.net.(*udp)
defer n1.Close()
defer n2.Close()
err := runUDP(10, func() error {
nodes := make([]*Node, bucketSize)
for i := range nodes {
nodes[i] = &Node{
IP: net.IP{1, 2, 3, 4},
DiscPort: i + 1,
TCPPort: i + 1,
ID: randomID(n2.self.ID, i+1),
}
}
// ask N2 for neighbors. it will send an empty reply back.
// the request will wait for up to bucketSize replies.
resultc := make(chan []*Node)
errc := make(chan error)
go func() {
ns, err := n1.net.findnode(n2.self, n1.self.ID)
if err != nil {
errc <- err
} else {
resultc <- ns
}
}()
// send a few more neighbors packets to N1.
// it should collect those.
for end := 0; end < len(nodes); {
off := end
if end = end + 5; end > len(nodes) {
end = len(nodes)
}
udp2.send(n1.self, neighborsPacket, neighbors{
Nodes: nodes[off:end],
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
})
}
// check that they are all returned. we cannot just check for
// equality because they might not be returned in the order they
// were sent.
var result []*Node
select {
case result = <-resultc:
case err := <-errc:
return err
}
if hasDuplicates(result) {
return fmt.Errorf("result slice contains duplicates")
}
if len(result) != len(nodes) {
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
}
matched := make(map[NodeID]bool)
for _, n := range result {
for _, expn := range nodes {
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
matched[n.ID] = true
}
}
}
if len(matched) != len(nodes) {
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
}
return nil
})
if err != nil {
t.Error(err)
}
} }
// runUDP runs a test n times and returns an error if the test failed func (c *dgramPipe) Close() error {
// in all n runs. This is necessary because UDP is unreliable even for c.mu.Lock()
// connections on the local machine, causing test failures. defer c.mu.Unlock()
func runUDP(n int, test func() error) error { if !c.closed {
errcount := 0 close(c.closing)
errors := "" c.closed = true
for i := 0; i < n; i++ {
if err := test(); err != nil {
errors += fmt.Sprintf("\n#%d: %v", i, err)
errcount++
}
}
if errcount == n {
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
} }
return nil return nil
} }
func (c *dgramPipe) LocalAddr() net.Addr {
return &net.UDPAddr{}
}
func (c *dgramPipe) waitPacketOut() []byte {
c.mu.Lock()
defer c.mu.Unlock()
for len(c.queue) == 0 {
c.cond.Wait()
}
p := c.queue[0]
copy(c.queue, c.queue[1:])
c.queue = c.queue[:len(c.queue)-1]
return p
}