p2p, p2p/discover: add signed ENR generation (#17753)

This PR adds enode.LocalNode and integrates it into the p2p
subsystem. This new object is the keeper of the local node
record. For now, a new version of the record is produced every
time the client restarts. We'll make it smarter to avoid that in
the future.

There are a couple of other changes in this commit: discovery now
waits for all of its goroutines at shutdown and the p2p server
now closes the node database after discovery has shut down. This
fixes a leveldb crash in tests. p2p server startup is faster
because it doesn't need to wait for the external IP query
anymore.
This commit is contained in:
Felix Lange 2018-10-12 11:47:24 +02:00 committed by GitHub
parent dcae0d348b
commit 6f607de5d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 979 additions and 278 deletions

@ -119,16 +119,17 @@ func main() {
} }
if *runv5 { if *runv5 {
if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil { if _, err := discv5.ListenUDP(nodeKey, conn, "", restrictList); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} else { } else {
db, _ := enode.OpenDB("")
ln := enode.NewLocalNode(db, nodeKey)
cfg := discover.Config{ cfg := discover.Config{
PrivateKey: nodeKey, PrivateKey: nodeKey,
AnnounceAddr: realaddr,
NetRestrict: restrictList, NetRestrict: restrictList,
} }
if _, err := discover.ListenUDP(conn, cfg); err != nil { if _, err := discover.ListenUDP(conn, ln, cfg); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} }

@ -454,9 +454,9 @@ func TestProtocolGather(t *testing.T) {
Count int Count int
Maker InstrumentingWrapper Maker InstrumentingWrapper
}{ }{
"Zero Protocols": {0, InstrumentedServiceMakerA}, "zero": {0, InstrumentedServiceMakerA},
"Single Protocol": {1, InstrumentedServiceMakerB}, "one": {1, InstrumentedServiceMakerB},
"Many Protocols": {25, InstrumentedServiceMakerC}, "many": {10, InstrumentedServiceMakerC},
} }
for id, config := range services { for id, config := range services {
protocols := make([]p2p.Protocol, config.Count) protocols := make([]p2p.Protocol, config.Count)
@ -480,7 +480,7 @@ func TestProtocolGather(t *testing.T) {
defer stack.Stop() defer stack.Stop()
protocols := stack.Server().Protocols protocols := stack.Server().Protocols
if len(protocols) != 26 { if len(protocols) != 11 {
t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26) t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26)
} }
for id, config := range services { for id, config := range services {

@ -71,6 +71,7 @@ type dialstate struct {
maxDynDials int maxDynDials int
ntab discoverTable ntab discoverTable
netrestrict *netutil.Netlist netrestrict *netutil.Netlist
self enode.ID
lookupRunning bool lookupRunning bool
dialing map[enode.ID]connFlag dialing map[enode.ID]connFlag
@ -84,7 +85,6 @@ type dialstate struct {
} }
type discoverTable interface { type discoverTable interface {
Self() *enode.Node
Close() Close()
Resolve(*enode.Node) *enode.Node Resolve(*enode.Node) *enode.Node
LookupRandom() []*enode.Node LookupRandom() []*enode.Node
@ -126,10 +126,11 @@ type waitExpireTask struct {
time.Duration time.Duration
} }
func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab, ntab: ntab,
self: self,
netrestrict: netrestrict, netrestrict: netrestrict,
static: make(map[enode.ID]*dialTask), static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag), dialing: make(map[enode.ID]connFlag),
@ -266,7 +267,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errAlreadyDialing return errAlreadyDialing
case peers[n.ID()] != nil: case peers[n.ID()] != nil:
return errAlreadyConnected return errAlreadyConnected
case s.ntab != nil && n.ID() == s.ntab.Self().ID(): case n.ID() == s.self:
return errSelf return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted return errNotWhitelisted

@ -89,7 +89,7 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
// This test checks that dynamic dials are launched from discovery results. // This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) { func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, nil, fakeTable{}, 5, nil), init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
rounds: []round{ rounds: []round{
// A discovery query is launched. // A discovery query is launched.
{ {
@ -236,7 +236,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
newNode(uintID(8), nil), newNode(uintID(8), nil),
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, bootnodes, table, 5, nil), init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
rounds: []round{ rounds: []round{
// 2 dynamic dials attempted, bootnodes pending fallback interval // 2 dynamic dials attempted, bootnodes pending fallback interval
{ {
@ -324,7 +324,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, nil, table, 10, nil), init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
rounds: []round{ rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{ {
@ -430,7 +430,7 @@ func TestDialStateNetRestrict(t *testing.T) {
restrict.Add("127.0.2.0/24") restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, nil, table, 10, restrict), init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
rounds: []round{ rounds: []round{
{ {
new: []task{ new: []task{
@ -453,7 +453,7 @@ func TestDialStateStaticDial(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -557,7 +557,7 @@ func TestDialStaticAfterReset(t *testing.T) {
}, },
} }
dTest := dialtest{ dTest := dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: rounds, rounds: rounds,
} }
runDialTest(t, dTest) runDialTest(t, dTest)
@ -578,7 +578,7 @@ func TestDialStateCache(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil), init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -640,7 +640,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) { func TestDialResolve(t *testing.T) {
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
table := &resolveMock{answer: resolved} table := &resolveMock{answer: resolved}
state := newDialState(nil, nil, table, 0, nil) state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
// Check that the task is generated with an incomplete ID. // Check that the task is generated with an incomplete ID.
dest := newNode(uintID(1), nil) dest := newNode(uintID(1), nil)

@ -72,21 +72,20 @@ type Table struct {
ips netutil.DistinctNetSet ips netutil.DistinctNetSet
db *enode.DB // database of known nodes db *enode.DB // database of known nodes
net transport
refreshReq chan chan struct{} refreshReq chan chan struct{}
initDone chan struct{} initDone chan struct{}
closeReq chan struct{} closeReq chan struct{}
closed chan struct{} closed chan struct{}
nodeAddedHook func(*node) // for testing nodeAddedHook func(*node) // for testing
net transport
self *node // metadata of the local node
} }
// transport is implemented by the UDP transport. // transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP // it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key. // sockets and without generating a private key.
type transport interface { type transport interface {
self() *enode.Node
ping(enode.ID, *net.UDPAddr) error ping(enode.ID, *net.UDPAddr) error
findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error) findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error)
close() close()
@ -100,11 +99,10 @@ type bucket struct {
ips netutil.DistinctNetSet ips netutil.DistinctNetSet
} }
func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) { func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
tab := &Table{ tab := &Table{
net: t, net: t,
db: db, db: db,
self: wrapNode(self),
refreshReq: make(chan chan struct{}), refreshReq: make(chan chan struct{}),
initDone: make(chan struct{}), initDone: make(chan struct{}),
closeReq: make(chan struct{}), closeReq: make(chan struct{}),
@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No
return tab, nil return tab, nil
} }
func (tab *Table) self() *enode.Node {
return tab.net.self()
}
func (tab *Table) seedRand() { func (tab *Table) seedRand() {
var b [8]byte var b [8]byte
crand.Read(b[:]) crand.Read(b[:])
@ -136,11 +138,6 @@ func (tab *Table) seedRand() {
tab.mutex.Unlock() tab.mutex.Unlock()
} }
// Self returns the local node.
func (tab *Table) Self() *enode.Node {
return unwrapNode(tab.self)
}
// ReadRandomNodes fills the given slice with random nodes from the table. The results // ReadRandomNodes fills the given slice with random nodes from the table. The results
// are guaranteed to be unique for a single invocation, no node will appear twice. // are guaranteed to be unique for a single invocation, no node will appear twice.
func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) { func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
// Close terminates the network listener and flushes the node database. // Close terminates the network listener and flushes the node database.
func (tab *Table) Close() { func (tab *Table) Close() {
if tab.net != nil {
tab.net.close()
}
select { select {
case <-tab.closed: case <-tab.closed:
// already closed. // already closed.
@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
) )
// don't query further if we hit ourself. // don't query further if we hit ourself.
// unlikely to happen often in practice. // unlikely to happen often in practice.
asked[tab.self.ID()] = true asked[tab.self().ID()] = true
for { for {
tab.mutex.Lock() tab.mutex.Lock()
@ -340,8 +341,8 @@ func (tab *Table) loop() {
revalidate = time.NewTimer(tab.nextRevalidateTime()) revalidate = time.NewTimer(tab.nextRevalidateTime())
refresh = time.NewTicker(refreshInterval) refresh = time.NewTicker(refreshInterval)
copyNodes = time.NewTicker(copyNodesInterval) copyNodes = time.NewTicker(copyNodesInterval)
revalidateDone = make(chan struct{})
refreshDone = make(chan struct{}) // where doRefresh reports completion refreshDone = make(chan struct{}) // where doRefresh reports completion
revalidateDone chan struct{} // where doRevalidate reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
) )
defer refresh.Stop() defer refresh.Stop()
@ -372,9 +373,11 @@ loop:
} }
waiting, refreshDone = nil, nil waiting, refreshDone = nil, nil
case <-revalidate.C: case <-revalidate.C:
revalidateDone = make(chan struct{})
go tab.doRevalidate(revalidateDone) go tab.doRevalidate(revalidateDone)
case <-revalidateDone: case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime()) revalidate.Reset(tab.nextRevalidateTime())
revalidateDone = nil
case <-copyNodes.C: case <-copyNodes.C:
go tab.copyLiveNodes() go tab.copyLiveNodes()
case <-tab.closeReq: case <-tab.closeReq:
@ -382,15 +385,15 @@ loop:
} }
} }
if tab.net != nil {
tab.net.close()
}
if refreshDone != nil { if refreshDone != nil {
<-refreshDone <-refreshDone
} }
for _, ch := range waiting { for _, ch := range waiting {
close(ch) close(ch)
} }
if revalidateDone != nil {
<-revalidateDone
}
close(tab.closed) close(tab.closed)
} }
@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
// Run self lookup to discover new neighbor nodes. // Run self lookup to discover new neighbor nodes.
// We can only do this if we have a secp256k1 identity. // We can only do this if we have a secp256k1 identity.
var key ecdsa.PublicKey var key ecdsa.PublicKey
if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil { if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil {
tab.lookup(encodePubkey(&key), false) tab.lookup(encodePubkey(&key), false)
} }
@ -530,7 +533,7 @@ func (tab *Table) len() (n int) {
// bucket returns the bucket for the given node ID hash. // bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(id enode.ID) *bucket { func (tab *Table) bucket(id enode.ID) *bucket {
d := enode.LogDist(tab.self.ID(), id) d := enode.LogDist(tab.self().ID(), id)
if d <= bucketMinDistance { if d <= bucketMinDistance {
return tab.buckets[0] return tab.buckets[0]
} }
@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket {
// //
// The caller must not hold tab.mutex. // The caller must not hold tab.mutex.
func (tab *Table) add(n *node) { func (tab *Table) add(n *node) {
if n.ID() == tab.self.ID() { if n.ID() == tab.self().ID() {
return return
} }
@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) {
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
for _, n := range nodes { for _, n := range nodes {
if n.ID() == tab.self.ID() { if n.ID() == tab.self().ID() {
continue // don't add self continue // don't add self
} }
b := tab.bucket(n.ID()) b := tab.bucket(n.ID())

@ -141,7 +141,7 @@ func TestTable_IPLimit(t *testing.T) {
defer db.Close() defer db.Close()
for i := 0; i < tableIPLimit+1; i++ { for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.add(n) tab.add(n)
} }
if tab.len() > tableIPLimit { if tab.len() > tableIPLimit {
@ -158,7 +158,7 @@ func TestTable_BucketIPLimit(t *testing.T) {
d := 3 d := 3
for i := 0; i < bucketIPLimit+1; i++ { for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.add(n) tab.add(n)
} }
if tab.len() > bucketIPLimit { if tab.len() > bucketIPLimit {
@ -240,7 +240,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
for i := 0; i < len(buf); i++ { for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets)) ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))}) tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
} }
gotN := tab.ReadRandomNodes(buf) gotN := tab.ReadRandomNodes(buf)
if gotN != tab.len() { if gotN != tab.len() {
@ -510,6 +510,10 @@ type preminedTestnet struct {
dists [hashBits + 1][]encPubkey dists [hashBits + 1][]encPubkey
} }
func (tn *preminedTestnet) self() *enode.Node {
return nullNode
}
func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
// current log distance is encoded in port number // current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port) // fmt.Println("findnode query at dist", toaddr.Port)

@ -28,12 +28,17 @@ import (
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
) )
func newTestTable(t transport) (*Table, *enode.DB) { var nullNode *enode.Node
func init() {
var r enr.Record var r enr.Record
r.Set(enr.IP{0, 0, 0, 0}) r.Set(enr.IP{0, 0, 0, 0})
n := enode.SignNull(&r, enode.ID{}) nullNode = enode.SignNull(&r, enode.ID{})
}
func newTestTable(t transport) (*Table, *enode.DB) {
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
tab, _ := newTable(t, n, db, nil) tab, _ := newTable(t, db, nil)
return tab, db return tab, db
} }
@ -70,10 +75,10 @@ func intIP(i int) net.IP {
// fillBucket inserts nodes into the given bucket until it is full. // fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, n *node) (last *node) { func fillBucket(tab *Table, n *node) (last *node) {
ld := enode.LogDist(tab.self.ID(), n.ID()) ld := enode.LogDist(tab.self().ID(), n.ID())
b := tab.bucket(n.ID()) b := tab.bucket(n.ID())
for len(b.entries) < bucketSize { for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld))) b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld)))
} }
return b.entries[bucketSize-1] return b.entries[bucketSize-1]
} }
@ -81,15 +86,25 @@ func fillBucket(tab *Table, n *node) (last *node) {
type pingRecorder struct { type pingRecorder struct {
mu sync.Mutex mu sync.Mutex
dead, pinged map[enode.ID]bool dead, pinged map[enode.ID]bool
n *enode.Node
} }
func newPingRecorder() *pingRecorder { func newPingRecorder() *pingRecorder {
var r enr.Record
r.Set(enr.IP{0, 0, 0, 0})
n := enode.SignNull(&r, enode.ID{})
return &pingRecorder{ return &pingRecorder{
dead: make(map[enode.ID]bool), dead: make(map[enode.ID]bool),
pinged: make(map[enode.ID]bool), pinged: make(map[enode.ID]bool),
n: n,
} }
} }
func (t *pingRecorder) self() *enode.Node {
return nullNode
}
func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
return nil, nil return nil, nil
} }

@ -23,12 +23,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -118,9 +118,11 @@ type (
) )
func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
ip := addr.IP.To4() ip := net.IP{}
if ip == nil { if ip4 := addr.IP.To4(); ip4 != nil {
ip = addr.IP.To16() ip = ip4
} else if ip6 := addr.IP.To16(); ip6 != nil {
ip = ip6
} }
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
} }
@ -165,20 +167,19 @@ type conn interface {
LocalAddr() net.Addr LocalAddr() net.Addr
} }
// udp implements the RPC protocol. // udp implements the discovery v4 UDP wire protocol.
type udp struct { type udp struct {
conn conn conn conn
netrestrict *netutil.Netlist netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey priv *ecdsa.PrivateKey
ourEndpoint rpcEndpoint localNode *enode.LocalNode
db *enode.DB
tab *Table
wg sync.WaitGroup
addpending chan *pending addpending chan *pending
gotreply chan reply gotreply chan reply
closing chan struct{} closing chan struct{}
nat nat.Interface
*Table
} }
// pending represents a pending reply. // pending represents a pending reply.
@ -230,60 +231,57 @@ type Config struct {
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
// These settings are optional: // These settings are optional:
AnnounceAddr *net.UDPAddr // local address announced in the DHT
NodeDBPath string // if set, the node database is stored at this filesystem location
NetRestrict *netutil.Netlist // network whitelist NetRestrict *netutil.Netlist // network whitelist
Bootnodes []*enode.Node // list of bootstrap nodes Bootnodes []*enode.Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(c conn, cfg Config) (*Table, error) { func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
tab, _, err := newUDP(c, cfg) tab, _, err := newUDP(c, ln, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Info("UDP listener up", "self", tab.self)
return tab, nil return tab, nil
} }
func newUDP(c conn, cfg Config) (*Table, *udp, error) { func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
realaddr := c.LocalAddr().(*net.UDPAddr)
if cfg.AnnounceAddr != nil {
realaddr = cfg.AnnounceAddr
}
self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port)
db, err := enode.OpenDB(cfg.NodeDBPath)
if err != nil {
return nil, nil, err
}
udp := &udp{ udp := &udp{
conn: c, conn: c,
priv: cfg.PrivateKey, priv: cfg.PrivateKey,
netrestrict: cfg.NetRestrict, netrestrict: cfg.NetRestrict,
localNode: ln,
db: ln.Database(),
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply), gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
} }
// TODO: separate TCP port tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
tab, err := newTable(udp, self, db, cfg.Bootnodes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udp.Table = tab udp.tab = tab
udp.wg.Add(2)
go udp.loop() go udp.loop()
go udp.readLoop(cfg.Unhandled) go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil return udp.tab, udp, nil
}
func (t *udp) self() *enode.Node {
return t.localNode.Node()
} }
func (t *udp) close() { func (t *udp) close() {
close(t.closing) close(t.closing)
t.conn.Close() t.conn.Close()
t.db.Close() t.wg.Wait()
// TODO: wait for the loops to end. }
func (t *udp) ourEndpoint() rpcEndpoint {
n := t.self()
a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
return makeEndpoint(a, uint16(n.TCP()))
} }
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
@ -296,7 +294,7 @@ func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error {
func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error { func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error {
req := &ping{ req := &ping{
Version: 4, Version: 4,
From: t.ourEndpoint, From: t.ourEndpoint(),
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
} }
@ -313,6 +311,7 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
} }
return ok return ok
}) })
t.localNode.UDPContact(toaddr)
t.write(toaddr, req.name(), packet) t.write(toaddr, req.name(), packet)
return errc return errc
} }
@ -381,6 +380,8 @@ func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
// loop runs in its own goroutine. it keeps track of // loop runs in its own goroutine. it keeps track of
// the refresh timer and the pending reply queue. // the refresh timer and the pending reply queue.
func (t *udp) loop() { func (t *udp) loop() {
defer t.wg.Done()
var ( var (
plist = list.New() plist = list.New()
timeout = time.NewTimer(0) timeout = time.NewTimer(0)
@ -542,10 +543,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet,
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop(unhandled chan<- ReadPacket) { func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close() defer t.wg.Done()
if unhandled != nil { if unhandled != nil {
defer close(unhandled) defer close(unhandled)
} }
// Discovery packets are defined to be no larger than 1280 bytes. // Discovery packets are defined to be no larger than 1280 bytes.
// Packets larger than this size will be cut at the end and treated // Packets larger than this size will be cut at the end and treated
// as invalid because their hash won't match. // as invalid because their hash won't match.
@ -629,10 +631,11 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port)) n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
t.handleReply(n.ID(), pingPacket, req) t.handleReply(n.ID(), pingPacket, req)
if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration { if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
t.sendPing(n.ID(), from, func() { t.addThroughPing(n) }) t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
} else { } else {
t.addThroughPing(n) t.tab.addThroughPing(n)
} }
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPingReceived(n.ID(), time.Now()) t.db.UpdateLastPingReceived(n.ID(), time.Now())
return nil return nil
} }
@ -647,6 +650,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
if !t.handleReply(fromID, pongPacket, req) { if !t.handleReply(fromID, pongPacket, req) {
return errUnsolicitedReply return errUnsolicitedReply
} }
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
t.db.UpdateLastPongReceived(fromID, time.Now()) t.db.UpdateLastPongReceived(fromID, time.Now())
return nil return nil
} }
@ -668,9 +672,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
return errUnknownNode return errUnknownNode
} }
target := enode.ID(crypto.Keccak256Hash(req.Target[:])) target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
t.mutex.Lock() t.tab.mutex.Lock()
closest := t.closest(target, bucketSize).entries closest := t.tab.closest(target, bucketSize).entries
t.mutex.Unlock() t.tab.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool var sent bool

