diff --git a/cmd/devp2p/crawl.go b/cmd/devp2p/crawl.go index 7fefbd7a1c..9259b4894c 100644 --- a/cmd/devp2p/crawl.go +++ b/cmd/devp2p/crawl.go @@ -20,14 +20,13 @@ import ( "time" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" ) type crawler struct { input nodeSet output nodeSet - disc *discover.UDPv4 + disc resolver iters []enode.Iterator inputIter enode.Iterator ch chan *enode.Node @@ -37,7 +36,11 @@ type crawler struct { revalidateInterval time.Duration } -func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler { +type resolver interface { + RequestENR(*enode.Node) (*enode.Node, error) +} + +func newCrawler(input nodeSet, disc resolver, iters ...enode.Iterator) *crawler { c := &crawler{ input: input, output: make(nodeSet, len(input)), diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 9525bec668..8580c61216 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -81,6 +81,18 @@ var ( Name: "bootnodes", Usage: "Comma separated nodes used for bootstrapping", } + nodekeyFlag = cli.StringFlag{ + Name: "nodekey", + Usage: "Hex-encoded node key", + } + nodedbFlag = cli.StringFlag{ + Name: "nodedb", + Usage: "Nodes database location", + } + listenAddrFlag = cli.StringFlag{ + Name: "addr", + Usage: "Listening address", + } crawlTimeoutFlag = cli.DurationFlag{ Name: "timeout", Usage: "Time limit for the crawl.", @@ -172,6 +184,62 @@ func discv4Crawl(ctx *cli.Context) error { return nil } +// startV4 starts an ephemeral discovery V4 node. +func startV4(ctx *cli.Context) *discover.UDPv4 { + ln, config := makeDiscoveryConfig(ctx) + socket := listen(ln, ctx.String(listenAddrFlag.Name)) + disc, err := discover.ListenV4(socket, ln, config) + if err != nil { + exit(err) + } + return disc +} + +func makeDiscoveryConfig(ctx *cli.Context) (*enode.LocalNode, discover.Config) { + var cfg discover.Config + + if ctx.IsSet(nodekeyFlag.Name) { + key, err := crypto.HexToECDSA(ctx.String(nodekeyFlag.Name)) + if err != nil { + exit(fmt.Errorf("-%s: %v", nodekeyFlag.Name, err)) + } + cfg.PrivateKey = key + } else { + cfg.PrivateKey, _ = crypto.GenerateKey() + } + + if commandHasFlag(ctx, bootnodesFlag) { + bn, err := parseBootnodes(ctx) + if err != nil { + exit(err) + } + cfg.Bootnodes = bn + } + + dbpath := ctx.String(nodedbFlag.Name) + db, err := enode.OpenDB(dbpath) + if err != nil { + exit(err) + } + ln := enode.NewLocalNode(db, cfg.PrivateKey) + return ln, cfg +} + +func listen(ln *enode.LocalNode, addr string) *net.UDPConn { + if addr == "" { + addr = "0.0.0.0:0" + } + socket, err := net.ListenPacket("udp4", addr) + if err != nil { + exit(err) + } + usocket := socket.(*net.UDPConn) + uaddr := socket.LocalAddr().(*net.UDPAddr) + ln.SetFallbackIP(net.IP{127, 0, 0, 1}) + ln.SetFallbackUDP(uaddr.Port) + return usocket +} + func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { s := params.RinkebyBootnodes if ctx.IsSet(bootnodesFlag.Name) { @@ -187,40 +255,3 @@ func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { } return nodes, nil } - -// startV4 starts an ephemeral discovery V4 node. -func startV4(ctx *cli.Context) *discover.UDPv4 { - socket, ln, cfg, err := listen() - if err != nil { - exit(err) - } - if commandHasFlag(ctx, bootnodesFlag) { - bn, err := parseBootnodes(ctx) - if err != nil { - exit(err) - } - cfg.Bootnodes = bn - } - disc, err := discover.ListenV4(socket, ln, cfg) - if err != nil { - exit(err) - } - return disc -} - -func listen() (*net.UDPConn, *enode.LocalNode, discover.Config, error) { - var cfg discover.Config - cfg.PrivateKey, _ = crypto.GenerateKey() - db, _ := enode.OpenDB("") - ln := enode.NewLocalNode(db, cfg.PrivateKey) - - socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{0, 0, 0, 0}}) - if err != nil { - db.Close() - return nil, nil, cfg, err - } - addr := socket.LocalAddr().(*net.UDPAddr) - ln.SetFallbackIP(net.IP{127, 0, 0, 1}) - ln.SetFallbackUDP(addr.Port) - return socket, ln, cfg, nil -} diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go new file mode 100644 index 0000000000..f871821ea2 --- /dev/null +++ b/cmd/devp2p/discv5cmd.go @@ -0,0 +1,123 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/p2p/discover" + "gopkg.in/urfave/cli.v1" +) + +var ( + discv5Command = cli.Command{ + Name: "discv5", + Usage: "Node Discovery v5 tools", + Subcommands: []cli.Command{ + discv5PingCommand, + discv5ResolveCommand, + discv5CrawlCommand, + discv5ListenCommand, + }, + } + discv5PingCommand = cli.Command{ + Name: "ping", + Usage: "Sends ping to a node", + Action: discv5Ping, + } + discv5ResolveCommand = cli.Command{ + Name: "resolve", + Usage: "Finds a node in the DHT", + Action: discv5Resolve, + Flags: []cli.Flag{bootnodesFlag}, + } + discv5CrawlCommand = cli.Command{ + Name: "crawl", + Usage: "Updates a nodes.json file with random nodes found in the DHT", + Action: discv5Crawl, + Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, + } + discv5ListenCommand = cli.Command{ + Name: "listen", + Usage: "Runs a node", + Action: discv5Listen, + Flags: []cli.Flag{ + bootnodesFlag, + nodekeyFlag, + nodedbFlag, + listenAddrFlag, + }, + } +) + +func discv5Ping(ctx *cli.Context) error { + n := getNodeArg(ctx) + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Ping(n)) + return nil +} + +func discv5Resolve(ctx *cli.Context) error { + n := getNodeArg(ctx) + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Resolve(n)) + return nil +} + +func discv5Crawl(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + nodesFile := ctx.Args().First() + var inputSet nodeSet + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) + } + + disc := startV5(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, disc.RandomNodes()) + c.revalidateInterval = 10 * time.Minute + output := c.run(ctx.Duration(crawlTimeoutFlag.Name)) + writeNodesJSON(nodesFile, output) + return nil +} + +func discv5Listen(ctx *cli.Context) error { + disc := startV5(ctx) + defer disc.Close() + + fmt.Println(disc.Self()) + select {} +} + +// startV5 starts an ephemeral discovery v5 node. +func startV5(ctx *cli.Context) *discover.UDPv5 { + ln, config := makeDiscoveryConfig(ctx) + socket := listen(ln, ctx.String(listenAddrFlag.Name)) + disc, err := discover.ListenV5(socket, ln, config) + if err != nil { + exit(err) + } + return disc +} diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go index b895941f25..19aec77ed4 100644 --- a/cmd/devp2p/main.go +++ b/cmd/devp2p/main.go @@ -59,6 +59,7 @@ func init() { app.Commands = []cli.Command{ enrdumpCommand, discv4Command, + discv5Command, dnsCommand, nodesetCommand, } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index cef6a9fc4f..3708bfb72c 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -20,8 +20,10 @@ import ( "crypto/ecdsa" "net" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" ) @@ -39,10 +41,25 @@ type Config struct { PrivateKey *ecdsa.PrivateKey // These settings are optional: - NetRestrict *netutil.Netlist // network whitelist - Bootnodes []*enode.Node // list of bootstrap nodes - Unhandled chan<- ReadPacket // unhandled packets are sent on this channel - Log log.Logger // if set, log messages go here + NetRestrict *netutil.Netlist // network whitelist + Bootnodes []*enode.Node // list of bootstrap nodes + Unhandled chan<- ReadPacket // unhandled packets are sent on this channel + Log log.Logger // if set, log messages go here + ValidSchemes enr.IdentityScheme // allowed identity schemes + Clock mclock.Clock +} + +func (cfg Config) withDefaults() Config { + if cfg.Log == nil { + cfg.Log = log.Root() + } + if cfg.ValidSchemes == nil { + cfg.ValidSchemes = enode.ValidSchemes + } + if cfg.Clock == nil { + cfg.Clock = mclock.System{} + } + return cfg } // ListenUDP starts listening for discovery packets on the given UDP socket. @@ -51,8 +68,15 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { } // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled -// channel if configured. This is exported for internal use, do not use this type. +// channel if configured. type ReadPacket struct { Data []byte Addr *net.UDPAddr } + +func min(x, y int) int { + if x > y { + return y + } + return x +} diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index ab825fb05d..40b271e6d9 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -150,7 +150,7 @@ func (it *lookup) query(n *node, reply chan<- []*node) { } else if len(r) == 0 { fails++ it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) - it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) + it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "results", len(r), "err", err) if fails >= maxFindnodeFailures { it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) it.tab.delete(n) diff --git a/p2p/discover/node.go b/p2p/discover/node.go index a7d9ce7368..230638b6d1 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -18,6 +18,7 @@ package discover import ( "crypto/ecdsa" + "crypto/elliptic" "errors" "math/big" "net" @@ -45,13 +46,13 @@ func encodePubkey(key *ecdsa.PublicKey) encPubkey { return e } -func decodePubkey(e encPubkey) (*ecdsa.PublicKey, error) { - p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} +func decodePubkey(curve elliptic.Curve, e encPubkey) (*ecdsa.PublicKey, error) { + p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} half := len(e) / 2 p.X.SetBytes(e[:half]) p.Y.SetBytes(e[half:]) if !p.Curve.IsOnCurve(p.X, p.Y) { - return nil, errors.New("invalid secp256k1 curve point") + return nil, errors.New("invalid curve point") } return p, nil } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index e5a5793e35..6d48ab00cd 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -424,6 +424,10 @@ func (tab *Table) len() (n int) { // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(id enode.ID) *bucket { d := enode.LogDist(tab.self().ID(), id) + return tab.bucketAtDistance(d) +} + +func (tab *Table) bucketAtDistance(d int) *bucket { if d <= bucketMinDistance { return tab.buckets[0] } diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index e35e48c5e6..44b62e751b 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -24,7 +24,6 @@ import ( "fmt" "math/rand" "net" - "reflect" "sort" "sync" @@ -56,6 +55,23 @@ func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) } +// nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld. +func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node { + results := make([]*enode.Node, n) + for i := range results { + results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i))) + } + return results +} + +func nodesToRecords(nodes []*enode.Node) []*enr.Record { + records := make([]*enr.Record, len(nodes)) + for i := range nodes { + records[i] = nodes[i].Record() + } + return records +} + // idAtDistance returns a random hash such that enode.LogDist(a, b) == n func idAtDistance(a enode.ID, n int) (b enode.ID) { if n == 0 { @@ -173,9 +189,16 @@ func hasDuplicates(slice []*node) bool { } func checkNodesEqual(got, want []*enode.Node) error { - if reflect.DeepEqual(got, want) { - return nil + if len(got) == len(want) { + for i := range got { + if !nodeEqual(got[i], want[i]) { + goto NotEqual + } + return nil + } } + +NotEqual: output := new(bytes.Buffer) fmt.Fprintf(output, "got %d nodes:\n", len(got)) for _, n := range got { @@ -188,6 +211,10 @@ func checkNodesEqual(got, want []*enode.Node) error { return errors.New(output.String()) } +func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { + return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) +} + func sortByID(nodes []*enode.Node) { sort.Slice(nodes, func(i, j int) bool { return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 9b4042c5a2..83480d35e8 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" ) func TestUDPv4_Lookup(t *testing.T) { @@ -32,7 +33,7 @@ func TestUDPv4_Lookup(t *testing.T) { test := newUDPTest(t) // Lookup on empty table returns no nodes. - targetKey, _ := decodePubkey(lookupTestnet.target) + targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target) if results := test.udp.LookupPubkey(targetKey); len(results) > 0 { t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) } @@ -59,15 +60,7 @@ func TestUDPv4_Lookup(t *testing.T) { if len(results) != bucketSize { t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) } - if hasDuplicates(wrapNodes(results)) { - t.Errorf("result set contains duplicate entries") - } - if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) { - t.Errorf("result set not sorted by distance to target") - } - if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil { - t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err) - } + checkLookupResults(t, lookupTestnet, results) } func TestUDPv4_LookupIterator(t *testing.T) { @@ -156,6 +149,26 @@ func serveTestnet(test *udpTest, testnet *preminedTestnet) { } } +// checkLookupResults verifies that the results of a lookup are the closest nodes to +// the testnet's target. +func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node) { + t.Helper() + t.Logf("results:") + for _, e := range results { + t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes()) + } + if hasDuplicates(wrapNodes(results)) { + t.Errorf("result set contains duplicate entries") + } + if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) { + t.Errorf("result set not sorted by distance to target") + } + wantNodes := tn.closest(len(results)) + if err := checkNodesEqual(results, wantNodes); err != nil { + t.Error(err) + } +} + // This is the test network for the Lookup test. // The nodes were obtained by running lookupTestnet.mine with a random NodeID as target. var lookupTestnet = &preminedTestnet{ @@ -242,8 +255,12 @@ func (tn *preminedTestnet) nodes() []*enode.Node { func (tn *preminedTestnet) node(dist, index int) *enode.Node { key := tn.dists[dist][index] - ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)} - return enode.NewV4(&key.PublicKey, ip, 0, 5000) + rec := new(enr.Record) + rec.Set(enr.IP{127, byte(dist >> 8), byte(dist), byte(index)}) + rec.Set(enr.UDP(5000)) + enode.SignV4(rec, key) + n, _ := enode.New(enode.ValidSchemes, rec) + return n } func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) { @@ -261,6 +278,19 @@ func (tn *preminedTestnet) nodesAtDistance(dist int) []rpcNode { return result } +func (tn *preminedTestnet) neighborsAtDistance(base *enode.Node, distance uint, elems int) []*enode.Node { + nodes := nodesByDistance{target: base.ID()} + for d := range lookupTestnet.dists { + for i := range lookupTestnet.dists[d] { + n := lookupTestnet.node(d, i) + if uint(enode.LogDist(n.ID(), base.ID())) == distance { + nodes.push(wrapNode(n), elems) + } + } + } + return unwrapNodes(nodes.entries) +} + func (tn *preminedTestnet) closest(n int) (nodes []*enode.Node) { for d := range tn.dists { for i := range tn.dists[d] { diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index bfb66fcb19..6af05f93dd 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -47,6 +47,7 @@ var ( errTimeout = errors.New("RPC timeout") errClockWarp = errors.New("reply deadline too far in the future") errClosed = errors.New("socket closed") + errLowPort = errors.New("low port") ) const ( @@ -176,7 +177,7 @@ func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*node, error) { if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { return nil, errors.New("not contained in netrestrict whitelist") } - key, err := decodePubkey(rn.ID) + key, err := decodePubkey(crypto.S256(), rn.ID) if err != nil { return nil, err } @@ -209,7 +210,7 @@ type UDPv4 struct { addReplyMatcher chan *replyMatcher gotreply chan reply closeCtx context.Context - cancelCloseCtx func() + cancelCloseCtx context.CancelFunc } // replyMatcher represents a pending reply. @@ -258,6 +259,7 @@ type reply struct { } func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { + cfg = cfg.withDefaults() closeCtx, cancel := context.WithCancel(context.Background()) t := &UDPv4{ conn: c, @@ -271,9 +273,6 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { cancelCloseCtx: cancel, log: cfg.Log, } - if t.log == nil { - t.log = log.Root() - } tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) if err != nil { @@ -812,7 +811,7 @@ func (req *pingV4) preverify(t *UDPv4, from *net.UDPAddr, fromID enode.ID, fromK if expired(req.Expiration) { return errExpired } - key, err := decodePubkey(fromKey) + key, err := decodePubkey(crypto.S256(), fromKey) if err != nil { return errors.New("invalid public key") } diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index b4e024e7ef..ea7194e43e 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -41,10 +41,6 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -func init() { - spew.Config.DisableMethods = true -} - // shared test variables var ( futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) @@ -117,9 +113,12 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr * func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() - dgram, ok := test.pipe.receive() - if !ok { + dgram, err := test.pipe.receive() + if err == errClosed { return true + } else if err != nil { + test.t.Error("packet receive error:", err) + return false } p, _, hash, err := decodeV4(dgram.data) if err != nil { @@ -671,17 +670,30 @@ func (c *dgramPipe) LocalAddr() net.Addr { return &net.UDPAddr{IP: testLocal.IP, Port: int(testLocal.UDP)} } -func (c *dgramPipe) receive() (dgram, bool) { +func (c *dgramPipe) receive() (dgram, error) { c.mu.Lock() defer c.mu.Unlock() - for len(c.queue) == 0 && !c.closed { + + var timedOut bool + timer := time.AfterFunc(3*time.Second, func() { + c.mu.Lock() + timedOut = true + c.mu.Unlock() + c.cond.Broadcast() + }) + defer timer.Stop() + + for len(c.queue) == 0 && !c.closed && !timedOut { c.cond.Wait() } if c.closed { - return dgram{}, false + return dgram{}, errClosed + } + if timedOut { + return dgram{}, errTimeout } p := c.queue[0] copy(c.queue, c.queue[1:]) c.queue = c.queue[:len(c.queue)-1] - return p, true + return p, nil } diff --git a/p2p/discover/v5_encoding.go b/p2p/discover/v5_encoding.go new file mode 100644 index 0000000000..842234e790 --- /dev/null +++ b/p2p/discover/v5_encoding.go @@ -0,0 +1,659 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "hash" + "net" + "time" + + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/crypto/hkdf" +) + +// TODO concurrent WHOAREYOU tie-breaker +// TODO deal with WHOAREYOU amplification factor (min packet size?) +// TODO add counter to nonce +// TODO rehandshake after X packets + +// Discovery v5 packet types. +const ( + p_pingV5 byte = iota + 1 + p_pongV5 + p_findnodeV5 + p_nodesV5 + p_requestTicketV5 + p_ticketV5 + p_regtopicV5 + p_regconfirmationV5 + p_topicqueryV5 + p_unknownV5 = byte(255) // any non-decryptable packet + p_whoareyouV5 = byte(254) // the WHOAREYOU packet +) + +// Discovery v5 packet structures. +type ( + // unknownV5 represents any packet that can't be decrypted. + unknownV5 struct { + AuthTag []byte + } + + // WHOAREYOU contains the handshake challenge. + whoareyouV5 struct { + AuthTag []byte + IDNonce [32]byte // To be signed by recipient. + RecordSeq uint64 // ENR sequence number of recipient + + node *enode.Node + sent mclock.AbsTime + } + + // PING is sent during liveness checks. + pingV5 struct { + ReqID []byte + ENRSeq uint64 + } + + // PONG is the reply to PING. + pongV5 struct { + ReqID []byte + ENRSeq uint64 + ToIP net.IP // These fields should mirror the UDP envelope address of the ping + ToPort uint16 // packet, which provides a way to discover the the external address (after NAT). + } + + // FINDNODE is a query for nodes in the given bucket. + findnodeV5 struct { + ReqID []byte + Distance uint + } + + // NODES is the reply to FINDNODE and TOPICQUERY. + nodesV5 struct { + ReqID []byte + Total uint8 + Nodes []*enr.Record + } + + // REQUESTTICKET requests a ticket for a topic queue. + requestTicketV5 struct { + ReqID []byte + Topic []byte + } + + // TICKET is the response to REQUESTTICKET. + ticketV5 struct { + ReqID []byte + Ticket []byte + } + + // REGTOPIC registers the sender in a topic queue using a ticket. + regtopicV5 struct { + ReqID []byte + Ticket []byte + ENR *enr.Record + } + + // REGCONFIRMATION is the reply to REGTOPIC. + regconfirmationV5 struct { + ReqID []byte + Registered bool + } + + // TOPICQUERY asks for nodes with the given topic. + topicqueryV5 struct { + ReqID []byte + Topic []byte + } +) + +const ( + // Encryption/authentication parameters. + authSchemeName = "gcm" + aesKeySize = 16 + gcmNonceSize = 12 + idNoncePrefix = "discovery-id-nonce" + handshakeTimeout = time.Second +) + +var ( + errTooShort = errors.New("packet too short") + errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake") + errHandshakeNonceMismatch = errors.New("wrong nonce in auth response") + errInvalidAuthKey = errors.New("invalid ephemeral pubkey") + errUnknownAuthScheme = errors.New("unknown auth scheme in handshake") + errNoRecord = errors.New("expected ENR in handshake but none sent") + errInvalidNonceSig = errors.New("invalid ID nonce signature") + zeroNonce = make([]byte, gcmNonceSize) +) + +// wireCodec encodes and decodes discovery v5 packets. +type wireCodec struct { + sha256 hash.Hash + localnode *enode.LocalNode + privkey *ecdsa.PrivateKey + myChtagHash enode.ID + myWhoareyouMagic []byte + + sc *sessionCache +} + +type handshakeSecrets struct { + writeKey, readKey, authRespKey []byte +} + +type authHeader struct { + authHeaderList + isHandshake bool +} + +type authHeaderList struct { + Auth []byte // authentication info of packet + IDNonce [32]byte // IDNonce of WHOAREYOU + Scheme string // name of encryption/authentication scheme + EphemeralKey []byte // ephemeral public key + Response []byte // encrypted authResponse +} + +type authResponse struct { + Version uint + Signature []byte + Record *enr.Record `rlp:"nil"` // sender's record +} + +func (h *authHeader) DecodeRLP(r *rlp.Stream) error { + k, _, err := r.Kind() + if err != nil { + return err + } + if k == rlp.Byte || k == rlp.String { + return r.Decode(&h.Auth) + } + h.isHandshake = true + return r.Decode(&h.authHeaderList) +} + +// ephemeralKey decodes the ephemeral public key in the header. +func (h *authHeaderList) ephemeralKey(curve elliptic.Curve) *ecdsa.PublicKey { + var key encPubkey + copy(key[:], h.EphemeralKey) + pubkey, _ := decodePubkey(curve, key) + return pubkey +} + +// newWireCodec creates a wire codec. +func newWireCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *wireCodec { + c := &wireCodec{ + sha256: sha256.New(), + localnode: ln, + privkey: key, + sc: newSessionCache(1024, clock), + } + // Create magic strings for packet matching. + self := ln.ID() + c.myWhoareyouMagic = c.sha256sum(self[:], []byte("WHOAREYOU")) + copy(c.myChtagHash[:], c.sha256sum(self[:])) + return c +} + +// encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The +// 'challenge' parameter should be the most recently received WHOAREYOU packet from that +// node. +func (c *wireCodec) encode(id enode.ID, addr string, packet packetV5, challenge *whoareyouV5) ([]byte, []byte, error) { + if packet.kind() == p_whoareyouV5 { + p := packet.(*whoareyouV5) + enc, err := c.encodeWhoareyou(id, p) + if err == nil { + c.sc.storeSentHandshake(id, addr, p) + } + return enc, nil, err + } + // Ensure calling code sets node if needed. + if challenge != nil && challenge.node == nil { + panic("BUG: missing challenge.node in encode") + } + writeKey := c.sc.writeKey(id, addr) + if writeKey != nil || challenge != nil { + return c.encodeEncrypted(id, addr, packet, writeKey, challenge) + } + return c.encodeRandom(id) +} + +// encodeRandom encodes a random packet. +func (c *wireCodec) encodeRandom(toID enode.ID) ([]byte, []byte, error) { + tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) + r := make([]byte, 44) // TODO randomize size + if _, err := crand.Read(r); err != nil { + return nil, nil, err + } + nonce := make([]byte, gcmNonceSize) + if _, err := crand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("can't get random data: %v", err) + } + b := new(bytes.Buffer) + b.Write(tag[:]) + rlp.Encode(b, nonce) + b.Write(r) + return b.Bytes(), nonce, nil +} + +// encodeWhoareyou encodes WHOAREYOU. +func (c *wireCodec) encodeWhoareyou(toID enode.ID, packet *whoareyouV5) ([]byte, error) { + // Sanity check node field to catch misbehaving callers. + if packet.RecordSeq > 0 && packet.node == nil { + panic("BUG: missing node in whoareyouV5 with non-zero seq") + } + b := new(bytes.Buffer) + b.Write(c.sha256sum(toID[:], []byte("WHOAREYOU"))) + err := rlp.Encode(b, packet) + return b.Bytes(), err +} + +// encodeEncrypted encodes an encrypted packet. +func (c *wireCodec) encodeEncrypted(toID enode.ID, toAddr string, packet packetV5, writeKey []byte, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) { + nonce := make([]byte, gcmNonceSize) + if _, err := crand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("can't get random data: %v", err) + } + + var headEnc []byte + if challenge == nil { + // Regular packet, use existing key and simply encode nonce. + headEnc, _ = rlp.EncodeToBytes(nonce) + } else { + // We're answering WHOAREYOU, generate new keys and encrypt with those. + header, sec, err := c.makeAuthHeader(nonce, challenge) + if err != nil { + return nil, nil, err + } + if headEnc, err = rlp.EncodeToBytes(header); err != nil { + return nil, nil, err + } + c.sc.storeNewSession(toID, toAddr, sec.readKey, sec.writeKey) + writeKey = sec.writeKey + } + + // Encode the packet. + body := new(bytes.Buffer) + body.WriteByte(packet.kind()) + if err := rlp.Encode(body, packet); err != nil { + return nil, nil, err + } + tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) + headsize := len(tag) + len(headEnc) + headbuf := make([]byte, headsize) + copy(headbuf[:], tag[:]) + copy(headbuf[len(tag):], headEnc) + + // Encrypt the body. + enc, err = encryptGCM(headbuf, writeKey, nonce, body.Bytes(), tag[:]) + return enc, nonce, err +} + +// encodeAuthHeader creates the auth header on a call packet following WHOAREYOU. +func (c *wireCodec) makeAuthHeader(nonce []byte, challenge *whoareyouV5) (*authHeaderList, *handshakeSecrets, error) { + resp := &authResponse{Version: 5} + + // Add our record to response if it's newer than what remote + // side has. + ln := c.localnode.Node() + if challenge.RecordSeq < ln.Seq() { + resp.Record = ln.Record() + } + + // Create the ephemeral key. This needs to be first because the + // key is part of the ID nonce signature. + var remotePubkey = new(ecdsa.PublicKey) + if err := challenge.node.Load((*enode.Secp256k1)(remotePubkey)); err != nil { + return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient") + } + ephkey, err := crypto.GenerateKey() + if err != nil { + return nil, nil, fmt.Errorf("can't generate ephemeral key") + } + ephpubkey := encodePubkey(&ephkey.PublicKey) + + // Add ID nonce signature to response. + idsig, err := c.signIDNonce(challenge.IDNonce[:], ephpubkey[:]) + if err != nil { + return nil, nil, fmt.Errorf("can't sign: %v", err) + } + resp.Signature = idsig + + // Create session keys. + sec := c.deriveKeys(c.localnode.ID(), challenge.node.ID(), ephkey, remotePubkey, challenge) + if sec == nil { + return nil, nil, fmt.Errorf("key derivation failed") + } + + // Encrypt the authentication response and assemble the auth header. + respRLP, err := rlp.EncodeToBytes(resp) + if err != nil { + return nil, nil, fmt.Errorf("can't encode auth response: %v", err) + } + respEnc, err := encryptGCM(nil, sec.authRespKey, zeroNonce, respRLP, nil) + if err != nil { + return nil, nil, fmt.Errorf("can't encrypt auth response: %v", err) + } + head := &authHeaderList{ + Auth: nonce, + Scheme: authSchemeName, + IDNonce: challenge.IDNonce, + EphemeralKey: ephpubkey[:], + Response: respEnc, + } + return head, sec, err +} + +// deriveKeys generates session keys using elliptic-curve Diffie-Hellman key agreement. +func (c *wireCodec) deriveKeys(n1, n2 enode.ID, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, challenge *whoareyouV5) *handshakeSecrets { + eph := ecdh(priv, pub) + if eph == nil { + return nil + } + + info := []byte("discovery v5 key agreement") + info = append(info, n1[:]...) + info = append(info, n2[:]...) + kdf := hkdf.New(c.sha256reset, eph, challenge.IDNonce[:], info) + sec := handshakeSecrets{ + writeKey: make([]byte, aesKeySize), + readKey: make([]byte, aesKeySize), + authRespKey: make([]byte, aesKeySize), + } + kdf.Read(sec.writeKey) + kdf.Read(sec.readKey) + kdf.Read(sec.authRespKey) + for i := range eph { + eph[i] = 0 + } + return &sec +} + +// signIDNonce creates the ID nonce signature. +func (c *wireCodec) signIDNonce(nonce, ephkey []byte) ([]byte, error) { + idsig, err := crypto.Sign(c.idNonceHash(nonce, ephkey), c.privkey) + if err != nil { + return nil, fmt.Errorf("can't sign: %v", err) + } + return idsig[:len(idsig)-1], nil // remove recovery ID +} + +// idNonceHash computes the hash of id nonce with prefix. +func (c *wireCodec) idNonceHash(nonce, ephkey []byte) []byte { + h := c.sha256reset() + h.Write([]byte(idNoncePrefix)) + h.Write(nonce) + h.Write(ephkey) + return h.Sum(nil) +} + +// decode decodes a discovery packet. +func (c *wireCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { + // Delete timed-out handshakes. This must happen before decoding to avoid + // processing the same handshake twice. + c.sc.handshakeGC() + + if len(input) < 32 { + return enode.ID{}, nil, nil, errTooShort + } + if bytes.HasPrefix(input, c.myWhoareyouMagic) { + p, err := c.decodeWhoareyou(input) + return enode.ID{}, nil, p, err + } + sender := xorTag(input[:32], c.myChtagHash) + p, n, err := c.decodeEncrypted(sender, addr, input) + return sender, n, p, err +} + +// decodeWhoareyou decode a WHOAREYOU packet. +func (c *wireCodec) decodeWhoareyou(input []byte) (packetV5, error) { + packet := new(whoareyouV5) + err := rlp.DecodeBytes(input[32:], packet) + return packet, err +} + +// decodeEncrypted decodes an encrypted discovery packet. +func (c *wireCodec) decodeEncrypted(fromID enode.ID, fromAddr string, input []byte) (packetV5, *enode.Node, error) { + // Decode packet header. + var head authHeader + r := bytes.NewReader(input[32:]) + err := rlp.Decode(r, &head) + if err != nil { + return nil, nil, err + } + + // Decrypt and process auth response. + readKey, node, err := c.decodeAuth(fromID, fromAddr, &head) + if err != nil { + return nil, nil, err + } + + // Decrypt and decode the packet body. + headsize := len(input) - r.Len() + bodyEnc := input[headsize:] + body, err := decryptGCM(readKey, head.Auth, bodyEnc, input[:32]) + if err != nil { + if !head.isHandshake { + // Can't decrypt, start handshake. + return &unknownV5{AuthTag: head.Auth}, nil, nil + } + return nil, nil, fmt.Errorf("handshake failed: %v", err) + } + if len(body) == 0 { + return nil, nil, errTooShort + } + p, err := decodePacketBodyV5(body[0], body[1:]) + return p, node, err +} + +// decodeAuth processes an auth header. +func (c *wireCodec) decodeAuth(fromID enode.ID, fromAddr string, head *authHeader) ([]byte, *enode.Node, error) { + if !head.isHandshake { + return c.sc.readKey(fromID, fromAddr), nil, nil + } + + // Remote is attempting handshake. Verify against our last WHOAREYOU. + challenge := c.sc.getHandshake(fromID, fromAddr) + if challenge == nil { + return nil, nil, errUnexpectedHandshake + } + if head.IDNonce != challenge.IDNonce { + return nil, nil, errHandshakeNonceMismatch + } + sec, n, err := c.decodeAuthResp(fromID, fromAddr, &head.authHeaderList, challenge) + if err != nil { + return nil, n, err + } + // Swap keys to match remote. + sec.readKey, sec.writeKey = sec.writeKey, sec.readKey + c.sc.storeNewSession(fromID, fromAddr, sec.readKey, sec.writeKey) + c.sc.deleteHandshake(fromID, fromAddr) + return sec.readKey, n, err +} + +// decodeAuthResp decodes and verifies an authentication response. +func (c *wireCodec) decodeAuthResp(fromID enode.ID, fromAddr string, head *authHeaderList, challenge *whoareyouV5) (*handshakeSecrets, *enode.Node, error) { + // Decrypt / decode the response. + if head.Scheme != authSchemeName { + return nil, nil, errUnknownAuthScheme + } + ephkey := head.ephemeralKey(c.privkey.Curve) + if ephkey == nil { + return nil, nil, errInvalidAuthKey + } + sec := c.deriveKeys(fromID, c.localnode.ID(), c.privkey, ephkey, challenge) + respPT, err := decryptGCM(sec.authRespKey, zeroNonce, head.Response, nil) + if err != nil { + return nil, nil, fmt.Errorf("can't decrypt auth response header: %v", err) + } + var resp authResponse + if err := rlp.DecodeBytes(respPT, &resp); err != nil { + return nil, nil, fmt.Errorf("invalid auth response: %v", err) + } + + // Verify response node record. The remote node should include the record + // if we don't have one or if ours is older than the latest version. + node := challenge.node + if resp.Record != nil { + if node == nil || node.Seq() < resp.Record.Seq() { + n, err := enode.New(enode.ValidSchemes, resp.Record) + if err != nil { + return nil, nil, fmt.Errorf("invalid node record: %v", err) + } + if n.ID() != fromID { + return nil, nil, fmt.Errorf("record in auth respose has wrong ID: %v", n.ID()) + } + node = n + } + } + if node == nil { + return nil, nil, errNoRecord + } + + // Verify ID nonce signature. + err = c.verifyIDSignature(challenge.IDNonce[:], head.EphemeralKey, resp.Signature, node) + if err != nil { + return nil, nil, err + } + return sec, node, nil +} + +// verifyIDSignature checks that signature over idnonce was made by the node with given record. +func (c *wireCodec) verifyIDSignature(nonce, ephkey, sig []byte, n *enode.Node) error { + switch idscheme := n.Record().IdentityScheme(); idscheme { + case "v4": + var pk ecdsa.PublicKey + n.Load((*enode.Secp256k1)(&pk)) // cannot fail because record is valid + if !crypto.VerifySignature(crypto.FromECDSAPub(&pk), c.idNonceHash(nonce, ephkey), sig) { + return errInvalidNonceSig + } + return nil + default: + return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme) + } +} + +// decodePacketBody decodes the body of an encrypted discovery packet. +func decodePacketBodyV5(ptype byte, body []byte) (packetV5, error) { + var dec packetV5 + switch ptype { + case p_pingV5: + dec = new(pingV5) + case p_pongV5: + dec = new(pongV5) + case p_findnodeV5: + dec = new(findnodeV5) + case p_nodesV5: + dec = new(nodesV5) + case p_requestTicketV5: + dec = new(requestTicketV5) + case p_ticketV5: + dec = new(ticketV5) + case p_regtopicV5: + dec = new(regtopicV5) + case p_regconfirmationV5: + dec = new(regconfirmationV5) + case p_topicqueryV5: + dec = new(topicqueryV5) + default: + return nil, fmt.Errorf("unknown packet type %d", ptype) + } + if err := rlp.DecodeBytes(body, dec); err != nil { + return nil, err + } + return dec, nil +} + +// sha256reset returns the shared hash instance. +func (c *wireCodec) sha256reset() hash.Hash { + c.sha256.Reset() + return c.sha256 +} + +// sha256sum computes sha256 on the concatenation of inputs. +func (c *wireCodec) sha256sum(inputs ...[]byte) []byte { + c.sha256.Reset() + for _, b := range inputs { + c.sha256.Write(b) + } + return c.sha256.Sum(nil) +} + +func xorTag(a []byte, b enode.ID) enode.ID { + var r enode.ID + for i := range r { + r[i] = a[i] ^ b[i] + } + return r +} + +// ecdh creates a shared secret. +func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte { + secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes()) + if secX == nil { + return nil + } + sec := make([]byte, 33) + sec[0] = 0x02 | byte(secY.Bit(0)) + math.ReadBits(secX, sec[1:]) + return sec +} + +// encryptGCM encrypts pt using AES-GCM with the given key and nonce. +func encryptGCM(dest, key, nonce, pt, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + panic(fmt.Errorf("can't create block cipher: %v", err)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + panic(fmt.Errorf("can't create GCM: %v", err)) + } + return aesgcm.Seal(dest, nonce, pt, authData), nil +} + +// decryptGCM decrypts ct using AES-GCM with the given key and nonce. +func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("can't create block cipher: %v", err) + } + if len(nonce) != gcmNonceSize { + return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + return nil, fmt.Errorf("can't create GCM: %v", err) + } + pt := make([]byte, 0, len(ct)) + return aesgcm.Open(pt, nonce, ct, authData) +} diff --git a/p2p/discover/v5_encoding_test.go b/p2p/discover/v5_encoding_test.go new file mode 100644 index 0000000000..77e6bae6ae --- /dev/null +++ b/p2p/discover/v5_encoding_test.go @@ -0,0 +1,373 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/ecdsa" + "encoding/hex" + "fmt" + "net" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +var ( + testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") + testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") + testIDnonce = [32]byte{5, 6, 7, 8, 9, 10, 11, 12} +) + +func TestDeriveKeysV5(t *testing.T) { + t.Parallel() + + var ( + n1 = enode.ID{1} + n2 = enode.ID{2} + challenge = &whoareyouV5{} + db, _ = enode.OpenDB("") + ln = enode.NewLocalNode(db, testKeyA) + c = newWireCodec(ln, testKeyA, mclock.System{}) + ) + defer db.Close() + + sec1 := c.deriveKeys(n1, n2, testKeyA, &testKeyB.PublicKey, challenge) + sec2 := c.deriveKeys(n1, n2, testKeyB, &testKeyA.PublicKey, challenge) + if sec1 == nil || sec2 == nil { + t.Fatal("key agreement failed") + } + if !reflect.DeepEqual(sec1, sec2) { + t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2) + } +} + +// This test checks the basic handshake flow where A talks to B and A has no secrets. +func TestHandshakeV5(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + if len(net.nodeB.c.sc.handshakes) > 0 { + t.Fatalf("node B didn't remove handshake from challenge map") + } + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// This test checks that handshake attempts are removed within the timeout. +func TestHandshakeV5_timeout(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE after timeout + net.clock.Run(handshakeTimeout + 1) + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode) +} + +// This test checks handshake behavior when no record is sent in the auth response. +func TestHandshakeV5_norecord(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + resp := net.nodeB.expectDecode(t, p_unknownV5, packet) + + // A <- B WHOAREYOU + nodeA := net.nodeA.n() + if nodeA.Seq() == 0 { + t.Fatal("need non-zero sequence number") + } + challenge := &whoareyouV5{ + AuthTag: resp.(*unknownV5).AuthTag, + IDNonce: testIDnonce, + RecordSeq: nodeA.Seq(), + node: nodeA, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// In this test, A tries to send FINDNODE with existing secrets but B doesn't know +// anything about A. +func TestHandshakeV5_rekey(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + initKeys := &handshakeSecrets{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeys.readKey, initKeys.writeKey) + + // A -> B FINDNODE (encrypted with zero keys) + findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) + net.nodeB.expectDecode(t, p_unknownV5, findnode) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // Check that new keys haven't been stored yet. + if s := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()); !bytes.Equal(s.writeKey, initKeys.writeKey) || !bytes.Equal(s.readKey, initKeys.readKey) { + t.Fatal("node A stored keys too early") + } + if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil { + t.Fatal("node B stored keys too early") + } + + // A -> B FINDNODE encrypted with new keys + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// In this test A and B have different keys before the handshake. +func TestHandshakeV5_rekey2(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + initKeysA := &handshakeSecrets{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + initKeysB := &handshakeSecrets{ + readKey: []byte("CCCCCCCCCCCCCCCC"), + writeKey: []byte("DDDDDDDDDDDDDDDD"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA.readKey, initKeysA.writeKey) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB.readKey, initKeysA.writeKey) + + // A -> B FINDNODE encrypted with initKeysA + findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{Distance: 3}) + net.nodeB.expectDecode(t, p_unknownV5, findnode) + + // A <- B WHOAREYOU + challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) + + // A -> B FINDNODE encrypted with new keys + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) + net.nodeB.expectDecode(t, p_findnodeV5, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) + net.nodeA.expectDecode(t, p_nodesV5, nodes) +} + +// This test checks some malformed packets. +func TestDecodeErrorsV5(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + net.nodeA.expectDecodeErr(t, errTooShort, []byte{}) + // TODO some more tests would be nice :) +} + +// This benchmark checks performance of authHeader decoding, verification and key derivation. +func BenchmarkV5_DecodeAuthSecp256k1(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + var ( + idA = net.nodeA.id() + addrA = net.nodeA.addr() + challenge = &whoareyouV5{AuthTag: []byte("authresp"), RecordSeq: 0, node: net.nodeB.n()} + nonce = make([]byte, gcmNonceSize) + ) + header, _, _ := net.nodeA.c.makeAuthHeader(nonce, challenge) + challenge.node = nil // force ENR signature verification in decoder + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _, err := net.nodeB.c.decodeAuthResp(idA, addrA, header, challenge) + if err != nil { + b.Fatal(err) + } + } +} + +// This benchmark checks how long it takes to decode an encrypted ping packet. +func BenchmarkV5_DecodePing(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + r := []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17} + w := []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134} + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), r, w) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), w, r) + addrB := net.nodeA.addr() + ping := &pingV5{ReqID: []byte("reqid"), ENRSeq: 5} + enc, _, err := net.nodeA.c.encode(net.nodeB.id(), addrB, ping, nil) + if err != nil { + b.Fatalf("can't encode: %v", err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _, p, _ := net.nodeB.c.decode(enc, addrB) + if _, ok := p.(*pingV5); !ok { + b.Fatalf("wrong packet type %T", p) + } + } +} + +var pp = spew.NewDefaultConfig() + +type handshakeTest struct { + nodeA, nodeB handshakeTestNode + clock mclock.Simulated +} + +type handshakeTestNode struct { + ln *enode.LocalNode + c *wireCodec +} + +func newHandshakeTest() *handshakeTest { + t := new(handshakeTest) + t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) + t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) + return t +} + +func (t *handshakeTest) close() { + t.nodeA.ln.Database().Close() + t.nodeB.ln.Database().Close() +} + +func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { + db, _ := enode.OpenDB("") + n.ln = enode.NewLocalNode(db, key) + n.ln.SetStaticIP(ip) + n.c = newWireCodec(n.ln, key, clock) +} + +func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p packetV5) ([]byte, []byte) { + t.Helper() + return n.encodeWithChallenge(t, to, nil, p) +} + +func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *whoareyouV5, p packetV5) ([]byte, []byte) { + t.Helper() + // Copy challenge and add destination node. This avoids sharing 'c' among the two codecs. + var challenge *whoareyouV5 + if c != nil { + challengeCopy := *c + challenge = &challengeCopy + challenge.node = to.n() + } + // Encode to destination. + enc, authTag, err := n.c.encode(to.id(), to.addr(), p, challenge) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.name(), hex.Dump(enc)) + return enc, authTag +} + +func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) packetV5 { + t.Helper() + dec, err := n.decode(p) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec)) + if dec.kind() != ptype { + t.Fatalf("expected packet type %d, got %d", ptype, dec.kind()) + } + return dec +} + +func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) { + t.Helper() + if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) { + t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr)) + } +} + +func (n *handshakeTestNode) decode(input []byte) (packetV5, error) { + _, _, p, err := n.c.decode(input, "127.0.0.1") + return p, err +} + +func (n *handshakeTestNode) n() *enode.Node { + return n.ln.Node() +} + +func (n *handshakeTestNode) addr() string { + return n.ln.Node().IP().String() +} + +func (n *handshakeTestNode) id() enode.ID { + return n.ln.ID() +} diff --git a/p2p/discover/v5_session.go b/p2p/discover/v5_session.go new file mode 100644 index 0000000000..8a0eeb6977 --- /dev/null +++ b/p2p/discover/v5_session.go @@ -0,0 +1,123 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + crand "crypto/rand" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/hashicorp/golang-lru/simplelru" +) + +// The sessionCache keeps negotiated encryption keys and +// state for in-progress handshakes in the Discovery v5 wire protocol. +type sessionCache struct { + sessions *simplelru.LRU + handshakes map[sessionID]*whoareyouV5 + clock mclock.Clock +} + +// sessionID identifies a session or handshake. +type sessionID struct { + id enode.ID + addr string +} + +// session contains session information +type session struct { + writeKey []byte + readKey []byte + nonceCounter uint32 +} + +func newSessionCache(maxItems int, clock mclock.Clock) *sessionCache { + cache, err := simplelru.NewLRU(maxItems, nil) + if err != nil { + panic("can't create session cache") + } + return &sessionCache{ + sessions: cache, + handshakes: make(map[sessionID]*whoareyouV5), + clock: clock, + } +} + +// nextNonce creates a nonce for encrypting a message to the given session. +func (sc *sessionCache) nextNonce(id enode.ID, addr string) []byte { + n := make([]byte, gcmNonceSize) + crand.Read(n) + return n +} + +// session returns the current session for the given node, if any. +func (sc *sessionCache) session(id enode.ID, addr string) *session { + item, ok := sc.sessions.Get(sessionID{id, addr}) + if !ok { + return nil + } + return item.(*session) +} + +// readKey returns the current read key for the given node. +func (sc *sessionCache) readKey(id enode.ID, addr string) []byte { + if s := sc.session(id, addr); s != nil { + return s.readKey + } + return nil +} + +// writeKey returns the current read key for the given node. +func (sc *sessionCache) writeKey(id enode.ID, addr string) []byte { + if s := sc.session(id, addr); s != nil { + return s.writeKey + } + return nil +} + +// storeNewSession stores new encryption keys in the cache. +func (sc *sessionCache) storeNewSession(id enode.ID, addr string, r, w []byte) { + sc.sessions.Add(sessionID{id, addr}, &session{ + readKey: r, writeKey: w, + }) +} + +// getHandshake gets the handshake challenge we previously sent to the given remote node. +func (sc *sessionCache) getHandshake(id enode.ID, addr string) *whoareyouV5 { + return sc.handshakes[sessionID{id, addr}] +} + +// storeSentHandshake stores the handshake challenge sent to the given remote node. +func (sc *sessionCache) storeSentHandshake(id enode.ID, addr string, challenge *whoareyouV5) { + challenge.sent = sc.clock.Now() + sc.handshakes[sessionID{id, addr}] = challenge +} + +// deleteHandshake deletes handshake data for the given node. +func (sc *sessionCache) deleteHandshake(id enode.ID, addr string) { + delete(sc.handshakes, sessionID{id, addr}) +} + +// handshakeGC deletes timed-out handshakes. +func (sc *sessionCache) handshakeGC() { + deadline := sc.clock.Now().Add(-handshakeTimeout) + for key, challenge := range sc.handshakes { + if challenge.sent < deadline { + delete(sc.handshakes, key) + } + } +} diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go new file mode 100644 index 0000000000..e667be1690 --- /dev/null +++ b/p2p/discover/v5_udp.go @@ -0,0 +1,832 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "context" + "crypto/ecdsa" + crand "crypto/rand" + "errors" + "fmt" + "io" + "math" + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/netutil" +) + +const ( + lookupRequestLimit = 3 // max requests against a single node during lookup + findnodeResultLimit = 15 // applies in FINDNODE handler + totalNodesResponseLimit = 5 // applies in waitForNodes + nodesResponseItemLimit = 3 // applies in sendNodes + + respTimeoutV5 = 700 * time.Millisecond +) + +// codecV5 is implemented by wireCodec (and testCodec). +// +// The UDPv5 transport is split into two objects: the codec object deals with +// encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns. +type codecV5 interface { + // encode encodes a packet. The 'challenge' parameter is non-nil for calls which got a + // WHOAREYOU response. + encode(fromID enode.ID, fromAddr string, p packetV5, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) + + // decode decodes a packet. It returns an *unknownV5 packet if decryption fails. + // The fromNode return value is non-nil when the input contains a handshake response. + decode(input []byte, fromAddr string) (fromID enode.ID, fromNode *enode.Node, p packetV5, err error) +} + +// packetV5 is implemented by all discv5 packet type structs. +type packetV5 interface { + // These methods provide information and set the request ID. + name() string + kind() byte + setreqid([]byte) + // handle should perform the appropriate action to handle the packet, i.e. this is the + // place to send the response. + handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) +} + +// UDPv5 is the implementation of protocol version 5. +type UDPv5 struct { + // static fields + conn UDPConn + tab *Table + netrestrict *netutil.Netlist + priv *ecdsa.PrivateKey + localNode *enode.LocalNode + db *enode.DB + log log.Logger + clock mclock.Clock + validSchemes enr.IdentityScheme + + // channels into dispatch + packetInCh chan ReadPacket + readNextCh chan struct{} + callCh chan *callV5 + callDoneCh chan *callV5 + respTimeoutCh chan *callTimeout + + // state of dispatch + codec codecV5 + activeCallByNode map[enode.ID]*callV5 + activeCallByAuth map[string]*callV5 + callQueue map[enode.ID][]*callV5 + + // shutdown stuff + closeOnce sync.Once + closeCtx context.Context + cancelCloseCtx context.CancelFunc + wg sync.WaitGroup +} + +// callV5 represents a remote procedure call against another node. +type callV5 struct { + node *enode.Node + packet packetV5 + responseType byte // expected packet type of response + reqid []byte + ch chan packetV5 // responses sent here + err chan error // errors sent here + + // Valid for active calls only: + authTag []byte // authTag of request packet + handshakeCount int // # times we attempted handshake for this call + challenge *whoareyouV5 // last sent handshake challenge + timeout mclock.Timer +} + +// callTimeout is the response timeout event of a call. +type callTimeout struct { + c *callV5 + timer mclock.Timer +} + +// ListenV5 listens on the given connection. +func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { + t, err := newUDPv5(conn, ln, cfg) + if err != nil { + return nil, err + } + go t.tab.loop() + t.wg.Add(2) + go t.readLoop() + go t.dispatch() + return t, nil +} + +// newUDPv5 creates a UDPv5 transport, but doesn't start any goroutines. +func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { + closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) + cfg = cfg.withDefaults() + t := &UDPv5{ + // static fields + conn: conn, + localNode: ln, + db: ln.Database(), + netrestrict: cfg.NetRestrict, + priv: cfg.PrivateKey, + log: cfg.Log, + validSchemes: cfg.ValidSchemes, + clock: cfg.Clock, + // channels into dispatch + packetInCh: make(chan ReadPacket, 1), + readNextCh: make(chan struct{}, 1), + callCh: make(chan *callV5), + callDoneCh: make(chan *callV5), + respTimeoutCh: make(chan *callTimeout), + // state of dispatch + codec: newWireCodec(ln, cfg.PrivateKey, cfg.Clock), + activeCallByNode: make(map[enode.ID]*callV5), + activeCallByAuth: make(map[string]*callV5), + callQueue: make(map[enode.ID][]*callV5), + // shutdown + closeCtx: closeCtx, + cancelCloseCtx: cancelCloseCtx, + } + tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log) + if err != nil { + return nil, err + } + t.tab = tab + return t, nil +} + +// Self returns the local node record. +func (t *UDPv5) Self() *enode.Node { + return t.localNode.Node() +} + +// Close shuts down packet processing. +func (t *UDPv5) Close() { + t.closeOnce.Do(func() { + t.cancelCloseCtx() + t.conn.Close() + t.wg.Wait() + t.tab.close() + }) +} + +// Ping sends a ping message to the given node. +func (t *UDPv5) Ping(n *enode.Node) error { + _, err := t.ping(n) + return err +} + +// Resolve searches for a specific node with the given ID and tries to get the most recent +// version of the node record for it. It returns n if the node could not be resolved. +func (t *UDPv5) Resolve(n *enode.Node) *enode.Node { + if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + n = intable + } + // Try asking directly. This works if the node is still responding on the endpoint we have. + if resp, err := t.RequestENR(n); err == nil { + return resp + } + // Otherwise do a network lookup. + result := t.Lookup(n.ID()) + for _, rn := range result { + if rn.ID() == n.ID() && rn.Seq() > n.Seq() { + return rn + } + } + return n +} + +func (t *UDPv5) RandomNodes() enode.Iterator { + if t.tab.len() == 0 { + // All nodes were dropped, refresh. The very first query will hit this + // case and run the bootstrapping logic. + <-t.tab.refresh() + } + + return newLookupIterator(t.closeCtx, t.newRandomLookup) +} + +// Lookup performs a recursive lookup for the given target. +// It returns the closest nodes to target. +func (t *UDPv5) Lookup(target enode.ID) []*enode.Node { + return t.newLookup(t.closeCtx, target).run() +} + +// lookupRandom looks up a random target. +// This is needed to satisfy the transport interface. +func (t *UDPv5) lookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).run() +} + +// lookupSelf looks up our own node ID. +// This is needed to satisfy the transport interface. +func (t *UDPv5) lookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, t.Self().ID()).run() +} + +func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { + var target enode.ID + crand.Read(target[:]) + return t.newLookup(ctx, target) +} + +func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { + return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return t.lookupWorker(n, target) + }) +} + +// lookupWorker performs FINDNODE calls against a single node during lookup. +func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { + var ( + dists = lookupDistances(target, destNode.ID()) + nodes = nodesByDistance{target: target} + err error + ) + for i := 0; i < lookupRequestLimit && len(nodes.entries) < findnodeResultLimit; i++ { + var r []*enode.Node + r, err = t.findnode(unwrapNode(destNode), dists[i]) + if err == errClosed { + return nil, err + } + for _, n := range r { + if n.ID() != t.Self().ID() { + nodes.push(wrapNode(n), findnodeResultLimit) + } + } + } + return nodes.entries, err +} + +// lookupDistances computes the distance parameter for FINDNODE calls to dest. +// It chooses distances adjacent to logdist(target, dest), e.g. for a target +// with logdist(target, dest) = 255 the result is [255, 256, 254]. +func lookupDistances(target, dest enode.ID) (dists []int) { + td := enode.LogDist(target, dest) + dists = append(dists, td) + for i := 1; len(dists) < lookupRequestLimit; i++ { + if td+i < 256 { + dists = append(dists, td+i) + } + if td-i > 0 { + dists = append(dists, td-i) + } + } + return dists +} + +// ping calls PING on a node and waits for a PONG response. +func (t *UDPv5) ping(n *enode.Node) (uint64, error) { + resp := t.call(n, p_pongV5, &pingV5{ENRSeq: t.localNode.Node().Seq()}) + defer t.callDone(resp) + select { + case pong := <-resp.ch: + return pong.(*pongV5).ENRSeq, nil + case err := <-resp.err: + return 0, err + } +} + +// requestENR requests n's record. +func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { + nodes, err := t.findnode(n, 0) + if err != nil { + return nil, err + } + if len(nodes) != 1 { + return nil, fmt.Errorf("%d nodes in response for distance zero", len(nodes)) + } + return nodes[0], nil +} + +// requestTicket calls REQUESTTICKET on a node and waits for a TICKET response. +func (t *UDPv5) requestTicket(n *enode.Node) ([]byte, error) { + resp := t.call(n, p_ticketV5, &pingV5{}) + defer t.callDone(resp) + select { + case response := <-resp.ch: + return response.(*ticketV5).Ticket, nil + case err := <-resp.err: + return nil, err + } +} + +// findnode calls FINDNODE on a node and waits for responses. +func (t *UDPv5) findnode(n *enode.Node, distance int) ([]*enode.Node, error) { + resp := t.call(n, p_nodesV5, &findnodeV5{Distance: uint(distance)}) + return t.waitForNodes(resp, distance) +} + +// waitForNodes waits for NODES responses to the given call. +func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { + defer t.callDone(c) + + var ( + nodes []*enode.Node + seen = make(map[enode.ID]struct{}) + received, total = 0, -1 + ) + for { + select { + case responseP := <-c.ch: + response := responseP.(*nodesV5) + for _, record := range response.Nodes { + node, err := t.verifyResponseNode(c, record, distance, seen) + if err != nil { + t.log.Debug("Invalid record in "+response.name(), "id", c.node.ID(), "err", err) + continue + } + nodes = append(nodes, node) + } + if total == -1 { + total = min(int(response.Total), totalNodesResponseLimit) + } + if received++; received == total { + return nodes, nil + } + case err := <-c.err: + return nodes, err + } + } +} + +// verifyResponseNode checks validity of a record in a NODES response. +func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen map[enode.ID]struct{}) (*enode.Node, error) { + node, err := enode.New(t.validSchemes, r) + if err != nil { + return nil, err + } + if err := netutil.CheckRelayIP(c.node.IP(), node.IP()); err != nil { + return nil, err + } + if c.node.UDP() <= 1024 { + return nil, errLowPort + } + if distance != -1 { + if d := enode.LogDist(c.node.ID(), node.ID()); d != distance { + return nil, fmt.Errorf("wrong distance %d", d) + } + } + if _, ok := seen[node.ID()]; ok { + return nil, fmt.Errorf("duplicate record") + } + seen[node.ID()] = struct{}{} + return node, nil +} + +// call sends the given call and sets up a handler for response packets (of type c.responseType). +// Responses are dispatched to the call's response channel. +func (t *UDPv5) call(node *enode.Node, responseType byte, packet packetV5) *callV5 { + c := &callV5{ + node: node, + packet: packet, + responseType: responseType, + reqid: make([]byte, 8), + ch: make(chan packetV5, 1), + err: make(chan error, 1), + } + // Assign request ID. + crand.Read(c.reqid) + packet.setreqid(c.reqid) + // Send call to dispatch. + select { + case t.callCh <- c: + case <-t.closeCtx.Done(): + c.err <- errClosed + } + return c +} + +// callDone tells dispatch that the active call is done. +func (t *UDPv5) callDone(c *callV5) { + select { + case t.callDoneCh <- c: + case <-t.closeCtx.Done(): + } +} + +// dispatch runs in its own goroutine, handles incoming packets and deals with calls. +// +// For any destination node there is at most one 'active call', stored in the t.activeCall* +// maps. A call is made active when it is sent. The active call can be answered by a +// matching response, in which case c.ch receives the response; or by timing out, in which case +// c.err receives the error. When the function that created the call signals the active +// call is done through callDone, the next call from the call queue is started. +// +// Calls may also be answered by a WHOAREYOU packet referencing the call packet's authTag. +// When that happens the call is simply re-sent to complete the handshake. We allow one +// handshake attempt per call. +func (t *UDPv5) dispatch() { + defer t.wg.Done() + + // Arm first read. + t.readNextCh <- struct{}{} + + for { + select { + case c := <-t.callCh: + id := c.node.ID() + t.callQueue[id] = append(t.callQueue[id], c) + t.sendNextCall(id) + + case ct := <-t.respTimeoutCh: + active := t.activeCallByNode[ct.c.node.ID()] + if ct.c == active && ct.timer == active.timeout { + ct.c.err <- errTimeout + } + + case c := <-t.callDoneCh: + id := c.node.ID() + active := t.activeCallByNode[id] + if active != c { + panic("BUG: callDone for inactive call") + } + c.timeout.Stop() + delete(t.activeCallByAuth, string(c.authTag)) + delete(t.activeCallByNode, id) + t.sendNextCall(id) + + case p := <-t.packetInCh: + t.handlePacket(p.Data, p.Addr) + // Arm next read. + t.readNextCh <- struct{}{} + + case <-t.closeCtx.Done(): + close(t.readNextCh) + for id, queue := range t.callQueue { + for _, c := range queue { + c.err <- errClosed + } + delete(t.callQueue, id) + } + for id, c := range t.activeCallByNode { + c.err <- errClosed + delete(t.activeCallByNode, id) + delete(t.activeCallByAuth, string(c.authTag)) + } + return + } + } +} + +// startResponseTimeout sets the response timer for a call. +func (t *UDPv5) startResponseTimeout(c *callV5) { + if c.timeout != nil { + c.timeout.Stop() + } + var ( + timer mclock.Timer + done = make(chan struct{}) + ) + timer = t.clock.AfterFunc(respTimeoutV5, func() { + <-done + select { + case t.respTimeoutCh <- &callTimeout{c, timer}: + case <-t.closeCtx.Done(): + } + }) + c.timeout = timer + close(done) +} + +// sendNextCall sends the next call in the call queue if there is no active call. +func (t *UDPv5) sendNextCall(id enode.ID) { + queue := t.callQueue[id] + if len(queue) == 0 || t.activeCallByNode[id] != nil { + return + } + t.activeCallByNode[id] = queue[0] + t.sendCall(t.activeCallByNode[id]) + if len(queue) == 1 { + delete(t.callQueue, id) + } else { + copy(queue, queue[1:]) + t.callQueue[id] = queue[:len(queue)-1] + } +} + +// sendCall encodes and sends a request packet to the call's recipient node. +// This performs a handshake if needed. +func (t *UDPv5) sendCall(c *callV5) { + if len(c.authTag) > 0 { + // The call already has an authTag from a previous handshake attempt. Remove the + // entry for the authTag because we're about to generate a new authTag for this + // call. + delete(t.activeCallByAuth, string(c.authTag)) + } + + addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()} + newTag, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) + c.authTag = newTag + t.activeCallByAuth[string(c.authTag)] = c + t.startResponseTimeout(c) +} + +// sendResponse sends a response packet to the given node. +// This doesn't trigger a handshake even if no keys are available. +func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet packetV5) error { + _, err := t.send(toID, toAddr, packet, nil) + return err +} + +// send sends a packet to the given node. +func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet packetV5, c *whoareyouV5) ([]byte, error) { + addr := toAddr.String() + enc, authTag, err := t.codec.encode(toID, addr, packet, c) + if err != nil { + t.log.Warn(">> "+packet.name(), "id", toID, "addr", addr, "err", err) + return authTag, err + } + _, err = t.conn.WriteToUDP(enc, toAddr) + t.log.Trace(">> "+packet.name(), "id", toID, "addr", addr) + return authTag, err +} + +// readLoop runs in its own goroutine and reads packets from the network. +func (t *UDPv5) readLoop() { + defer t.wg.Done() + + buf := make([]byte, maxPacketSize) + for range t.readNextCh { + nbytes, from, err := t.conn.ReadFromUDP(buf) + if netutil.IsTemporaryError(err) { + // Ignore temporary read errors. + t.log.Debug("Temporary UDP read error", "err", err) + continue + } else if err != nil { + // Shut down the loop for permament errors. + if err != io.EOF { + t.log.Debug("UDP read error", "err", err) + } + return + } + t.dispatchReadPacket(from, buf[:nbytes]) + } +} + +// dispatchReadPacket sends a packet into the dispatch loop. +func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { + select { + case t.packetInCh <- ReadPacket{content, from}: + return true + case <-t.closeCtx.Done(): + return false + } +} + +// handlePacket decodes and processes an incoming packet from the network. +func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { + addr := fromAddr.String() + fromID, fromNode, packet, err := t.codec.decode(rawpacket, addr) + if err != nil { + t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) + return err + } + if fromNode != nil { + // Handshake succeeded, add to table. + t.tab.addSeenNode(wrapNode(fromNode)) + } + if packet.kind() != p_whoareyouV5 { + // WHOAREYOU logged separately to report the sender ID. + t.log.Trace("<< "+packet.name(), "id", fromID, "addr", addr) + } + packet.handle(t, fromID, fromAddr) + return nil +} + +// handleCallResponse dispatches a response packet to the call waiting for it. +func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, reqid []byte, p packetV5) { + ac := t.activeCallByNode[fromID] + if ac == nil || !bytes.Equal(reqid, ac.reqid) { + t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.name()), "id", fromID, "addr", fromAddr) + return + } + if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() { + t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.name()), "id", fromID, "addr", fromAddr) + return + } + if p.kind() != ac.responseType { + t.log.Debug(fmt.Sprintf("Wrong disv5 response type %s", p.name()), "id", fromID, "addr", fromAddr) + return + } + t.startResponseTimeout(ac) + ac.ch <- p +} + +// getNode looks for a node record in table and database. +func (t *UDPv5) getNode(id enode.ID) *enode.Node { + if n := t.tab.getNode(id); n != nil { + return n + } + if n := t.localNode.Database().Node(id); n != nil { + return n + } + return nil +} + +// UNKNOWN + +func (p *unknownV5) name() string { return "UNKNOWN/v5" } +func (p *unknownV5) kind() byte { return p_unknownV5 } +func (p *unknownV5) setreqid(id []byte) {} + +func (p *unknownV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + challenge := &whoareyouV5{AuthTag: p.AuthTag} + crand.Read(challenge.IDNonce[:]) + if n := t.getNode(fromID); n != nil { + challenge.node = n + challenge.RecordSeq = n.Seq() + } + t.sendResponse(fromID, fromAddr, challenge) +} + +// WHOAREYOU + +func (p *whoareyouV5) name() string { return "WHOAREYOU/v5" } +func (p *whoareyouV5) kind() byte { return p_whoareyouV5 } +func (p *whoareyouV5) setreqid(id []byte) {} + +func (p *whoareyouV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + c, err := p.matchWithCall(t, p.AuthTag) + if err != nil { + t.log.Debug("Invalid WHOAREYOU/v5", "addr", fromAddr, "err", err) + return + } + // Resend the call that was answered by WHOAREYOU. + t.log.Trace("<< "+p.name(), "id", c.node.ID(), "addr", fromAddr) + c.handshakeCount++ + c.challenge = p + p.node = c.node + t.sendCall(c) +} + +var ( + errChallengeNoCall = errors.New("no matching call") + errChallengeTwice = errors.New("second handshake") +) + +// matchWithCall checks whether the handshake attempt matches the active call. +func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) { + c := t.activeCallByAuth[string(authTag)] + if c == nil { + return nil, errChallengeNoCall + } + if c.handshakeCount > 0 { + return nil, errChallengeTwice + } + return c, nil +} + +// PING + +func (p *pingV5) name() string { return "PING/v5" } +func (p *pingV5) kind() byte { return p_pingV5 } +func (p *pingV5) setreqid(id []byte) { p.ReqID = id } + +func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, &pongV5{ + ReqID: p.ReqID, + ToIP: fromAddr.IP, + ToPort: uint16(fromAddr.Port), + ENRSeq: t.localNode.Node().Seq(), + }) +} + +// PONG + +func (p *pongV5) name() string { return "PONG/v5" } +func (p *pongV5) kind() byte { return p_pongV5 } +func (p *pongV5) setreqid(id []byte) { p.ReqID = id } + +func (p *pongV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// FINDNODE + +func (p *findnodeV5) name() string { return "FINDNODE/v5" } +func (p *findnodeV5) kind() byte { return p_findnodeV5 } +func (p *findnodeV5) setreqid(id []byte) { p.ReqID = id } + +func (p *findnodeV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + if p.Distance == 0 { + t.sendNodes(fromID, fromAddr, p.ReqID, []*enode.Node{t.Self()}) + return + } + if p.Distance > 256 { + p.Distance = 256 + } + // Get bucket entries. + t.tab.mutex.Lock() + nodes := unwrapNodes(t.tab.bucketAtDistance(int(p.Distance)).entries) + t.tab.mutex.Unlock() + if len(nodes) > findnodeResultLimit { + nodes = nodes[:findnodeResultLimit] + } + t.sendNodes(fromID, fromAddr, p.ReqID, nodes) +} + +// sendNodes sends the given records in one or more NODES packets. +func (t *UDPv5) sendNodes(toID enode.ID, toAddr *net.UDPAddr, reqid []byte, nodes []*enode.Node) { + // TODO livenessChecks > 1 + // TODO CheckRelayIP + total := uint8(math.Ceil(float64(len(nodes)) / 3)) + resp := &nodesV5{ReqID: reqid, Total: total, Nodes: make([]*enr.Record, 3)} + sent := false + for len(nodes) > 0 { + items := min(nodesResponseItemLimit, len(nodes)) + resp.Nodes = resp.Nodes[:items] + for i := 0; i < items; i++ { + resp.Nodes[i] = nodes[i].Record() + } + t.sendResponse(toID, toAddr, resp) + nodes = nodes[items:] + sent = true + } + // Ensure at least one response is sent. + if !sent { + resp.Total = 1 + resp.Nodes = nil + t.sendResponse(toID, toAddr, resp) + } +} + +// NODES + +func (p *nodesV5) name() string { return "NODES/v5" } +func (p *nodesV5) kind() byte { return p_nodesV5 } +func (p *nodesV5) setreqid(id []byte) { p.ReqID = id } + +func (p *nodesV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// REQUESTTICKET + +func (p *requestTicketV5) name() string { return "REQUESTTICKET/v5" } +func (p *requestTicketV5) kind() byte { return p_requestTicketV5 } +func (p *requestTicketV5) setreqid(id []byte) { p.ReqID = id } + +func (p *requestTicketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, &ticketV5{ReqID: p.ReqID}) +} + +// TICKET + +func (p *ticketV5) name() string { return "TICKET/v5" } +func (p *ticketV5) kind() byte { return p_ticketV5 } +func (p *ticketV5) setreqid(id []byte) { p.ReqID = id } + +func (p *ticketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// REGTOPIC + +func (p *regtopicV5) name() string { return "REGTOPIC/v5" } +func (p *regtopicV5) kind() byte { return p_regtopicV5 } +func (p *regtopicV5) setreqid(id []byte) { p.ReqID = id } + +func (p *regtopicV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, ®confirmationV5{ReqID: p.ReqID, Registered: false}) +} + +// REGCONFIRMATION + +func (p *regconfirmationV5) name() string { return "REGCONFIRMATION/v5" } +func (p *regconfirmationV5) kind() byte { return p_regconfirmationV5 } +func (p *regconfirmationV5) setreqid(id []byte) { p.ReqID = id } + +func (p *regconfirmationV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +} + +// TOPICQUERY + +func (p *topicqueryV5) name() string { return "TOPICQUERY/v5" } +func (p *topicqueryV5) kind() byte { return p_topicqueryV5 } +func (p *topicqueryV5) setreqid(id []byte) { p.ReqID = id } + +func (p *topicqueryV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { +} diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go new file mode 100644 index 0000000000..15ea0402c2 --- /dev/null +++ b/p2p/discover/v5_udp_test.go @@ -0,0 +1,622 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "bytes" + "crypto/ecdsa" + "encoding/binary" + "fmt" + "math/rand" + "net" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +// Real sockets, real crypto: this test checks end-to-end connectivity for UDPv5. +func TestEndToEndV5(t *testing.T) { + t.Parallel() + + var nodes []*UDPv5 + for i := 0; i < 5; i++ { + var cfg Config + if len(nodes) > 0 { + bn := nodes[0].Self() + cfg.Bootnodes = []*enode.Node{bn} + } + node := startLocalhostV5(t, cfg) + nodes = append(nodes, node) + defer node.Close() + } + + last := nodes[len(nodes)-1] + target := nodes[rand.Intn(len(nodes)-2)].Self() + results := last.Lookup(target.ID()) + if len(results) == 0 || results[0].ID() != target.ID() { + t.Fatalf("lookup returned wrong results: %v", results) + } +} + +func startLocalhostV5(t *testing.T, cfg Config) *UDPv5 { + cfg.PrivateKey = newkey() + db, _ := enode.OpenDB("") + ln := enode.NewLocalNode(db, cfg.PrivateKey) + + // Prefix logs with node ID. + lprefix := fmt.Sprintf("(%s)", ln.ID().TerminalString()) + lfmt := log.TerminalFormat(false) + cfg.Log = testlog.Logger(t, log.LvlTrace) + cfg.Log.SetHandler(log.FuncHandler(func(r *log.Record) error { + t.Logf("%s %s", lprefix, lfmt.Format(r)) + return nil + })) + + // Listen. + socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + if err != nil { + t.Fatal(err) + } + realaddr := socket.LocalAddr().(*net.UDPAddr) + ln.SetStaticIP(realaddr.IP) + ln.Set(enr.UDP(realaddr.Port)) + udp, err := ListenV5(socket, ln, cfg) + if err != nil { + t.Fatal(err) + } + return udp +} + +// This test checks that incoming PING calls are handled correctly. +func TestUDPv5_pingHandling(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + test.packetIn(&pingV5{ReqID: []byte("foo")}) + test.waitPacketOut(func(p *pongV5, addr *net.UDPAddr, authTag []byte) { + if !bytes.Equal(p.ReqID, []byte("foo")) { + t.Error("wrong request ID in response:", p.ReqID) + } + if p.ENRSeq != test.table.self().Seq() { + t.Error("wrong ENR sequence number in response:", p.ENRSeq) + } + }) +} + +// This test checks that incoming 'unknown' packets trigger the handshake. +func TestUDPv5_unknownPacket(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + authTag := [12]byte{1, 2, 3} + check := func(p *whoareyouV5, wantSeq uint64) { + t.Helper() + if !bytes.Equal(p.AuthTag, authTag[:]) { + t.Error("wrong token in WHOAREYOU:", p.AuthTag, authTag[:]) + } + if p.IDNonce == ([32]byte{}) { + t.Error("all zero ID nonce") + } + if p.RecordSeq != wantSeq { + t.Errorf("wrong record seq %d in WHOAREYOU, want %d", p.RecordSeq, wantSeq) + } + } + + // Unknown packet from unknown node. + test.packetIn(&unknownV5{AuthTag: authTag[:]}) + test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + check(p, 0) + }) + + // Make node known. + n := test.getNode(test.remotekey, test.remoteaddr).Node() + test.table.addSeenNode(wrapNode(n)) + + test.packetIn(&unknownV5{AuthTag: authTag[:]}) + test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + check(p, n.Seq()) + }) +} + +// This test checks that incoming FINDNODE calls are handled correctly. +func TestUDPv5_findnodeHandling(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Create test nodes and insert them into the table. + nodes := nodesAtDistance(test.table.self().ID(), 253, 10) + fillTable(test.table, wrapNodes(nodes)) + + // Requesting with distance zero should return the node's own record. + test.packetIn(&findnodeV5{ReqID: []byte{0}, Distance: 0}) + test.expectNodes([]byte{0}, 1, []*enode.Node{test.udp.Self()}) + + // Requesting with distance > 256 caps it at 256. + test.packetIn(&findnodeV5{ReqID: []byte{1}, Distance: 4234098}) + test.expectNodes([]byte{1}, 1, nil) + + // This request gets no nodes because the corresponding bucket is empty. + test.packetIn(&findnodeV5{ReqID: []byte{2}, Distance: 254}) + test.expectNodes([]byte{2}, 1, nil) + + // This request gets all test nodes. + test.packetIn(&findnodeV5{ReqID: []byte{3}, Distance: 253}) + test.expectNodes([]byte{3}, 4, nodes) +} + +func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes []*enode.Node) { + nodeSet := make(map[enode.ID]*enr.Record) + for _, n := range wantNodes { + nodeSet[n.ID()] = n.Record() + } + for { + test.waitPacketOut(func(p *nodesV5, addr *net.UDPAddr, authTag []byte) { + if len(p.Nodes) > 3 { + test.t.Fatalf("too many nodes in response") + } + if p.Total != wantTotal { + test.t.Fatalf("wrong total response count %d", p.Total) + } + if !bytes.Equal(p.ReqID, wantReqID) { + test.t.Fatalf("wrong request ID in response: %v", p.ReqID) + } + for _, record := range p.Nodes { + n, _ := enode.New(enode.ValidSchemesForTesting, record) + want := nodeSet[n.ID()] + if want == nil { + test.t.Fatalf("unexpected node in response: %v", n) + } + if !reflect.DeepEqual(record, want) { + test.t.Fatalf("wrong record in response: %v", n) + } + delete(nodeSet, n.ID()) + } + }) + if len(nodeSet) == 0 { + return + } + } +} + +// This test checks that outgoing PING calls work. +func TestUDPv5_pingCall(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 1) + + // This ping times out. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) {}) + if err := <-done; err != errTimeout { + t.Fatalf("want errTimeout, got %q", err) + } + + // This ping works. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetInFrom(test.remotekey, test.remoteaddr, &pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != nil { + t.Fatal(err) + } + + // This ping gets a reply from the wrong endpoint. + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} + test.packetInFrom(test.remotekey, wrongAddr, &pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != errTimeout { + t.Fatalf("want errTimeout for reply from wrong IP, got %q", err) + } +} + +// This test checks that outgoing FINDNODE calls work and multiple NODES +// replies are aggregated. +func TestUDPv5_findnodeCall(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Launch the request: + var ( + distance = 230 + remote = test.getNode(test.remotekey, test.remoteaddr).Node() + nodes = nodesAtDistance(remote.ID(), distance, 8) + done = make(chan error, 1) + response []*enode.Node + ) + go func() { + var err error + response, err = test.udp.findnode(remote, distance) + done <- err + }() + + // Serve the responses: + test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { + if p.Distance != uint(distance) { + t.Fatalf("wrong bucket: %d", p.Distance) + } + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[:4]), + }) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[4:]), + }) + }) + + // Check results: + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(response, nodes) { + t.Fatalf("wrong nodes in response") + } + + // TODO: check invalid IPs + // TODO: check invalid/unsigned record +} + +// This test checks that pending calls are re-sent when a handshake happens. +func TestUDPv5_callResend(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 2) + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + + // Ping answered by WHOAREYOU. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + // Ping should be re-sent. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&pongV5{ReqID: p.ReqID}) + }) + // Answer the other ping. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&pongV5{ReqID: p.ReqID}) + }) + if err := <-done; err != nil { + t.Fatalf("unexpected ping error: %v", err) + } + if err := <-done; err != nil { + t.Fatalf("unexpected ping error: %v", err) + } +} + +// This test ensures we don't allow multiple rounds of WHOAREYOU for a single call. +func TestUDPv5_multipleHandshakeRounds(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 1) + go func() { + _, err := test.udp.ping(remote) + done <- err + }() + + // Ping answered by WHOAREYOU. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + // Ping answered by WHOAREYOU again. + test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&whoareyouV5{AuthTag: authTag}) + }) + if err := <-done; err != errTimeout { + t.Fatalf("unexpected ping error: %q", err) + } +} + +// This test checks that calls with n replies may take up to n * respTimeout. +func TestUDPv5_callTimeoutReset(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + // Launch the request: + var ( + distance = 230 + remote = test.getNode(test.remotekey, test.remoteaddr).Node() + nodes = nodesAtDistance(remote.ID(), distance, 8) + done = make(chan error, 1) + ) + go func() { + _, err := test.udp.findnode(remote, distance) + done <- err + }() + + // Serve two responses, slowly. + test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { + time.Sleep(respTimeout - 50*time.Millisecond) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[:4]), + }) + + time.Sleep(respTimeout - 50*time.Millisecond) + test.packetIn(&nodesV5{ + ReqID: p.ReqID, + Total: 2, + Nodes: nodesToRecords(nodes[4:]), + }) + }) + if err := <-done; err != nil { + t.Fatalf("unexpected error: %q", err) + } +} + +// This test checks that lookup works. +func TestUDPv5_lookup(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + + // Lookup on empty table returns no nodes. + if results := test.udp.Lookup(lookupTestnet.target.id()); len(results) > 0 { + t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) + } + + // Ensure the tester knows all nodes in lookupTestnet by IP. + for d, nn := range lookupTestnet.dists { + for i, key := range nn { + n := lookupTestnet.node(d, i) + test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()}) + } + } + + // Seed table with initial node. + fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}) + + // Start the lookup. + resultC := make(chan []*enode.Node, 1) + go func() { + resultC <- test.udp.Lookup(lookupTestnet.target.id()) + test.close() + }() + + // Answer lookup packets. + for done := false; !done; { + done = test.waitPacketOut(func(p packetV5, to *net.UDPAddr, authTag []byte) { + recipient, key := lookupTestnet.nodeByAddr(to) + switch p := p.(type) { + case *pingV5: + test.packetInFrom(key, to, &pongV5{ReqID: p.ReqID}) + case *findnodeV5: + nodes := lookupTestnet.neighborsAtDistance(recipient, p.Distance, 3) + response := &nodesV5{ReqID: p.ReqID, Total: 1, Nodes: nodesToRecords(nodes)} + test.packetInFrom(key, to, response) + } + }) + } + + // Verify result nodes. + checkLookupResults(t, lookupTestnet, <-resultC) +} + +// udpV5Test is the framework for all tests above. +// It runs the UDPv5 transport on a virtual socket and allows testing outgoing packets. +type udpV5Test struct { + t *testing.T + pipe *dgramPipe + table *Table + db *enode.DB + udp *UDPv5 + localkey, remotekey *ecdsa.PrivateKey + remoteaddr *net.UDPAddr + nodesByID map[enode.ID]*enode.LocalNode + nodesByIP map[string]*enode.LocalNode +} + +type testCodec struct { + test *udpV5Test + id enode.ID + ctr uint64 +} + +type testCodecFrame struct { + NodeID enode.ID + AuthTag []byte + Ptype byte + Packet rlp.RawValue +} + +func (c *testCodec) encode(toID enode.ID, addr string, p packetV5, _ *whoareyouV5) ([]byte, []byte, error) { + c.ctr++ + authTag := make([]byte, 8) + binary.BigEndian.PutUint64(authTag, c.ctr) + penc, _ := rlp.EncodeToBytes(p) + frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.kind(), penc}) + return frame, authTag, err +} + +func (c *testCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { + frame, p, err := c.decodeFrame(input) + if err != nil { + return enode.ID{}, nil, nil, err + } + if p.kind() == p_whoareyouV5 { + frame.NodeID = enode.ID{} // match wireCodec behavior + } + return frame.NodeID, nil, p, nil +} + +func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p packetV5, err error) { + if err = rlp.DecodeBytes(input, &frame); err != nil { + return frame, nil, fmt.Errorf("invalid frame: %v", err) + } + switch frame.Ptype { + case p_unknownV5: + dec := new(unknownV5) + err = rlp.DecodeBytes(frame.Packet, &dec) + p = dec + case p_whoareyouV5: + dec := new(whoareyouV5) + err = rlp.DecodeBytes(frame.Packet, &dec) + p = dec + default: + p, err = decodePacketBodyV5(frame.Ptype, frame.Packet) + } + return frame, p, err +} + +func newUDPV5Test(t *testing.T) *udpV5Test { + test := &udpV5Test{ + t: t, + pipe: newpipe(), + localkey: newkey(), + remotekey: newkey(), + remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + nodesByID: make(map[enode.ID]*enode.LocalNode), + nodesByIP: make(map[string]*enode.LocalNode), + } + test.db, _ = enode.OpenDB("") + ln := enode.NewLocalNode(test.db, test.localkey) + ln.SetStaticIP(net.IP{10, 0, 0, 1}) + ln.Set(enr.UDP(30303)) + test.udp, _ = ListenV5(test.pipe, ln, Config{ + PrivateKey: test.localkey, + Log: testlog.Logger(t, log.LvlTrace), + ValidSchemes: enode.ValidSchemesForTesting, + }) + test.udp.codec = &testCodec{test: test, id: ln.ID()} + test.table = test.udp.tab + test.nodesByID[ln.ID()] = ln + // Wait for initial refresh so the table doesn't send unexpected findnode. + <-test.table.initDone + return test +} + +// handles a packet as if it had been sent to the transport. +func (test *udpV5Test) packetIn(packet packetV5) { + test.t.Helper() + test.packetInFrom(test.remotekey, test.remoteaddr, packet) +} + +// handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet packetV5) { + test.t.Helper() + + ln := test.getNode(key, addr) + codec := &testCodec{test: test, id: ln.ID()} + enc, _, err := codec.encode(test.udp.Self().ID(), addr.String(), packet, nil) + if err != nil { + test.t.Errorf("%s encode error: %v", packet.name(), err) + } + if test.udp.dispatchReadPacket(addr, enc) { + <-test.udp.readNextCh // unblock UDPv5.dispatch + } +} + +// getNode ensures the test knows about a node at the given endpoint. +func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode { + id := encodePubkey(&key.PublicKey).id() + ln := test.nodesByID[id] + if ln == nil { + db, _ := enode.OpenDB("") + ln = enode.NewLocalNode(db, key) + ln.SetStaticIP(addr.IP) + ln.Set(enr.UDP(addr.Port)) + test.nodesByID[id] = ln + } + test.nodesByIP[string(addr.IP)] = ln + return ln +} + +func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { + test.t.Helper() + fn := reflect.ValueOf(validate) + exptype := fn.Type().In(0) + + dgram, err := test.pipe.receive() + if err == errClosed { + return true + } + if err == errTimeout { + test.t.Fatalf("timed out waiting for %v", exptype) + return false + } + ln := test.nodesByIP[string(dgram.to.IP)] + if ln == nil { + test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to) + return false + } + codec := &testCodec{test: test, id: ln.ID()} + frame, p, err := codec.decodeFrame(dgram.data) + if err != nil { + test.t.Errorf("sent packet decode error: %v", err) + return false + } + if !reflect.TypeOf(p).AssignableTo(exptype) { + test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) + return false + } + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)}) + return false +} + +func (test *udpV5Test) close() { + test.t.Helper() + + test.udp.Close() + test.db.Close() + for id, n := range test.nodesByID { + if id != test.udp.Self().ID() { + n.Database().Close() + } + } + if len(test.pipe.queue) != 0 { + test.t.Fatalf("%d unmatched UDP packets in queue", len(test.pipe.queue)) + } +} diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 44332640c7..bd066ce857 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -41,6 +41,7 @@ const ( dbNodePrefix = "n:" // Identifier to prefix node entries with dbLocalPrefix = "local:" dbDiscoverRoot = "v4" + dbDiscv5Root = "v5" // These fields are stored per ID and IP, the full key is "n::v4::findfail". // Use nodeItemKey to create those keys. @@ -172,6 +173,16 @@ func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { return id, ip, field } +func v5Key(id ID, ip net.IP, field string) []byte { + return bytes.Join([][]byte{ + []byte(dbNodePrefix), + id[:], + []byte(dbDiscv5Root), + ip.To16(), + []byte(field), + }, []byte{':'}) +} + // localItemKey returns the key of a local node item. func localItemKey(id ID, field string) []byte { key := append([]byte(dbLocalPrefix), id[:]...) @@ -378,6 +389,16 @@ func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) } +// FindFailsV5 retrieves the discv5 findnode failure counter. +func (db *DB) FindFailsV5(id ID, ip net.IP) int { + return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) +} + +// UpdateFindFailsV5 stores the discv5 findnode failure counter. +func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { + return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) +} + // LocalSeq retrieves the local record sequence counter. func (db *DB) localSeq(id ID) uint64 { return db.fetchUint64(localItemKey(id, dbLocalSeq)) diff --git a/p2p/enode/nodedb_test.go b/p2p/enode/nodedb_test.go index 2adb14145d..d2b187896f 100644 --- a/p2p/enode/nodedb_test.go +++ b/p2p/enode/nodedb_test.go @@ -462,3 +462,14 @@ func TestDBExpiration(t *testing.T) { } } } + +// This test checks that expiration works when discovery v5 data is present +// in the database. +func TestDBExpireV5(t *testing.T) { + db, _ := OpenDB("") + defer db.Close() + + ip := net.IP{127, 0, 0, 1} + db.UpdateFindFailsV5(ID{}, ip, 4) + db.expireNodes() +}