@ -71,7 +71,9 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
} }
test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey}) db, _ := enode.OpenDB("")
ln := enode.NewLocalNode(db, test.localkey)
test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
// Wait for initial refresh so the table doesn't send unexpected findnode. // Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone <-test.table.initDone
return test return test
@ -355,12 +357,13 @@ func TestUDP_successfulPing(t *testing.T) {
// remote is unknown, the table pings back. // remote is unknown, the table pings back.
hash, _ := test.waitPacketOut(func(p *ping) error { hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
} }
wantTo := rpcEndpoint{ wantTo := rpcEndpoint{
// The mirrored UDP address is the UDP packet sender. // The mirrored UDP address is the UDP packet sender.
IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), IP: test.remoteaddr.IP,
UDP: uint16(test.remoteaddr.Port),
TCP: 0, TCP: 0,
} }
if !reflect.DeepEqual(p.To, wantTo) { if !reflect.DeepEqual(p.To, wantTo) {

@ -230,7 +230,8 @@ type udp struct {
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { func ListenUDP(priv *ecdsa.PrivateKey, conn conn, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
realaddr := conn.LocalAddr().(*net.UDPAddr)
transport, err := listenUDP(priv, conn, realaddr) transport, err := listenUDP(priv, conn, realaddr)
if err != nil { if err != nil {
return nil, err return nil, err

246
p2p/enode/localnode.go Normal file

@ -0,0 +1,246 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package enode
import (
"crypto/ecdsa"
"fmt"
"net"
"reflect"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
)
const (
// IP tracker configuration
iptrackMinStatements = 10
iptrackWindow = 5 * time.Minute
iptrackContactWindow = 10 * time.Minute
)
// LocalNode produces the signed node record of a local node, i.e. a node run in the
// current process. Setting ENR entries via the Set method updates the record. A new version
// of the record is signed on demand when the Node method is called.
type LocalNode struct {
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
id ID
key *ecdsa.PrivateKey
db *DB
// everything below is protected by a lock
mu sync.Mutex
seq uint64
entries map[string]enr.Entry
udpTrack *netutil.IPTracker // predicts external UDP endpoint
staticIP net.IP
fallbackIP net.IP
fallbackUDP int
}
// NewLocalNode creates a local node.
func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
ln := &LocalNode{
id: PubkeyToIDV4(&key.PublicKey),
db: db,
key: key,
udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
entries: make(map[string]enr.Entry),
}
ln.seq = db.localSeq(ln.id)
ln.invalidate()
return ln
}
// Database returns the node database associated with the local node.
func (ln *LocalNode) Database() *DB {
return ln.db
}
// Node returns the current version of the local node record.
func (ln *LocalNode) Node() *Node {
n := ln.cur.Load().(*Node)
if n != nil {
return n
}
// Record was invalidated, sign a new copy.
ln.mu.Lock()
defer ln.mu.Unlock()
ln.sign()
return ln.cur.Load().(*Node)
}
// ID returns the local node ID.
func (ln *LocalNode) ID() ID {
return ln.id
}
// Set puts the given entry into the local record, overwriting
// any existing value.
func (ln *LocalNode) Set(e enr.Entry) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.set(e)
}
func (ln *LocalNode) set(e enr.Entry) {
val, exists := ln.entries[e.ENRKey()]
if !exists || !reflect.DeepEqual(val, e) {
ln.entries[e.ENRKey()] = e
ln.invalidate()
}
}
// Delete removes the given entry from the local record.
func (ln *LocalNode) Delete(e enr.Entry) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.delete(e)
}
func (ln *LocalNode) delete(e enr.Entry) {
_, exists := ln.entries[e.ENRKey()]
if exists {
delete(ln.entries, e.ENRKey())
ln.invalidate()
}
}
// SetStaticIP sets the local IP to the given one unconditionally.
// This disables endpoint prediction.
func (ln *LocalNode) SetStaticIP(ip net.IP) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.staticIP = ip
ln.updateEndpoints()
}
// SetFallbackIP sets the last-resort IP address. This address is used
// if no endpoint prediction can be made and no static IP is set.
func (ln *LocalNode) SetFallbackIP(ip net.IP) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.fallbackIP = ip
ln.updateEndpoints()
}
// SetFallbackUDP sets the last-resort UDP port. This port is used
// if no endpoint prediction can be made.
func (ln *LocalNode) SetFallbackUDP(port int) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.fallbackUDP = port
ln.updateEndpoints()
}
// UDPEndpointStatement should be called whenever a statement about the local node's
// UDP endpoint is received. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String())
ln.updateEndpoints()
}
// UDPContact should be called whenever the local node has announced itself to another node
// via UDP. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.udpTrack.AddContact(toaddr.String())
ln.updateEndpoints()
}
func (ln *LocalNode) updateEndpoints() {
// Determine the endpoints.
newIP := ln.fallbackIP
newUDP := ln.fallbackUDP
if ln.staticIP != nil {
newIP = ln.staticIP
} else if ip, port := predictAddr(ln.udpTrack); ip != nil {
newIP = ip
newUDP = port
}
// Update the record.
if newIP != nil && !newIP.IsUnspecified() {
ln.set(enr.IP(newIP))
if newUDP != 0 {
ln.set(enr.UDP(newUDP))
} else {
ln.delete(enr.UDP(0))
}
} else {
ln.delete(enr.IP{})
}
}
// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
// endpoint representation to IP and port types.
func predictAddr(t *netutil.IPTracker) (net.IP, int) {
ep := t.PredictEndpoint()
if ep == "" {
return nil, 0
}
ipString, portString, _ := net.SplitHostPort(ep)
ip := net.ParseIP(ipString)
port, _ := strconv.Atoi(portString)
return ip, port
}
func (ln *LocalNode) invalidate() {
ln.cur.Store((*Node)(nil))
}
func (ln *LocalNode) sign() {
if n := ln.cur.Load().(*Node); n != nil {
return // no changes
}
var r enr.Record
for _, e := range ln.entries {
r.Set(e)
}
ln.bumpSeq()
r.SetSeq(ln.seq)
if err := SignV4(&r, ln.key); err != nil {
panic(fmt.Errorf("enode: can't sign record: %v", err))
}
n, err := New(ValidSchemes, &r)
if err != nil {
panic(fmt.Errorf("enode: can't verify local record: %v", err))
}
ln.cur.Store(n)
log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
}
func (ln *LocalNode) bumpSeq() {
ln.seq++
ln.db.storeLocalSeq(ln.id, ln.seq)
}

@ -0,0 +1,76 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package enode
import (
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enr"
)
func newLocalNodeForTesting() (*LocalNode, *DB) {
db, _ := OpenDB("")
key, _ := crypto.GenerateKey()
return NewLocalNode(db, key), db
}
func TestLocalNode(t *testing.T) {
ln, db := newLocalNodeForTesting()
defer db.Close()
if ln.Node().ID() != ln.ID() {
t.Fatal("inconsistent ID")
}
ln.Set(enr.WithEntry("x", uint(3)))
var x uint
if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil {
t.Fatal("can't load entry 'x':", err)
} else if x != 3 {
t.Fatal("wrong value for entry 'x':", x)
}
}
func TestLocalNodeSeqPersist(t *testing.T) {
ln, db := newLocalNodeForTesting()
defer db.Close()
if s := ln.Node().Seq(); s != 1 {
t.Fatalf("wrong initial seq %d, want 1", s)
}
ln.Set(enr.WithEntry("x", uint(1)))
if s := ln.Node().Seq(); s != 2 {
t.Fatalf("wrong seq %d after set, want 2", s)
}
// Create a new instance, it should reload the sequence number.
// The number increases just after that because a new record is
// created without the "x" entry.
ln2 := NewLocalNode(db, ln.key)
if s := ln2.Node().Seq(); s != 3 {
t.Fatalf("wrong seq %d on new instance, want 3", s)
}
// Create a new instance with a different node key on the same database.
// This should reset the sequence number.
key, _ := crypto.GenerateKey()
ln3 := NewLocalNode(db, key)
if s := ln3.Node().Seq(); s != 1 {
t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
}
}

@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey {
return &key return &key
} }
// Record returns the node's record. The return value is a copy and may
// be modified by the caller.
func (n *Node) Record() *enr.Record {
cpy := n.r
return &cpy
}
// checks whether n is a valid complete node. // checks whether n is a valid complete node.
func (n *Node) ValidateComplete() error { func (n *Node) ValidateComplete() error {
if n.Incomplete() { if n.Incomplete() {

@ -35,11 +35,24 @@ import (
"github.com/syndtr/goleveldb/leveldb/util" "github.com/syndtr/goleveldb/leveldb/util"
) )
// Keys in the node database.
const (
dbVersionKey = "version" // Version of the database to flush if changes
dbItemPrefix = "n:" // Identifier to prefix node entries with
dbDiscoverRoot = ":discover"
dbDiscoverSeq = dbDiscoverRoot + ":seq"
dbDiscoverPing = dbDiscoverRoot + ":lastping"
dbDiscoverPong = dbDiscoverRoot + ":lastpong"
dbDiscoverFindFails = dbDiscoverRoot + ":findfail"
dbLocalRoot = ":local"
dbLocalSeq = dbLocalRoot + ":seq"
)
var ( var (
nodeDBNilID = ID{} // Special node ID to use as a nil element. dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped. dbCleanupCycle = time.Hour // Time period for running the expiration task.
nodeDBCleanupCycle = time.Hour // Time period for running the expiration task. dbVersion = 7
nodeDBVersion = 6
) )
// DB is the node database, storing previously seen nodes and any collected metadata about // DB is the node database, storing previously seen nodes and any collected metadata about
@ -50,17 +63,6 @@ type DB struct {
quit chan struct{} // Channel to signal the expiring thread to stop quit chan struct{} // Channel to signal the expiring thread to stop
} }
// Schema layout for the node database
var (
nodeDBVersionKey = []byte("version") // Version of the database to flush if changes
nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with
nodeDBDiscoverRoot = ":discover"
nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping"
nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong"
nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
)
// OpenDB opens a node database for storing and retrieving infos about known peers in the // OpenDB opens a node database for storing and retrieving infos about known peers in the
// network. If no path is given an in-memory, temporary database is constructed. // network. If no path is given an in-memory, temporary database is constructed.
func OpenDB(path string) (*DB, error) { func OpenDB(path string) (*DB, error) {
@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) {
// The nodes contained in the cache correspond to a certain protocol version. // The nodes contained in the cache correspond to a certain protocol version.
// Flush all nodes if the version doesn't match. // Flush all nodes if the version doesn't match.
currentVer := make([]byte, binary.MaxVarintLen64) currentVer := make([]byte, binary.MaxVarintLen64)
currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))] currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
blob, err := db.Get(nodeDBVersionKey, nil) blob, err := db.Get([]byte(dbVersionKey), nil)
switch err { switch err {
case leveldb.ErrNotFound: case leveldb.ErrNotFound:
// Version not found (i.e. empty cache), insert it // Version not found (i.e. empty cache), insert it
if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil { if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil {
db.Close() db.Close()
return nil, err return nil, err
} }
@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) {
// makeKey generates the leveldb key-blob from a node id and its particular // makeKey generates the leveldb key-blob from a node id and its particular
// field of interest. // field of interest.
func makeKey(id ID, field string) []byte { func makeKey(id ID, field string) []byte {
if bytes.Equal(id[:], nodeDBNilID[:]) { if (id == ID{}) {
return []byte(field) return []byte(field)
} }
return append(nodeDBItemPrefix, append(id[:], field...)...) return append([]byte(dbItemPrefix), append(id[:], field...)...)
} }
// splitKey tries to split a database key into a node id and a field part. // splitKey tries to split a database key into a node id and a field part.
func splitKey(key []byte) (id ID, field string) { func splitKey(key []byte) (id ID, field string) {
// If the key is not of a node, return it plainly // If the key is not of a node, return it plainly
if !bytes.HasPrefix(key, nodeDBItemPrefix) { if !bytes.HasPrefix(key, []byte(dbItemPrefix)) {
return ID{}, string(key) return ID{}, string(key)
} }
// Otherwise split the id and field // Otherwise split the id and field
item := key[len(nodeDBItemPrefix):] item := key[len(dbItemPrefix):]
copy(id[:], item[:len(id)]) copy(id[:], item[:len(id)])
field = string(item[len(id):]) field = string(item[len(id):])
return id, field return id, field
} }
// fetchInt64 retrieves an integer instance associated with a particular // fetchInt64 retrieves an integer associated with a particular key.
// database key.
func (db *DB) fetchInt64(key []byte) int64 { func (db *DB) fetchInt64(key []byte) int64 {
blob, err := db.lvl.Get(key, nil) blob, err := db.lvl.Get(key, nil)
if err != nil { if err != nil {
@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 {
return val return val
} }
// storeInt64 update a specific database entry to the current time instance as a // storeInt64 stores an integer in the given key.
// unix timestamp.
func (db *DB) storeInt64(key []byte, n int64) error { func (db *DB) storeInt64(key []byte, n int64) error {
blob := make([]byte, binary.MaxVarintLen64) blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutVarint(blob, n)] blob = blob[:binary.PutVarint(blob, n)]
return db.lvl.Put(key, blob, nil)
}
// fetchUint64 retrieves an integer associated with a particular key.
func (db *DB) fetchUint64(key []byte) uint64 {
blob, err := db.lvl.Get(key, nil)
if err != nil {
return 0
}
val, _ := binary.Uvarint(blob)
return val
}
// storeUint64 stores an integer in the given key.
func (db *DB) storeUint64(key []byte, n uint64) error {
blob := make([]byte, binary.MaxVarintLen64)
blob = blob[:binary.PutUvarint(blob, n)]
return db.lvl.Put(key, blob, nil) return db.lvl.Put(key, blob, nil)
} }
// Node retrieves a node with a given id from the database. // Node retrieves a node with a given id from the database.
func (db *DB) Node(id ID) *Node { func (db *DB) Node(id ID) *Node {
blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil) blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil)
if err != nil { if err != nil {
return nil return nil
} }
@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node {
// UpdateNode inserts - potentially overwriting - a node into the peer database. // UpdateNode inserts - potentially overwriting - a node into the peer database.
func (db *DB) UpdateNode(node *Node) error { func (db *DB) UpdateNode(node *Node) error {
if node.Seq() < db.NodeSeq(node.ID()) {
return nil
}
blob, err := rlp.EncodeToBytes(&node.r) blob, err := rlp.EncodeToBytes(&node.r)
if err != nil { if err != nil {
return err return err
} }
return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil) if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil {
return err
}
return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq())
}
// NodeSeq returns the stored record sequence number of the given node.
func (db *DB) NodeSeq(id ID) uint64 {
return db.fetchUint64(makeKey(id, dbDiscoverSeq))
}
// Resolve returns the stored record of the node if it has a larger sequence
// number than n.
func (db *DB) Resolve(n *Node) *Node {
if n.Seq() > db.NodeSeq(n.ID()) {
return n
}
return db.Node(n.ID())
} }
// DeleteNode deletes all information/keys associated with a node. // DeleteNode deletes all information/keys associated with a node.
@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() {
// expirer should be started in a go routine, and is responsible for looping ad // expirer should be started in a go routine, and is responsible for looping ad
// infinitum and dropping stale data from the database. // infinitum and dropping stale data from the database.
func (db *DB) expirer() { func (db *DB) expirer() {
tick := time.NewTicker(nodeDBCleanupCycle) tick := time.NewTicker(dbCleanupCycle)
defer tick.Stop() defer tick.Stop()
for { for {
select { select {
@ -235,7 +271,7 @@ func (db *DB) expirer() {
// expireNodes iterates over the database and deletes all nodes that have not // expireNodes iterates over the database and deletes all nodes that have not
// been seen (i.e. received a pong from) for some allotted time. // been seen (i.e. received a pong from) for some allotted time.
func (db *DB) expireNodes() error { func (db *DB) expireNodes() error {
threshold := time.Now().Add(-nodeDBNodeExpiration) threshold := time.Now().Add(-dbNodeExpiration)
// Find discovered nodes that are older than the allowance // Find discovered nodes that are older than the allowance
it := db.lvl.NewIterator(nil, nil) it := db.lvl.NewIterator(nil, nil)
@ -244,7 +280,7 @@ func (db *DB) expireNodes() error {
for it.Next() { for it.Next() {
// Skip the item if not a discovery node // Skip the item if not a discovery node
id, field := splitKey(it.Key()) id, field := splitKey(it.Key())
if field != nodeDBDiscoverRoot { if field != dbDiscoverRoot {
continue continue
} }
// Skip the node if not expired yet (and not self) // Skip the node if not expired yet (and not self)
@ -260,34 +296,44 @@ func (db *DB) expireNodes() error {
// LastPingReceived retrieves the time of the last ping packet received from // LastPingReceived retrieves the time of the last ping packet received from
// a remote node. // a remote node.
func (db *DB) LastPingReceived(id ID) time.Time { func (db *DB) LastPingReceived(id ID) time.Time {
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0) return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0)
} }
// UpdateLastPingReceived updates the last time we tried contacting a remote node. // UpdateLastPingReceived updates the last time we tried contacting a remote node.
func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error { func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix()) return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix())
} }
// LastPongReceived retrieves the time of the last successful pong from remote node. // LastPongReceived retrieves the time of the last successful pong from remote node.
func (db *DB) LastPongReceived(id ID) time.Time { func (db *DB) LastPongReceived(id ID) time.Time {
// Launch expirer // Launch expirer
db.ensureExpirer() db.ensureExpirer()
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0) return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0)
} }
// UpdateLastPongReceived updates the last pong time of a node. // UpdateLastPongReceived updates the last pong time of a node.
func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error { func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix()) return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix())
} }
// FindFails retrieves the number of findnode failures since bonding. // FindFails retrieves the number of findnode failures since bonding.
func (db *DB) FindFails(id ID) int { func (db *DB) FindFails(id ID) int {
return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails))) return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails)))
} }
// UpdateFindFails updates the number of findnode failures since bonding. // UpdateFindFails updates the number of findnode failures since bonding.
func (db *DB) UpdateFindFails(id ID, fails int) error { func (db *DB) UpdateFindFails(id ID, fails int) error {
return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails)) return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails))
}
// LocalSeq retrieves the local record sequence counter.
func (db *DB) localSeq(id ID) uint64 {
return db.fetchUint64(makeKey(id, dbLocalSeq))
}
// storeLocalSeq stores the local record sequence counter.
func (db *DB) storeLocalSeq(id ID, n uint64) {
db.storeUint64(makeKey(id, dbLocalSeq), n)
} }
// QuerySeeds retrieves random nodes to be used as potential seed nodes // QuerySeeds retrieves random nodes to be used as potential seed nodes
@ -309,7 +355,7 @@ seek:
ctr := id[0] ctr := id[0]
rand.Read(id[:]) rand.Read(id[:])
id[0] = ctr + id[0]%16 id[0] = ctr + id[0]%16
it.Seek(makeKey(id, nodeDBDiscoverRoot)) it.Seek(makeKey(id, dbDiscoverRoot))
n := nextNode(it) n := nextNode(it)
if n == nil { if n == nil {
@ -334,7 +380,7 @@ seek:
func nextNode(it iterator.Iterator) *Node { func nextNode(it iterator.Iterator) *Node {
for end := false; !end; end = !it.Next() { for end := false; !end; end = !it.Next() {
id, field := splitKey(it.Key()) id, field := splitKey(it.Key())
if field != nodeDBDiscoverRoot { if field != dbDiscoverRoot {
continue continue
} }
return mustDecodeNode(id[:], it.Value()) return mustDecodeNode(id[:], it.Value())

@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct {
30303, 30303,
30303, 30303,
), ),
pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute), pong: time.Now().Add(-dbNodeExpiration + time.Minute),
exp: false, exp: false,
}, { }, {
node: NewV4( node: NewV4(
@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct {
30303, 30303,
30303, 30303,
), ),
pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute), pong: time.Now().Add(-dbNodeExpiration - time.Minute),
exp: true, exp: true,
}, },
} }

@ -156,7 +156,7 @@ func (r *Record) Set(e Entry) {
} }
func (r *Record) invalidate() { func (r *Record) invalidate() {
if r.signature == nil { if r.signature != nil {
r.seq++ r.seq++
} }
r.signature = nil r.signature = nil

@ -169,6 +169,18 @@ func TestDirty(t *testing.T) {
} }
} }
func TestSeq(t *testing.T) {
var r Record
assert.Equal(t, uint64(0), r.Seq())
r.Set(UDP(1))
assert.Equal(t, uint64(0), r.Seq())
signTest([]byte{5}, &r)
assert.Equal(t, uint64(0), r.Seq())
r.Set(UDP(2))
assert.Equal(t, uint64(1), r.Seq())
}
// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
func TestGetSetOverwrite(t *testing.T) { func TestGetSetOverwrite(t *testing.T) {
var r Record var r Record

@ -129,21 +129,15 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na
// ExtIP assumes that the local machine is reachable on the given // ExtIP assumes that the local machine is reachable on the given
// external IP address, and that any required ports were mapped manually. // external IP address, and that any required ports were mapped manually.
// Mapping operations will not return an error but won't actually do anything. // Mapping operations will not return an error but won't actually do anything.
func ExtIP(ip net.IP) Interface { type ExtIP net.IP
if ip == nil {
panic("IP must not be nil")
}
return extIP(ip)
}
type extIP net.IP func (n ExtIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
func (n ExtIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
// These do nothing. // These do nothing.
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
func (extIP) DeleteMapping(string, int, int) error { return nil } func (ExtIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
func (ExtIP) DeleteMapping(string, int, int) error { return nil }
// Any returns a port mapper that tries to discover any supported // Any returns a port mapper that tries to discover any supported
// mechanism on the local network. // mechanism on the local network.

@ -28,7 +28,7 @@ import (
func TestAutoDiscRace(t *testing.T) { func TestAutoDiscRace(t *testing.T) {
ad := startautodisc("thing", func() Interface { ad := startautodisc("thing", func() Interface {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
return extIP{33, 44, 55, 66} return ExtIP{33, 44, 55, 66}
}) })
// Spawn a few concurrent calls to ad.ExternalIP. // Spawn a few concurrent calls to ad.ExternalIP.

130
p2p/netutil/iptrack.go Normal file

@ -0,0 +1,130 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import (
"time"
"github.com/ethereum/go-ethereum/common/mclock"
)
// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host
// based on statements made by other hosts.
type IPTracker struct {
window time.Duration
contactWindow time.Duration
minStatements int
clock mclock.Clock
statements map[string]ipStatement
contact map[string]mclock.AbsTime
lastStatementGC mclock.AbsTime
lastContactGC mclock.AbsTime
}
type ipStatement struct {
endpoint string
time mclock.AbsTime
}
// NewIPTracker creates an IP tracker.
//
// The window parameters configure the amount of past network events which are kept. The
// minStatements parameter enforces a minimum number of statements which must be recorded
// before any prediction is made. Higher values for these parameters decrease 'flapping' of
// predictions as network conditions change. Window duration values should typically be in
// the range of minutes.
func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker {
return &IPTracker{
window: window,
contactWindow: contactWindow,
statements: make(map[string]ipStatement),
minStatements: minStatements,
contact: make(map[string]mclock.AbsTime),
clock: mclock.System{},
}
}
// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by
// checking whether any statement has been received from a node we didn't contact before
// the statement was made.
func (it *IPTracker) PredictFullConeNAT() bool {
now := it.clock.Now()
it.gcContact(now)
it.gcStatements(now)
for host, st := range it.statements {
if c, ok := it.contact[host]; !ok || c > st.time {
return true
}
}
return false
}
// PredictEndpoint returns the current prediction of the external endpoint.
func (it *IPTracker) PredictEndpoint() string {
it.gcStatements(it.clock.Now())
// The current strategy is simple: find the endpoint with most statements.
counts := make(map[string]int)
maxcount, max := 0, ""
for _, s := range it.statements {
c := counts[s.endpoint] + 1
counts[s.endpoint] = c
if c > maxcount && c >= it.minStatements {
maxcount, max = c, s.endpoint
}
}
return max
}
// AddStatement records that a certain host thinks our external endpoint is the one given.
func (it *IPTracker) AddStatement(host, endpoint string) {
now := it.clock.Now()
it.statements[host] = ipStatement{endpoint, now}
if time.Duration(now-it.lastStatementGC) >= it.window {
it.gcStatements(now)
}
}
// AddContact records that a packet containing our endpoint information has been sent to a
// certain host.
func (it *IPTracker) AddContact(host string) {
now := it.clock.Now()
it.contact[host] = now
if time.Duration(now-it.lastContactGC) >= it.contactWindow {
it.gcContact(now)
}
}
func (it *IPTracker) gcStatements(now mclock.AbsTime) {
it.lastStatementGC = now
cutoff := now.Add(-it.window)
for host, s := range it.statements {
if s.time < cutoff {
delete(it.statements, host)
}
}
}
func (it *IPTracker) gcContact(now mclock.AbsTime) {
it.lastContactGC = now
cutoff := now.Add(-it.contactWindow)
for host, ct := range it.contact {
if ct < cutoff {
delete(it.contact, host)
}
}
}

138
p2p/netutil/iptrack_test.go Normal file

@ -0,0 +1,138 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import (
"fmt"
mrand "math/rand"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
)
const (
opStatement = iota
opContact
opPredict
opCheckFullCone
)
type iptrackTestEvent struct {
op int
time int // absolute, in milliseconds
ip, from string
}
func TestIPTracker(t *testing.T) {
tests := map[string][]iptrackTestEvent{
"minStatements": {
{opPredict, 0, "", ""},
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
{opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.3"},
{opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.4"},
{opPredict, 1000, "127.0.0.1", ""},
},
"window": {
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
{opStatement, 2000, "127.0.0.1", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
{opPredict, 10000, "127.0.0.1", ""},
{opPredict, 10001, "", ""}, // first statement expired
{opStatement, 10100, "127.0.0.1", "127.0.0.2"},
{opPredict, 10200, "127.0.0.1", ""},
},
"fullcone": {
{opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
{opContact, 3000, "", "127.0.0.4"},
{opStatement, 3010, "127.0.0.1", "127.0.0.4"},
{opCheckFullCone, 3500, "false", ""},
},
"fullcone_2": {
{opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
{opContact, 3010, "", "127.0.0.4"},
{opCheckFullCone, 3500, "true", ""},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) })
}
}
func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
var (
clock mclock.Simulated
it = NewIPTracker(10*time.Second, 10*time.Second, 3)
)
it.clock = &clock
for i, ev := range evs {
evtime := time.Duration(ev.time) * time.Millisecond
clock.Run(evtime - time.Duration(clock.Now()))
switch ev.op {
case opStatement:
it.AddStatement(ev.from, ev.ip)
case opContact:
it.AddContact(ev.from)
case opPredict:
if pred := it.PredictEndpoint(); pred != ev.ip {
t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip)
}
case opCheckFullCone:
pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
if pred != ev.ip {
t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip)
}
}
}
}
// This checks that old statements and contacts are GCed even if Predict* isn't called.
func TestIPTrackerForceGC(t *testing.T) {
var (
clock mclock.Simulated
window = 10 * time.Second
rate = 50 * time.Millisecond
max = int(window/rate) + 1
it = NewIPTracker(window, window, 3)
)
it.clock = &clock
for i := 0; i < 5*max; i++ {
e1 := make([]byte, 4)
e2 := make([]byte, 4)
mrand.Read(e1)
mrand.Read(e2)
it.AddStatement(string(e1), string(e2))
it.AddContact(string(e1))
clock.Run(rate)
}
if len(it.contact) > 2*max {
t.Errorf("contacts not GCed, have %d", len(it.contact))
}
if len(it.statements) > 2*max {
t.Errorf("statements not GCed, have %d", len(it.statements))
}
}

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
) )
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
@ -52,6 +53,9 @@ type Protocol struct {
// about a certain peer in the network. If an info retrieval function is set, // about a certain peer in the network. If an info retrieval function is set,
// but returns nil, it is assumed that the protocol handshake is still running. // but returns nil, it is assumed that the protocol handshake is still running.
PeerInfo func(id enode.ID) interface{} PeerInfo func(id enode.ID) interface{}
// Attributes contains protocol specific information for the node record.
Attributes []enr.Entry
} }
func (p Protocol) cap() Cap { func (p Protocol) cap() Cap {
@ -64,10 +68,6 @@ type Cap struct {
Version uint Version uint
} }
func (cap Cap) RlpData() interface{} {
return []interface{}{cap.Name, cap.Version}
}
func (cap Cap) String() string { func (cap Cap) String() string {
return fmt.Sprintf("%s/%d", cap.Name, cap.Version) return fmt.Sprintf("%s/%d", cap.Name, cap.Version)
} }
@ -79,3 +79,5 @@ func (cs capsByNameAndVersion) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
func (cs capsByNameAndVersion) Less(i, j int) bool { func (cs capsByNameAndVersion) Less(i, j int) bool {
return cs[i].Name < cs[j].Name || (cs[i].Name == cs[j].Name && cs[i].Version < cs[j].Version) return cs[i].Name < cs[j].Name || (cs[i].Name == cs[j].Name && cs[i].Version < cs[j].Version)
} }
func (capsByNameAndVersion) ENRKey() string { return "cap" }

@ -20,9 +20,11 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa" "crypto/ecdsa"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sort"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -35,8 +37,10 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
) )
const ( const (
@ -160,6 +164,8 @@ type Server struct {
lock sync.Mutex // protects running lock sync.Mutex // protects running
running bool running bool
nodedb *enode.DB
localnode *enode.LocalNode
ntab discoverTable ntab discoverTable
listener net.Listener listener net.Listener
ourHandshake *protoHandshake ourHandshake *protoHandshake
@ -347,43 +353,13 @@ func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription {
// Self returns the local node's endpoint information. // Self returns the local node's endpoint information.
func (srv *Server) Self() *enode.Node { func (srv *Server) Self() *enode.Node {
srv.lock.Lock() srv.lock.Lock()
running, listener, ntab := srv.running, srv.listener, srv.ntab ln := srv.localnode
srv.lock.Unlock() srv.lock.Unlock()
if !running { if ln == nil {
return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0) return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0)
} }
return srv.makeSelf(listener, ntab) return ln.Node()
}
func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node {
// If the node is running but discovery is off, manually assemble the node infos.
if ntab == nil {
addr := srv.tcpAddr(listener)
return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0)
}
// Otherwise return the discovery node.
return ntab.Self()
}
func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr {
addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}}
if listener == nil {
return addr // Inbound connections disabled, use zero address.
}
// Otherwise inject the listener address too.
if a, ok := listener.Addr().(*net.TCPAddr); ok {
addr = *a
}
if srv.NAT != nil {
if ip, err := srv.NAT.ExternalIP(); err == nil {
addr.IP = ip
}
}
if addr.IP.IsUnspecified() {
addr.IP = net.IP{127, 0, 0, 1}
}
return addr
} }
// Stop terminates the server and all active peer connections. // Stop terminates the server and all active peer connections.
@ -443,7 +419,9 @@ func (srv *Server) Start() (err error) {
if srv.log == nil { if srv.log == nil {
srv.log = log.New() srv.log = log.New()
} }
srv.log.Info("Starting P2P networking") if srv.NoDial && srv.ListenAddr == "" {
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
}
// static fields // static fields
if srv.PrivateKey == nil { if srv.PrivateKey == nil {
@ -466,65 +444,120 @@ func (srv *Server) Start() (err error) {
srv.peerOp = make(chan peerOpFunc) srv.peerOp = make(chan peerOpFunc)
srv.peerOpDone = make(chan struct{}) srv.peerOpDone = make(chan struct{})
var ( if err := srv.setupLocalNode(); err != nil {
conn *net.UDPConn return err
sconn *sharedUDPConn }
realaddr *net.UDPAddr if srv.ListenAddr != "" {
unhandled chan discover.ReadPacket if err := srv.setupListening(); err != nil {
) return err
}
}
if err := srv.setupDiscovery(); err != nil {
return err
}
dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
srv.loopWG.Add(1)
go srv.run(dialer)
return nil
}
func (srv *Server) setupLocalNode() error {
// Create the devp2p handshake.
pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
sort.Sort(capsByNameAndVersion(srv.ourHandshake.Caps))
// Create the local node.
db, err := enode.OpenDB(srv.Config.NodeDatabase)
if err != nil {
return err
}
srv.nodedb = db
srv.localnode = enode.NewLocalNode(db, srv.PrivateKey)
srv.localnode.SetFallbackIP(net.IP{127, 0, 0, 1})
srv.localnode.Set(capsByNameAndVersion(srv.ourHandshake.Caps))
// TODO: check conflicts
for _, p := range srv.Protocols {
for _, e := range p.Attributes {
srv.localnode.Set(e)
}
}
switch srv.NAT.(type) {
case nil:
// No NAT interface, do nothing.
case nat.ExtIP:
// ExtIP doesn't block, set the IP right away.
ip, _ := srv.NAT.ExternalIP()
srv.localnode.SetStaticIP(ip)
default:
// Ask the router about the IP. This takes a while and blocks startup,
// do it in the background.
srv.loopWG.Add(1)
go func() {
defer srv.loopWG.Done()
if ip, err := srv.NAT.ExternalIP(); err == nil {
srv.localnode.SetStaticIP(ip)
}
}()
}
return nil
}
func (srv *Server) setupDiscovery() error {
if srv.NoDiscovery && !srv.DiscoveryV5 {
return nil
}
if !srv.NoDiscovery || srv.DiscoveryV5 {
addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr) addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
if err != nil { if err != nil {
return err return err
} }
conn, err = net.ListenUDP("udp", addr) conn, err := net.ListenUDP("udp", addr)
if err != nil { if err != nil {
return err return err
} }
realaddr = conn.LocalAddr().(*net.UDPAddr) realaddr := conn.LocalAddr().(*net.UDPAddr)
srv.log.Debug("UDP listener up", "addr", realaddr)
if srv.NAT != nil { if srv.NAT != nil {
if !realaddr.IP.IsLoopback() { if !realaddr.IP.IsLoopback() {
go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
} }
// TODO: react to external IP changes over time.
if ext, err := srv.NAT.ExternalIP(); err == nil {
realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
}
}
} }
srv.localnode.SetFallbackUDP(realaddr.Port)
if !srv.NoDiscovery && srv.DiscoveryV5 { // Discovery V4
var unhandled chan discover.ReadPacket
var sconn *sharedUDPConn
if !srv.NoDiscovery {
if srv.DiscoveryV5 {
unhandled = make(chan discover.ReadPacket, 100) unhandled = make(chan discover.ReadPacket, 100)
sconn = &sharedUDPConn{conn, unhandled} sconn = &sharedUDPConn{conn, unhandled}
} }
// node table
if !srv.NoDiscovery {
cfg := discover.Config{ cfg := discover.Config{
PrivateKey: srv.PrivateKey, PrivateKey: srv.PrivateKey,
AnnounceAddr: realaddr,
NodeDBPath: srv.NodeDatabase,
NetRestrict: srv.NetRestrict, NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes, Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled, Unhandled: unhandled,
} }
ntab, err := discover.ListenUDP(conn, cfg) ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
if err != nil { if err != nil {
return err return err
} }
srv.ntab = ntab srv.ntab = ntab
} }
// Discovery V5
if srv.DiscoveryV5 { if srv.DiscoveryV5 {
var ( var ntab *discv5.Network
ntab *discv5.Network var err error
err error
)
if sconn != nil { if sconn != nil {
ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, "", srv.NetRestrict)
} else { } else {
ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase) ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, "", srv.NetRestrict)
} }
if err != nil { if err != nil {
return err return err
@ -534,32 +567,10 @@ func (srv *Server) Start() (err error) {
} }
srv.DiscV5 = ntab srv.DiscV5 = ntab
} }
dynPeers := srv.maxDialedConns()
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake
pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
for _, p := range srv.Protocols {
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
}
// listen/dial
if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil {
return err
}
}
if srv.NoDial && srv.ListenAddr == "" {
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
}
srv.loopWG.Add(1)
go srv.run(dialer)
return nil return nil
} }
func (srv *Server) startListening() error { func (srv *Server) setupListening() error {
// Launch the TCP listener. // Launch the TCP listener.
listener, err := net.Listen("tcp", srv.ListenAddr) listener, err := net.Listen("tcp", srv.ListenAddr)
if err != nil { if err != nil {
@ -568,8 +579,11 @@ func (srv *Server) startListening() error {
laddr := listener.Addr().(*net.TCPAddr) laddr := listener.Addr().(*net.TCPAddr)
srv.ListenAddr = laddr.String() srv.ListenAddr = laddr.String()
srv.listener = listener srv.listener = listener
srv.localnode.Set(enr.TCP(laddr.Port))
srv.loopWG.Add(1) srv.loopWG.Add(1)
go srv.listenLoop() go srv.listenLoop()
// Map the TCP listening port if NAT is configured. // Map the TCP listening port if NAT is configured.
if !laddr.IP.IsLoopback() && srv.NAT != nil { if !laddr.IP.IsLoopback() && srv.NAT != nil {
srv.loopWG.Add(1) srv.loopWG.Add(1)
@ -589,7 +603,10 @@ type dialer interface {
} }
func (srv *Server) run(dialstate dialer) { func (srv *Server) run(dialstate dialer) {
srv.log.Info("Started P2P networking", "self", srv.localnode.Node())
defer srv.loopWG.Done() defer srv.loopWG.Done()
defer srv.nodedb.Close()
var ( var (
peers = make(map[enode.ID]*Peer) peers = make(map[enode.ID]*Peer)
inboundCount = 0 inboundCount = 0
@ -781,7 +798,7 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
return DiscTooManyPeers return DiscTooManyPeers
case peers[c.node.ID()] != nil: case peers[c.node.ID()] != nil:
return DiscAlreadyConnected return DiscAlreadyConnected
case c.node.ID() == srv.Self().ID(): case c.node.ID() == srv.localnode.ID():
return DiscSelf return DiscSelf
default: default:
return nil return nil
@ -802,15 +819,11 @@ func (srv *Server) maxDialedConns() int {
return srv.MaxPeers / r return srv.MaxPeers / r
} }
type tempError interface {
Temporary() bool
}
// listenLoop runs in its own goroutine and accepts // listenLoop runs in its own goroutine and accepts
// inbound connections. // inbound connections.
func (srv *Server) listenLoop() { func (srv *Server) listenLoop() {
defer srv.loopWG.Done() defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.Self()) srv.log.Debug("TCP listener up", "addr", srv.listener.Addr())
tokens := defaultMaxPendingPeers tokens := defaultMaxPendingPeers
if srv.MaxPendingPeers > 0 { if srv.MaxPendingPeers > 0 {
@ -831,7 +844,7 @@ func (srv *Server) listenLoop() {
) )
for { for {
fd, err = srv.listener.Accept() fd, err = srv.listener.Accept()
if tempErr, ok := err.(tempError); ok && tempErr.Temporary() { if netutil.IsTemporaryError(err) {
srv.log.Debug("Temporary read error", "err", err) srv.log.Debug("Temporary read error", "err", err)
continue continue
} else if err != nil { } else if err != nil {
@ -864,10 +877,6 @@ func (srv *Server) listenLoop() {
// as a peer. It returns when the connection has been added as a peer // as a peer. It returns when the connection has been added as a peer
// or the handshakes have failed. // or the handshakes have failed.
func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
self := srv.Self()
if self == nil {
return errors.New("shutdown")
}
c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
err := srv.setupConn(c, flags, dialDest) err := srv.setupConn(c, flags, dialDest)
if err != nil { if err != nil {
@ -1003,6 +1012,7 @@ type NodeInfo struct {
ID string `json:"id"` // Unique node identifier (also the encryption key) ID string `json:"id"` // Unique node identifier (also the encryption key)
Name string `json:"name"` // Name of the node, including client type, version, OS, custom data Name string `json:"name"` // Name of the node, including client type, version, OS, custom data
Enode string `json:"enode"` // Enode URL for adding this peer from remote peers Enode string `json:"enode"` // Enode URL for adding this peer from remote peers
ENR string `json:"enr"` // Ethereum Node Record
IP string `json:"ip"` // IP address of the node IP string `json:"ip"` // IP address of the node
Ports struct { Ports struct {
Discovery int `json:"discovery"` // UDP listening port for discovery protocol Discovery int `json:"discovery"` // UDP listening port for discovery protocol
@ -1014,9 +1024,8 @@ type NodeInfo struct {
// NodeInfo gathers and returns a collection of metadata known about the host. // NodeInfo gathers and returns a collection of metadata known about the host.
func (srv *Server) NodeInfo() *NodeInfo { func (srv *Server) NodeInfo() *NodeInfo {
node := srv.Self()
// Gather and assemble the generic node infos // Gather and assemble the generic node infos
node := srv.Self()
info := &NodeInfo{ info := &NodeInfo{
Name: srv.Name, Name: srv.Name,
Enode: node.String(), Enode: node.String(),
@ -1027,6 +1036,9 @@ func (srv *Server) NodeInfo() *NodeInfo {
} }
info.Ports.Discovery = node.UDP() info.Ports.Discovery = node.UDP()
info.Ports.Listener = node.TCP() info.Ports.Listener = node.TCP()
if enc, err := rlp.EncodeToBytes(node.Record()); err == nil {
info.ENR = "0x" + hex.EncodeToString(enc)
}
// Gather all the running protocol infos (only once per protocol type) // Gather all the running protocol infos (only once per protocol type)
for _, proto := range srv.Protocols { for _, proto := range srv.Protocols {

@ -225,8 +225,11 @@ func TestServerTaskScheduling(t *testing.T) {
// The Server in this test isn't actually running // The Server in this test isn't actually running
// because we're only interested in what run does. // because we're only interested in what run does.
db, _ := enode.OpenDB("")
srv := &Server{ srv := &Server{
Config: Config{MaxPeers: 10}, Config: Config{MaxPeers: 10},
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
quit: make(chan struct{}), quit: make(chan struct{}),
ntab: fakeTable{}, ntab: fakeTable{},
running: true, running: true,
@ -271,8 +274,11 @@ func TestServerManyTasks(t *testing.T) {
} }
var ( var (
db, _ = enode.OpenDB("")
srv = &Server{ srv = &Server{
quit: make(chan struct{}), quit: make(chan struct{}),
localnode: enode.NewLocalNode(db, newkey()),
nodedb: db,
ntab: fakeTable{}, ntab: fakeTable{},
running: true, running: true,
log: log.New(), log: log.New(),