p2p/discover: implement v5.1 wire protocol (#21647)

This change implements the Discovery v5.1 wire protocol and
also adds an interactive test suite for this protocol.
This commit is contained in:
Felix Lange 2020-10-14 12:28:17 +02:00 committed by GitHub
parent 4eb01b21c8
commit 524aaf5ec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 3236 additions and 1427 deletions

86
cmd/devp2p/README.md Normal file

@ -0,0 +1,86 @@
# The devp2p command
The devp2p command line tool is a utility for low-level peer-to-peer debugging and
protocol development purposes. It can do many things.
### ENR Decoding
Use `devp2p enrdump <base64>` to verify and display an Ethereum Node Record.
### Node Key Management
The `devp2p key ...` command family deals with node key files.
Run `devp2p key generate mynode.key` to create a new node key in the `mynode.key` file.
Run `devp2p key to-enode mynode.key -ip 127.0.0.1 -tcp 30303` to create an enode:// URL
corresponding to the given node key and address information.
### Maintaining DNS Discovery Node Lists
The devp2p command can create and publish DNS discovery node lists.
Run `devp2p dns sign <directory>` to update the signature of a DNS discovery tree.
Run `devp2p dns sync <enrtree-URL>` to download a complete DNS discovery tree.
Run `devp2p dns to-cloudflare <directory>` to publish a tree to CloudFlare DNS.
Run `devp2p dns to-route53 <directory>` to publish a tree to Amazon Route53.
You can find more information about these commands in the [DNS Discovery Setup Guide][dns-tutorial].
### Discovery v4 Utilities
The `devp2p discv4 ...` command family deals with the [Node Discovery v4][discv4]
protocol.
Run `devp2p discv4 ping <enode/ENR>` to ping a node.
Run `devp2p discv4 resolve <enode/ENR>` to find the most recent node record of a node in
the DHT.
Run `devp2p discv4 crawl <nodes.json path>` to create or update a JSON node set.
### Discovery v5 Utilities
The `devp2p discv5 ...` command family deals with the [Node Discovery v5][discv5]
protocol. This protocol is currently under active development.
Run `devp2p discv5 ping <ENR>` to ping a node.
Run `devp2p discv5 resolve <ENR>` to find the most recent node record of a node in
the discv5 DHT.
Run `devp2p discv5 listen` to run a Discovery v5 node.
Run `devp2p discv5 crawl <nodes.json path>` to create or update a JSON node set containing
discv5 nodes.
### Discovery Test Suites
The devp2p command also contains interactive test suites for Discovery v4 and Discovery
v5.
To run these tests against your implementation, you need to set up a networking
environment where two separate UDP listening addresses are available on the same machine.
The two listening addresses must also be routed such that they are able to reach the node
you want to test.
For example, if you want to run the test on your local host, and the node under test is
also on the local host, you need to assign two IP addresses (or a larger range) to your
loopback interface. On macOS, this can be done by executing the following command:
sudo ifconfig lo0 add 127.0.0.2
You can now run either test suite as follows: Start the node under test first, ensuring
that it won't talk to the Internet (i.e. disable bootstrapping). An easy way to prevent
unintended connections to the global DHT is listening on `127.0.0.1`.
Now get the ENR of your node and store it in the `NODE` environment variable.
Start the test by running `devp2p discv5 test -listen1 127.0.0.1 -listen2 127.0.0.2 $NODE`.
[dns-tutorial]: https://geth.ethereum.org/docs/developers/dns-discovery-setup
[discv4]: https://github.com/ethereum/devp2p/tree/master/discv4.md
[discv5]: https://github.com/ethereum/devp2p/tree/master/discv5/discv5.md

@ -286,7 +286,11 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn {
} }
usocket := socket.(*net.UDPConn) usocket := socket.(*net.UDPConn)
uaddr := socket.LocalAddr().(*net.UDPAddr) uaddr := socket.LocalAddr().(*net.UDPAddr)
ln.SetFallbackIP(net.IP{127, 0, 0, 1}) if uaddr.IP.IsUnspecified() {
ln.SetFallbackIP(net.IP{127, 0, 0, 1})
} else {
ln.SetFallbackIP(uaddr.IP)
}
ln.SetFallbackUDP(uaddr.Port) ln.SetFallbackUDP(uaddr.Port)
return usocket return usocket
} }
@ -294,7 +298,11 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn {
func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) {
s := params.RinkebyBootnodes s := params.RinkebyBootnodes
if ctx.IsSet(bootnodesFlag.Name) { if ctx.IsSet(bootnodesFlag.Name) {
s = strings.Split(ctx.String(bootnodesFlag.Name), ",") input := ctx.String(bootnodesFlag.Name)
if input == "" {
return nil, nil
}
s = strings.Split(input, ",")
} }
nodes := make([]*enode.Node, len(s)) nodes := make([]*enode.Node, len(s))
var err error var err error

@ -18,9 +18,13 @@ package main
import ( import (
"fmt" "fmt"
"os"
"time" "time"
"github.com/ethereum/go-ethereum/cmd/devp2p/internal/v5test"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"gopkg.in/urfave/cli.v1" "gopkg.in/urfave/cli.v1"
) )
@ -33,6 +37,7 @@ var (
discv5PingCommand, discv5PingCommand,
discv5ResolveCommand, discv5ResolveCommand,
discv5CrawlCommand, discv5CrawlCommand,
discv5TestCommand,
discv5ListenCommand, discv5ListenCommand,
}, },
} }
@ -53,6 +58,12 @@ var (
Action: discv5Crawl, Action: discv5Crawl,
Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag},
} }
discv5TestCommand = cli.Command{
Name: "test",
Usage: "Runs protocol tests against a node",
Action: discv5Test,
Flags: []cli.Flag{testPatternFlag, testListen1Flag, testListen2Flag},
}
discv5ListenCommand = cli.Command{ discv5ListenCommand = cli.Command{
Name: "listen", Name: "listen",
Usage: "Runs a node", Usage: "Runs a node",
@ -103,6 +114,30 @@ func discv5Crawl(ctx *cli.Context) error {
return nil return nil
} }
func discv5Test(ctx *cli.Context) error {
// Disable logging unless explicitly enabled.
if !ctx.GlobalIsSet("verbosity") && !ctx.GlobalIsSet("vmodule") {
log.Root().SetHandler(log.DiscardHandler())
}
// Filter and run test cases.
suite := &v5test.Suite{
Dest: getNodeArg(ctx),
Listen1: ctx.String(testListen1Flag.Name),
Listen2: ctx.String(testListen2Flag.Name),
}
tests := suite.AllTests()
if ctx.IsSet(testPatternFlag.Name) {
tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name))
}
results := utesting.RunTests(tests, os.Stdout)
if fails := utesting.CountFailures(results); fails > 0 {
return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests))
}
fmt.Printf("%v/%v passed\n", len(tests), len(tests))
return nil
}
func discv5Listen(ctx *cli.Context) error { func discv5Listen(ctx *cli.Context) error {
disc := startV5(ctx) disc := startV5(ctx)
defer disc.Close() defer disc.Close()

@ -0,0 +1,377 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package v5test
import (
"bytes"
"net"
"sync"
"time"
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/netutil"
)
// Suite is the discv5 test suite.
type Suite struct {
Dest *enode.Node
Listen1, Listen2 string // listening addresses
}
func (s *Suite) listen1(log logger) (*conn, net.PacketConn) {
c := newConn(s.Dest, log)
l := c.listen(s.Listen1)
return c, l
}
func (s *Suite) listen2(log logger) (*conn, net.PacketConn, net.PacketConn) {
c := newConn(s.Dest, log)
l1, l2 := c.listen(s.Listen1), c.listen(s.Listen2)
return c, l1, l2
}
func (s *Suite) AllTests() []utesting.Test {
return []utesting.Test{
{Name: "Ping", Fn: s.TestPing},
{Name: "PingLargeRequestID", Fn: s.TestPingLargeRequestID},
{Name: "PingMultiIP", Fn: s.TestPingMultiIP},
{Name: "PingHandshakeInterrupted", Fn: s.TestPingHandshakeInterrupted},
{Name: "TalkRequest", Fn: s.TestTalkRequest},
{Name: "FindnodeZeroDistance", Fn: s.TestFindnodeZeroDistance},
{Name: "FindnodeResults", Fn: s.TestFindnodeResults},
}
}
// This test sends PING and expects a PONG response.
func (s *Suite) TestPing(t *utesting.T) {
conn, l1 := s.listen1(t)
defer conn.close()
ping := &v5wire.Ping{ReqID: conn.nextReqID()}
switch resp := conn.reqresp(l1, ping).(type) {
case *v5wire.Pong:
checkPong(t, resp, ping, l1)
default:
t.Fatal("expected PONG, got", resp.Name())
}
}
func checkPong(t *utesting.T, pong *v5wire.Pong, ping *v5wire.Ping, c net.PacketConn) {
if !bytes.Equal(pong.ReqID, ping.ReqID) {
t.Fatalf("wrong request ID %x in PONG, want %x", pong.ReqID, ping.ReqID)
}
if !pong.ToIP.Equal(laddr(c).IP) {
t.Fatalf("wrong destination IP %v in PONG, want %v", pong.ToIP, laddr(c).IP)
}
if int(pong.ToPort) != laddr(c).Port {
t.Fatalf("wrong destination port %v in PONG, want %v", pong.ToPort, laddr(c).Port)
}
}
// This test sends PING with a 9-byte request ID, which isn't allowed by the spec.
// The remote node should not respond.
func (s *Suite) TestPingLargeRequestID(t *utesting.T) {
conn, l1 := s.listen1(t)
defer conn.close()
ping := &v5wire.Ping{ReqID: make([]byte, 9)}
switch resp := conn.reqresp(l1, ping).(type) {
case *v5wire.Pong:
t.Errorf("PONG response with unknown request ID %x", resp.ReqID)
case *readError:
if resp.err == v5wire.ErrInvalidReqID {
t.Error("response with oversized request ID")
} else if !netutil.IsTimeout(resp.err) {
t.Error(resp)
}
}
}
// In this test, a session is established from one IP as usual. The session is then reused
// on another IP, which shouldn't work. The remote node should respond with WHOAREYOU for
// the attempt from a different IP.
func (s *Suite) TestPingMultiIP(t *utesting.T) {
conn, l1, l2 := s.listen2(t)
defer conn.close()
// Create the session on l1.
ping := &v5wire.Ping{ReqID: conn.nextReqID()}
resp := conn.reqresp(l1, ping)
if resp.Kind() != v5wire.PongMsg {
t.Fatal("expected PONG, got", resp)
}
checkPong(t, resp.(*v5wire.Pong), ping, l1)
// Send on l2. This reuses the session because there is only one codec.
ping2 := &v5wire.Ping{ReqID: conn.nextReqID()}
conn.write(l2, ping2, nil)
switch resp := conn.read(l2).(type) {
case *v5wire.Pong:
t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l2).IP, laddr(l1).IP)
case *v5wire.Whoareyou:
t.Logf("got WHOAREYOU for new session as expected")
resp.Node = s.Dest
conn.write(l2, ping2, resp)
default:
t.Fatal("expected WHOAREYOU, got", resp)
}
// Catch the PONG on l2.
switch resp := conn.read(l2).(type) {
case *v5wire.Pong:
checkPong(t, resp, ping2, l2)
default:
t.Fatal("expected PONG, got", resp)
}
// Try on l1 again.
ping3 := &v5wire.Ping{ReqID: conn.nextReqID()}
conn.write(l1, ping3, nil)
switch resp := conn.read(l1).(type) {
case *v5wire.Pong:
t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l1).IP, laddr(l2).IP)
case *v5wire.Whoareyou:
t.Logf("got WHOAREYOU for new session as expected")
default:
t.Fatal("expected WHOAREYOU, got", resp)
}
}
// This test starts a handshake, but doesn't finish it and sends a second ordinary message
// packet instead of a handshake message packet. The remote node should respond with
// another WHOAREYOU challenge for the second packet.
func (s *Suite) TestPingHandshakeInterrupted(t *utesting.T) {
conn, l1 := s.listen1(t)
defer conn.close()
// First PING triggers challenge.
ping := &v5wire.Ping{ReqID: conn.nextReqID()}
conn.write(l1, ping, nil)
switch resp := conn.read(l1).(type) {
case *v5wire.Whoareyou:
t.Logf("got WHOAREYOU for PING")
default:
t.Fatal("expected WHOAREYOU, got", resp)
}
// Send second PING.
ping2 := &v5wire.Ping{ReqID: conn.nextReqID()}
switch resp := conn.reqresp(l1, ping2).(type) {
case *v5wire.Pong:
checkPong(t, resp, ping2, l1)
default:
t.Fatal("expected WHOAREYOU, got", resp)
}
}
// This test sends TALKREQ and expects an empty TALKRESP response.
func (s *Suite) TestTalkRequest(t *utesting.T) {
conn, l1 := s.listen1(t)
defer conn.close()
// Non-empty request ID.
id := conn.nextReqID()
resp := conn.reqresp(l1, &v5wire.TalkRequest{ReqID: id, Protocol: "test-protocol"})
switch resp := resp.(type) {
case *v5wire.TalkResponse:
if !bytes.Equal(resp.ReqID, id) {
t.Fatalf("wrong request ID %x in TALKRESP, want %x", resp.ReqID, id)
}
if len(resp.Message) > 0 {
t.Fatalf("non-empty message %x in TALKRESP", resp.Message)
}
default:
t.Fatal("expected TALKRESP, got", resp.Name())
}
// Empty request ID.
resp = conn.reqresp(l1, &v5wire.TalkRequest{Protocol: "test-protocol"})
switch resp := resp.(type) {
case *v5wire.TalkResponse:
if len(resp.ReqID) > 0 {
t.Fatalf("wrong request ID %x in TALKRESP, want empty byte array", resp.ReqID)
}
if len(resp.Message) > 0 {
t.Fatalf("non-empty message %x in TALKRESP", resp.Message)
}
default:
t.Fatal("expected TALKRESP, got", resp.Name())
}
}
// This test checks that the remote node returns itself for FINDNODE with distance zero.
func (s *Suite) TestFindnodeZeroDistance(t *utesting.T) {
conn, l1 := s.listen1(t)
defer conn.close()
nodes, err := conn.findnode(l1, []uint{0})
if err != nil {
t.Fatal(err)
}
if len(nodes) != 1 {
t.Fatalf("remote returned more than one node for FINDNODE [0]")
}
if nodes[0].ID() != conn.remote.ID() {
t.Errorf("ID of response node is %v, want %v", nodes[0].ID(), conn.remote.ID())
}
}
// In this test, multiple nodes ping the node under test. After waiting for them to be
// accepted into the remote table, the test checks that they are returned by FINDNODE.
func (s *Suite) TestFindnodeResults(t *utesting.T) {
// Create bystanders.
nodes := make([]*bystander, 5)
added := make(chan enode.ID, len(nodes))
for i := range nodes {
nodes[i] = newBystander(t, s, added)
defer nodes[i].close()
}
// Get them added to the remote table.
timeout := 60 * time.Second
timeoutCh := time.After(timeout)
for count := 0; count < len(nodes); {
select {
case id := <-added:
t.Logf("bystander node %v added to remote table", id)
count++
case <-timeoutCh:
t.Errorf("remote added %d bystander nodes in %v, need %d to continue", count, timeout, len(nodes))
t.Logf("this can happen if the node has a non-empty table from previous runs")
return
}
}
t.Logf("all %d bystander nodes were added", len(nodes))
// Collect our nodes by distance.
var dists []uint
expect := make(map[enode.ID]*enode.Node)
for _, bn := range nodes {
n := bn.conn.localNode.Node()
expect[n.ID()] = n
d := uint(enode.LogDist(n.ID(), s.Dest.ID()))
if !containsUint(dists, d) {
dists = append(dists, d)
}
}
// Send FINDNODE for all distances.
conn, l1 := s.listen1(t)
defer conn.close()
foundNodes, err := conn.findnode(l1, dists)
if err != nil {
t.Fatal(err)
}
t.Logf("remote returned %d nodes for distance list %v", len(foundNodes), dists)
for _, n := range foundNodes {
delete(expect, n.ID())
}
if len(expect) > 0 {
t.Errorf("missing %d nodes in FINDNODE result", len(expect))
t.Logf("this can happen if the test is run multiple times in quick succession")
t.Logf("and the remote node hasn't removed dead nodes from previous runs yet")
} else {
t.Logf("all %d expected nodes were returned", len(nodes))
}
}
// A bystander is a node whose only purpose is filling a spot in the remote table.
type bystander struct {
dest *enode.Node
conn *conn
l net.PacketConn
addedCh chan enode.ID
done sync.WaitGroup
}
func newBystander(t *utesting.T, s *Suite, added chan enode.ID) *bystander {
conn, l := s.listen1(t)
conn.setEndpoint(l) // bystander nodes need IP/port to get pinged
bn := &bystander{
conn: conn,
l: l,
dest: s.Dest,
addedCh: added,
}
bn.done.Add(1)
go bn.loop()
return bn
}
// id returns the node ID of the bystander.
func (bn *bystander) id() enode.ID {
return bn.conn.localNode.ID()
}
// close shuts down loop.
func (bn *bystander) close() {
bn.conn.close()
bn.done.Wait()
}
// loop answers packets from the remote node until quit.
func (bn *bystander) loop() {
defer bn.done.Done()
var (
lastPing time.Time
wasAdded bool
)
for {
// Ping the remote node.
if !wasAdded && time.Since(lastPing) > 10*time.Second {
bn.conn.reqresp(bn.l, &v5wire.Ping{
ReqID: bn.conn.nextReqID(),
ENRSeq: bn.dest.Seq(),
})
lastPing = time.Now()
}
// Answer packets.
switch p := bn.conn.read(bn.l).(type) {
case *v5wire.Ping:
bn.conn.write(bn.l, &v5wire.Pong{
ReqID: p.ReqID,
ENRSeq: bn.conn.localNode.Seq(),
ToIP: bn.dest.IP(),
ToPort: uint16(bn.dest.UDP()),
}, nil)
wasAdded = true
bn.notifyAdded()
case *v5wire.Findnode:
bn.conn.write(bn.l, &v5wire.Nodes{ReqID: p.ReqID, Total: 1}, nil)
wasAdded = true
bn.notifyAdded()
case *v5wire.TalkRequest:
bn.conn.write(bn.l, &v5wire.TalkResponse{ReqID: p.ReqID}, nil)
case *readError:
if !netutil.IsTemporaryError(p.err) {
bn.conn.logf("shutting down: %v", p.err)
return
}
}
}
}
func (bn *bystander) notifyAdded() {
if bn.addedCh != nil {
bn.addedCh <- bn.id()
bn.addedCh = nil
}
}

@ -0,0 +1,263 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
package v5test
import (
"bytes"
"crypto/ecdsa"
"encoding/binary"
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)
// readError represents an error during packet reading.
// This exists to facilitate type-switching on the result of conn.read.
type readError struct {
err error
}
func (p *readError) Kind() byte { return 99 }
func (p *readError) Name() string { return fmt.Sprintf("error: %v", p.err) }
func (p *readError) Error() string { return p.err.Error() }
func (p *readError) Unwrap() error { return p.err }
func (p *readError) RequestID() []byte { return nil }
func (p *readError) SetRequestID([]byte) {}
// readErrorf creates a readError with the given text.
func readErrorf(format string, args ...interface{}) *readError {
return &readError{fmt.Errorf(format, args...)}
}
// This is the response timeout used in tests.
const waitTime = 300 * time.Millisecond
// conn is a connection to the node under test.
type conn struct {
localNode *enode.LocalNode
localKey *ecdsa.PrivateKey
remote *enode.Node
remoteAddr *net.UDPAddr
listeners []net.PacketConn
log logger
codec *v5wire.Codec
lastRequest v5wire.Packet
lastChallenge *v5wire.Whoareyou
idCounter uint32
}
type logger interface {
Logf(string, ...interface{})
}
// newConn sets up a connection to the given node.
func newConn(dest *enode.Node, log logger) *conn {
key, err := crypto.GenerateKey()
if err != nil {
panic(err)
}
db, err := enode.OpenDB("")
if err != nil {
panic(err)
}
ln := enode.NewLocalNode(db, key)
return &conn{
localKey: key,
localNode: ln,
remote: dest,
remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()},
codec: v5wire.NewCodec(ln, key, mclock.System{}),
log: log,
}
}
func (tc *conn) setEndpoint(c net.PacketConn) {
tc.localNode.SetStaticIP(laddr(c).IP)
tc.localNode.SetFallbackUDP(laddr(c).Port)
}
func (tc *conn) listen(ip string) net.PacketConn {
l, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", ip))
if err != nil {
panic(err)
}
tc.listeners = append(tc.listeners, l)
return l
}
// close shuts down all listeners and the local node.
func (tc *conn) close() {
for _, l := range tc.listeners {
l.Close()
}
tc.localNode.Database().Close()
}
// nextReqID creates a request id.
func (tc *conn) nextReqID() []byte {
id := make([]byte, 4)
tc.idCounter++
binary.BigEndian.PutUint32(id, tc.idCounter)
return id
}
// reqresp performs a request/response interaction on the given connection.
// The request is retried if a handshake is requested.
func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet {
reqnonce := tc.write(c, req, nil)
switch resp := tc.read(c).(type) {
case *v5wire.Whoareyou:
if resp.Nonce != reqnonce {
return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:])
}
resp.Node = tc.remote
tc.write(c, req, resp)
return tc.read(c)
default:
return resp
}
}
// findnode sends a FINDNODE request and waits for its responses.
func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) {
var (
findnode = &v5wire.Findnode{ReqID: tc.nextReqID(), Distances: dists}
reqnonce = tc.write(c, findnode, nil)
first = true
total uint8
results []*enode.Node
)
for n := 1; n > 0; {
switch resp := tc.read(c).(type) {
case *v5wire.Whoareyou:
// Handle handshake.
if resp.Nonce == reqnonce {
resp.Node = tc.remote
tc.write(c, findnode, resp)
} else {
return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:])
}
case *v5wire.Ping:
// Handle ping from remote.
tc.write(c, &v5wire.Pong{
ReqID: resp.ReqID,
ENRSeq: tc.localNode.Seq(),
}, nil)
case *v5wire.Nodes:
// Got NODES! Check request ID.
if !bytes.Equal(resp.ReqID, findnode.ReqID) {
return nil, fmt.Errorf("NODES response has wrong request id %x", resp.ReqID)
}
// Check total count. It should be greater than one
// and needs to be the same across all responses.
if first {
if resp.Total == 0 || resp.Total > 6 {
return nil, fmt.Errorf("invalid NODES response 'total' %d (not in (0,7))", resp.Total)
}
total = resp.Total
n = int(total) - 1
first = false
} else {
n--
if resp.Total != total {
return nil, fmt.Errorf("invalid NODES response 'total' %d (!= %d)", resp.Total, total)
}
}
// Check nodes.
nodes, err := checkRecords(resp.Nodes)
if err != nil {
return nil, fmt.Errorf("invalid node in NODES response: %v", err)
}
results = append(results, nodes...)
default:
return nil, fmt.Errorf("expected NODES, got %v", resp)
}
}
return results, nil
}
// write sends a packet on the given connection.
func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce {
packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge)
if err != nil {
panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err))
}
if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil {
tc.logf("Can't send %s: %v", p.Name(), err)
} else {
tc.logf(">> %s", p.Name())
}
return nonce
}
// read waits for an incoming packet on the given connection.
func (tc *conn) read(c net.PacketConn) v5wire.Packet {
buf := make([]byte, 1280)
if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
return &readError{err}
}
n, fromAddr, err := c.ReadFrom(buf)
if err != nil {
return &readError{err}
}
_, _, p, err := tc.codec.Decode(buf[:n], fromAddr.String())
if err != nil {
return &readError{err}
}
tc.logf("<< %s", p.Name())
return p
}
// logf prints to the test log.
func (tc *conn) logf(format string, args ...interface{}) {
if tc.log != nil {
tc.log.Logf("(%s) %s", tc.localNode.ID().TerminalString(), fmt.Sprintf(format, args...))
}
}
func laddr(c net.PacketConn) *net.UDPAddr {
return c.LocalAddr().(*net.UDPAddr)
}
func checkRecords(records []*enr.Record) ([]*enode.Node, error) {
nodes := make([]*enode.Node, len(records))
for i := range records {
n, err := enode.New(enode.ValidSchemes, records[i])
if err != nil {
return nil, err
}
nodes[i] = n
}
return nodes, nil
}
func containsUint(ints []uint, x uint) bool {
for i := range ints {
if ints[i] == x {
return true
}
}
return false
}

@ -65,10 +65,17 @@ func MatchTests(tests []Test, expr string) []Test {
func RunTests(tests []Test, report io.Writer) []Result { func RunTests(tests []Test, report io.Writer) []Result {
results := make([]Result, len(tests)) results := make([]Result, len(tests))
for i, test := range tests { for i, test := range tests {
var output io.Writer
buffer := new(bytes.Buffer)
output = buffer
if report != nil {
output = io.MultiWriter(buffer, report)
}
start := time.Now() start := time.Now()
results[i].Name = test.Name results[i].Name = test.Name
results[i].Failed, results[i].Output = Run(test) results[i].Failed = run(test, output)
results[i].Duration = time.Since(start) results[i].Duration = time.Since(start)
results[i].Output = buffer.String()
if report != nil { if report != nil {
printResult(results[i], report) printResult(results[i], report)
} }
@ -80,7 +87,6 @@ func printResult(r Result, w io.Writer) {
pd := r.Duration.Truncate(100 * time.Microsecond) pd := r.Duration.Truncate(100 * time.Microsecond)
if r.Failed { if r.Failed {
fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd) fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd)
fmt.Fprintln(w, r.Output)
} else { } else {
fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd) fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd)
} }
@ -99,7 +105,13 @@ func CountFailures(rr []Result) int {
// Run executes a single test. // Run executes a single test.
func Run(test Test) (bool, string) { func Run(test Test) (bool, string) {
t := new(T) output := new(bytes.Buffer)
failed := run(test, output)
return failed, output.String()
}
func run(test Test, output io.Writer) bool {
t := &T{output: output}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
@ -114,7 +126,7 @@ func Run(test Test) (bool, string) {
test.Fn(t) test.Fn(t)
}() }()
<-done <-done
return t.failed, t.output.String() return t.failed
} }
// T is the value given to the test function. The test can signal failures // T is the value given to the test function. The test can signal failures
@ -122,7 +134,7 @@ func Run(test Test) (bool, string) {
type T struct { type T struct {
mu sync.Mutex mu sync.Mutex
failed bool failed bool
output bytes.Buffer output io.Writer
} }
// FailNow marks the test as having failed and stops its execution by calling // FailNow marks the test as having failed and stops its execution by calling
@ -151,7 +163,7 @@ func (t *T) Failed() bool {
func (t *T) Log(vs ...interface{}) { func (t *T) Log(vs ...interface{}) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
fmt.Fprintln(&t.output, vs...) fmt.Fprintln(t.output, vs...)
} }
// Logf formats its arguments according to the format, analogous to Printf, and records // Logf formats its arguments according to the format, analogous to Printf, and records
@ -162,7 +174,7 @@ func (t *T) Logf(format string, vs ...interface{}) {
if len(format) == 0 || format[len(format)-1] != '\n' { if len(format) == 0 || format[len(format)-1] != '\n' {
format += "\n" format += "\n"
} }
fmt.Fprintf(&t.output, format, vs...) fmt.Fprintf(t.output, format, vs...)
} }
// Error is equivalent to Log followed by Fail. // Error is equivalent to Log followed by Fail.

@ -46,7 +46,10 @@ func encodePubkey(key *ecdsa.PublicKey) encPubkey {
return e return e
} }
func decodePubkey(curve elliptic.Curve, e encPubkey) (*ecdsa.PublicKey, error) { func decodePubkey(curve elliptic.Curve, e []byte) (*ecdsa.PublicKey, error) {
if len(e) != len(encPubkey{}) {
return nil, errors.New("wrong size public key data")
}
p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)}
half := len(e) / 2 half := len(e) / 2
p.X.SetBytes(e[:half]) p.X.SetBytes(e[:half])

@ -146,7 +146,6 @@ func (t *pingRecorder) updateRecord(n *enode.Node) {
func (t *pingRecorder) Self() *enode.Node { return nullNode } func (t *pingRecorder) Self() *enode.Node { return nullNode }
func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } func (t *pingRecorder) lookupSelf() []*enode.Node { return nil }
func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
func (t *pingRecorder) close() {}
// ping simulates a ping request. // ping simulates a ping request.
func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
@ -188,15 +187,16 @@ func hasDuplicates(slice []*node) bool {
return false return false
} }
// checkNodesEqual checks whether the two given node lists contain the same nodes.
func checkNodesEqual(got, want []*enode.Node) error { func checkNodesEqual(got, want []*enode.Node) error {
if len(got) == len(want) { if len(got) == len(want) {
for i := range got { for i := range got {
if !nodeEqual(got[i], want[i]) { if !nodeEqual(got[i], want[i]) {
goto NotEqual goto NotEqual
} }
return nil
} }
} }
return nil
NotEqual: NotEqual:
output := new(bytes.Buffer) output := new(bytes.Buffer)
@ -227,6 +227,7 @@ func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
}) })
} }
// hexEncPrivkey decodes h as a private key.
func hexEncPrivkey(h string) *ecdsa.PrivateKey { func hexEncPrivkey(h string) *ecdsa.PrivateKey {
b, err := hex.DecodeString(h) b, err := hex.DecodeString(h)
if err != nil { if err != nil {
@ -239,6 +240,7 @@ func hexEncPrivkey(h string) *ecdsa.PrivateKey {
return key return key
} }
// hexEncPubkey decodes h as a public key.
func hexEncPubkey(h string) (ret encPubkey) { func hexEncPubkey(h string) (ret encPubkey) {
b, err := hex.DecodeString(h) b, err := hex.DecodeString(h)
if err != nil { if err != nil {

@ -34,7 +34,7 @@ func TestUDPv4_Lookup(t *testing.T) {
test := newUDPTest(t) test := newUDPTest(t)
// Lookup on empty table returns no nodes. // Lookup on empty table returns no nodes.
targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target) targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target[:])
if results := test.udp.LookupPubkey(targetKey); len(results) > 0 { if results := test.udp.LookupPubkey(targetKey); len(results) > 0 {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
} }
@ -279,17 +279,21 @@ func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node {
return result return result
} }
func (tn *preminedTestnet) neighborsAtDistance(base *enode.Node, distance uint, elems int) []*enode.Node { func (tn *preminedTestnet) neighborsAtDistances(base *enode.Node, distances []uint, elems int) []*enode.Node {
nodes := nodesByDistance{target: base.ID()} var result []*enode.Node
for d := range lookupTestnet.dists { for d := range lookupTestnet.dists {
for i := range lookupTestnet.dists[d] { for i := range lookupTestnet.dists[d] {
n := lookupTestnet.node(d, i) n := lookupTestnet.node(d, i)
if uint(enode.LogDist(n.ID(), base.ID())) == distance { d := enode.LogDist(base.ID(), n.ID())
nodes.push(wrapNode(n), elems) if containsUint(uint(d), distances) {
result = append(result, n)
if len(result) >= elems {
return result
}
} }
} }
} }
return unwrapNodes(nodes.entries) return result
} }
func (tn *preminedTestnet) closest(n int) (nodes []*enode.Node) { func (tn *preminedTestnet) closest(n int) (nodes []*enode.Node) {

@ -1,659 +0,0 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
crand "crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"hash"
"net"
"time"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
"golang.org/x/crypto/hkdf"
)
// TODO concurrent WHOAREYOU tie-breaker
// TODO deal with WHOAREYOU amplification factor (min packet size?)
// TODO add counter to nonce
// TODO rehandshake after X packets
// Discovery v5 packet types.
const (
p_pingV5 byte = iota + 1
p_pongV5
p_findnodeV5
p_nodesV5
p_requestTicketV5
p_ticketV5
p_regtopicV5
p_regconfirmationV5
p_topicqueryV5
p_unknownV5 = byte(255) // any non-decryptable packet
p_whoareyouV5 = byte(254) // the WHOAREYOU packet
)
// Discovery v5 packet structures.
type (
// unknownV5 represents any packet that can't be decrypted.
unknownV5 struct {
AuthTag []byte
}
// WHOAREYOU contains the handshake challenge.
whoareyouV5 struct {
AuthTag []byte
IDNonce [32]byte // To be signed by recipient.
RecordSeq uint64 // ENR sequence number of recipient
node *enode.Node
sent mclock.AbsTime
}
// PING is sent during liveness checks.
pingV5 struct {
ReqID []byte
ENRSeq uint64
}
// PONG is the reply to PING.
pongV5 struct {
ReqID []byte
ENRSeq uint64
ToIP net.IP // These fields should mirror the UDP envelope address of the ping
ToPort uint16 // packet, which provides a way to discover the the external address (after NAT).
}
// FINDNODE is a query for nodes in the given bucket.
findnodeV5 struct {
ReqID []byte
Distance uint
}
// NODES is the reply to FINDNODE and TOPICQUERY.
nodesV5 struct {
ReqID []byte
Total uint8
Nodes []*enr.Record
}
// REQUESTTICKET requests a ticket for a topic queue.
requestTicketV5 struct {
ReqID []byte
Topic []byte
}
// TICKET is the response to REQUESTTICKET.
ticketV5 struct {
ReqID []byte
Ticket []byte
}
// REGTOPIC registers the sender in a topic queue using a ticket.
regtopicV5 struct {
ReqID []byte
Ticket []byte
ENR *enr.Record
}
// REGCONFIRMATION is the reply to REGTOPIC.
regconfirmationV5 struct {
ReqID []byte
Registered bool
}
// TOPICQUERY asks for nodes with the given topic.
topicqueryV5 struct {
ReqID []byte
Topic []byte
}
)
const (
// Encryption/authentication parameters.
authSchemeName = "gcm"
aesKeySize = 16
gcmNonceSize = 12
idNoncePrefix = "discovery-id-nonce"
handshakeTimeout = time.Second
)
var (
errTooShort = errors.New("packet too short")
errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake")
errHandshakeNonceMismatch = errors.New("wrong nonce in auth response")
errInvalidAuthKey = errors.New("invalid ephemeral pubkey")
errUnknownAuthScheme = errors.New("unknown auth scheme in handshake")
errNoRecord = errors.New("expected ENR in handshake but none sent")
errInvalidNonceSig = errors.New("invalid ID nonce signature")
zeroNonce = make([]byte, gcmNonceSize)
)
// wireCodec encodes and decodes discovery v5 packets.
type wireCodec struct {
sha256 hash.Hash
localnode *enode.LocalNode
privkey *ecdsa.PrivateKey
myChtagHash enode.ID
myWhoareyouMagic []byte
sc *sessionCache
}
type handshakeSecrets struct {
writeKey, readKey, authRespKey []byte
}
type authHeader struct {
authHeaderList
isHandshake bool
}
type authHeaderList struct {
Auth []byte // authentication info of packet
IDNonce [32]byte // IDNonce of WHOAREYOU
Scheme string // name of encryption/authentication scheme
EphemeralKey []byte // ephemeral public key
Response []byte // encrypted authResponse
}
type authResponse struct {
Version uint
Signature []byte
Record *enr.Record `rlp:"nil"` // sender's record
}
func (h *authHeader) DecodeRLP(r *rlp.Stream) error {
k, _, err := r.Kind()
if err != nil {
return err
}
if k == rlp.Byte || k == rlp.String {
return r.Decode(&h.Auth)
}
h.isHandshake = true
return r.Decode(&h.authHeaderList)
}
// ephemeralKey decodes the ephemeral public key in the header.
func (h *authHeaderList) ephemeralKey(curve elliptic.Curve) *ecdsa.PublicKey {
var key encPubkey
copy(key[:], h.EphemeralKey)
pubkey, _ := decodePubkey(curve, key)
return pubkey
}
// newWireCodec creates a wire codec.
func newWireCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *wireCodec {
c := &wireCodec{
sha256: sha256.New(),
localnode: ln,
privkey: key,
sc: newSessionCache(1024, clock),
}
// Create magic strings for packet matching.
self := ln.ID()
c.myWhoareyouMagic = c.sha256sum(self[:], []byte("WHOAREYOU"))
copy(c.myChtagHash[:], c.sha256sum(self[:]))
return c
}
// encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The
// 'challenge' parameter should be the most recently received WHOAREYOU packet from that
// node.
func (c *wireCodec) encode(id enode.ID, addr string, packet packetV5, challenge *whoareyouV5) ([]byte, []byte, error) {
if packet.kind() == p_whoareyouV5 {
p := packet.(*whoareyouV5)
enc, err := c.encodeWhoareyou(id, p)
if err == nil {
c.sc.storeSentHandshake(id, addr, p)
}
return enc, nil, err
}
// Ensure calling code sets node if needed.
if challenge != nil && challenge.node == nil {
panic("BUG: missing challenge.node in encode")
}
writeKey := c.sc.writeKey(id, addr)
if writeKey != nil || challenge != nil {
return c.encodeEncrypted(id, addr, packet, writeKey, challenge)
}
return c.encodeRandom(id)
}
// encodeRandom encodes a random packet.
func (c *wireCodec) encodeRandom(toID enode.ID) ([]byte, []byte, error) {
tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID())
r := make([]byte, 44) // TODO randomize size
if _, err := crand.Read(r); err != nil {
return nil, nil, err
}
nonce := make([]byte, gcmNonceSize)
if _, err := crand.Read(nonce); err != nil {
return nil, nil, fmt.Errorf("can't get random data: %v", err)
}
b := new(bytes.Buffer)
b.Write(tag[:])
rlp.Encode(b, nonce)
b.Write(r)
return b.Bytes(), nonce, nil
}
// encodeWhoareyou encodes WHOAREYOU.
func (c *wireCodec) encodeWhoareyou(toID enode.ID, packet *whoareyouV5) ([]byte, error) {
// Sanity check node field to catch misbehaving callers.
if packet.RecordSeq > 0 && packet.node == nil {
panic("BUG: missing node in whoareyouV5 with non-zero seq")
}
b := new(bytes.Buffer)
b.Write(c.sha256sum(toID[:], []byte("WHOAREYOU")))
err := rlp.Encode(b, packet)
return b.Bytes(), err
}
// encodeEncrypted encodes an encrypted packet.
func (c *wireCodec) encodeEncrypted(toID enode.ID, toAddr string, packet packetV5, writeKey []byte, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) {
nonce := make([]byte, gcmNonceSize)
if _, err := crand.Read(nonce); err != nil {
return nil, nil, fmt.Errorf("can't get random data: %v", err)
}
var headEnc []byte
if challenge == nil {
// Regular packet, use existing key and simply encode nonce.
headEnc, _ = rlp.EncodeToBytes(nonce)
} else {
// We're answering WHOAREYOU, generate new keys and encrypt with those.
header, sec, err := c.makeAuthHeader(nonce, challenge)
if err != nil {
return nil, nil, err
}
if headEnc, err = rlp.EncodeToBytes(header); err != nil {
return nil, nil, err
}
c.sc.storeNewSession(toID, toAddr, sec.readKey, sec.writeKey)
writeKey = sec.writeKey
}
// Encode the packet.
body := new(bytes.Buffer)
body.WriteByte(packet.kind())
if err := rlp.Encode(body, packet); err != nil {
return nil, nil, err
}
tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID())
headsize := len(tag) + len(headEnc)
headbuf := make([]byte, headsize)
copy(headbuf[:], tag[:])
copy(headbuf[len(tag):], headEnc)
// Encrypt the body.
enc, err = encryptGCM(headbuf, writeKey, nonce, body.Bytes(), tag[:])
return enc, nonce, err
}
// encodeAuthHeader creates the auth header on a call packet following WHOAREYOU.
func (c *wireCodec) makeAuthHeader(nonce []byte, challenge *whoareyouV5) (*authHeaderList, *handshakeSecrets, error) {
resp := &authResponse{Version: 5}
// Add our record to response if it's newer than what remote
// side has.
ln := c.localnode.Node()
if challenge.RecordSeq < ln.Seq() {
resp.Record = ln.Record()
}
// Create the ephemeral key. This needs to be first because the
// key is part of the ID nonce signature.
var remotePubkey = new(ecdsa.PublicKey)
if err := challenge.node.Load((*enode.Secp256k1)(remotePubkey)); err != nil {
return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient")
}
ephkey, err := crypto.GenerateKey()
if err != nil {
return nil, nil, fmt.Errorf("can't generate ephemeral key")
}
ephpubkey := encodePubkey(&ephkey.PublicKey)
// Add ID nonce signature to response.
idsig, err := c.signIDNonce(challenge.IDNonce[:], ephpubkey[:])
if err != nil {
return nil, nil, fmt.Errorf("can't sign: %v", err)
}
resp.Signature = idsig
// Create session keys.
sec := c.deriveKeys(c.localnode.ID(), challenge.node.ID(), ephkey, remotePubkey, challenge)
if sec == nil {
return nil, nil, fmt.Errorf("key derivation failed")
}
// Encrypt the authentication response and assemble the auth header.
respRLP, err := rlp.EncodeToBytes(resp)
if err != nil {
return nil, nil, fmt.Errorf("can't encode auth response: %v", err)
}
respEnc, err := encryptGCM(nil, sec.authRespKey, zeroNonce, respRLP, nil)
if err != nil {
return nil, nil, fmt.Errorf("can't encrypt auth response: %v", err)
}
head := &authHeaderList{
Auth: nonce,
Scheme: authSchemeName,
IDNonce: challenge.IDNonce,
EphemeralKey: ephpubkey[:],
Response: respEnc,
}
return head, sec, err
}
// deriveKeys generates session keys using elliptic-curve Diffie-Hellman key agreement.
func (c *wireCodec) deriveKeys(n1, n2 enode.ID, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, challenge *whoareyouV5) *handshakeSecrets {
eph := ecdh(priv, pub)
if eph == nil {
return nil
}
info := []byte("discovery v5 key agreement")
info = append(info, n1[:]...)
info = append(info, n2[:]...)
kdf := hkdf.New(sha256.New, eph, challenge.IDNonce[:], info)
sec := handshakeSecrets{
writeKey: make([]byte, aesKeySize),
readKey: make([]byte, aesKeySize),
authRespKey: make([]byte, aesKeySize),
}
kdf.Read(sec.writeKey)
kdf.Read(sec.readKey)
kdf.Read(sec.authRespKey)
for i := range eph {
eph[i] = 0
}
return &sec
}
// signIDNonce creates the ID nonce signature.
func (c *wireCodec) signIDNonce(nonce, ephkey []byte) ([]byte, error) {
idsig, err := crypto.Sign(c.idNonceHash(nonce, ephkey), c.privkey)
if err != nil {
return nil, fmt.Errorf("can't sign: %v", err)
}
return idsig[:len(idsig)-1], nil // remove recovery ID
}
// idNonceHash computes the hash of id nonce with prefix.
func (c *wireCodec) idNonceHash(nonce, ephkey []byte) []byte {
h := c.sha256reset()
h.Write([]byte(idNoncePrefix))
h.Write(nonce)
h.Write(ephkey)
return h.Sum(nil)
}
// decode decodes a discovery packet.
func (c *wireCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) {
// Delete timed-out handshakes. This must happen before decoding to avoid
// processing the same handshake twice.
c.sc.handshakeGC()
if len(input) < 32 {
return enode.ID{}, nil, nil, errTooShort
}
if bytes.HasPrefix(input, c.myWhoareyouMagic) {
p, err := c.decodeWhoareyou(input)
return enode.ID{}, nil, p, err
}
sender := xorTag(input[:32], c.myChtagHash)
p, n, err := c.decodeEncrypted(sender, addr, input)
return sender, n, p, err
}
// decodeWhoareyou decode a WHOAREYOU packet.
func (c *wireCodec) decodeWhoareyou(input []byte) (packetV5, error) {
packet := new(whoareyouV5)
err := rlp.DecodeBytes(input[32:], packet)
return packet, err
}
// decodeEncrypted decodes an encrypted discovery packet.
func (c *wireCodec) decodeEncrypted(fromID enode.ID, fromAddr string, input []byte) (packetV5, *enode.Node, error) {
// Decode packet header.
var head authHeader
r := bytes.NewReader(input[32:])
err := rlp.Decode(r, &head)
if err != nil {
return nil, nil, err
}
// Decrypt and process auth response.
readKey, node, err := c.decodeAuth(fromID, fromAddr, &head)
if err != nil {
return nil, nil, err
}
// Decrypt and decode the packet body.
headsize := len(input) - r.Len()
bodyEnc := input[headsize:]
body, err := decryptGCM(readKey, head.Auth, bodyEnc, input[:32])
if err != nil {
if !head.isHandshake {
// Can't decrypt, start handshake.
return &unknownV5{AuthTag: head.Auth}, nil, nil
}
return nil, nil, fmt.Errorf("handshake failed: %v", err)
}
if len(body) == 0 {
return nil, nil, errTooShort
}
p, err := decodePacketBodyV5(body[0], body[1:])
return p, node, err
}
// decodeAuth processes an auth header.
func (c *wireCodec) decodeAuth(fromID enode.ID, fromAddr string, head *authHeader) ([]byte, *enode.Node, error) {
if !head.isHandshake {
return c.sc.readKey(fromID, fromAddr), nil, nil
}
// Remote is attempting handshake. Verify against our last WHOAREYOU.
challenge := c.sc.getHandshake(fromID, fromAddr)
if challenge == nil {
return nil, nil, errUnexpectedHandshake
}
if head.IDNonce != challenge.IDNonce {
return nil, nil, errHandshakeNonceMismatch
}
sec, n, err := c.decodeAuthResp(fromID, fromAddr, &head.authHeaderList, challenge)
if err != nil {
return nil, n, err
}
// Swap keys to match remote.
sec.readKey, sec.writeKey = sec.writeKey, sec.readKey
c.sc.storeNewSession(fromID, fromAddr, sec.readKey, sec.writeKey)
c.sc.deleteHandshake(fromID, fromAddr)
return sec.readKey, n, err
}
// decodeAuthResp decodes and verifies an authentication response.
func (c *wireCodec) decodeAuthResp(fromID enode.ID, fromAddr string, head *authHeaderList, challenge *whoareyouV5) (*handshakeSecrets, *enode.Node, error) {
// Decrypt / decode the response.
if head.Scheme != authSchemeName {
return nil, nil, errUnknownAuthScheme
}
ephkey := head.ephemeralKey(c.privkey.Curve)
if ephkey == nil {
return nil, nil, errInvalidAuthKey
}
sec := c.deriveKeys(fromID, c.localnode.ID(), c.privkey, ephkey, challenge)
respPT, err := decryptGCM(sec.authRespKey, zeroNonce, head.Response, nil)
if err != nil {
return nil, nil, fmt.Errorf("can't decrypt auth response header: %v", err)
}
var resp authResponse
if err := rlp.DecodeBytes(respPT, &resp); err != nil {
return nil, nil, fmt.Errorf("invalid auth response: %v", err)
}
// Verify response node record. The remote node should include the record
// if we don't have one or if ours is older than the latest version.
node := challenge.node
if resp.Record != nil {
if node == nil || node.Seq() < resp.Record.Seq() {
n, err := enode.New(enode.ValidSchemes, resp.Record)
if err != nil {
return nil, nil, fmt.Errorf("invalid node record: %v", err)
}
if n.ID() != fromID {
return nil, nil, fmt.Errorf("record in auth respose has wrong ID: %v", n.ID())
}
node = n
}
}
if node == nil {
return nil, nil, errNoRecord
}
// Verify ID nonce signature.
err = c.verifyIDSignature(challenge.IDNonce[:], head.EphemeralKey, resp.Signature, node)
if err != nil {
return nil, nil, err
}
return sec, node, nil
}
// verifyIDSignature checks that signature over idnonce was made by the node with given record.
func (c *wireCodec) verifyIDSignature(nonce, ephkey, sig []byte, n *enode.Node) error {
switch idscheme := n.Record().IdentityScheme(); idscheme {
case "v4":
var pk ecdsa.PublicKey
n.Load((*enode.Secp256k1)(&pk)) // cannot fail because record is valid
if !crypto.VerifySignature(crypto.FromECDSAPub(&pk), c.idNonceHash(nonce, ephkey), sig) {
return errInvalidNonceSig
}
return nil
default:
return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme)
}
}
// decodePacketBody decodes the body of an encrypted discovery packet.
func decodePacketBodyV5(ptype byte, body []byte) (packetV5, error) {
var dec packetV5
switch ptype {
case p_pingV5:
dec = new(pingV5)
case p_pongV5:
dec = new(pongV5)
case p_findnodeV5:
dec = new(findnodeV5)
case p_nodesV5:
dec = new(nodesV5)
case p_requestTicketV5:
dec = new(requestTicketV5)
case p_ticketV5:
dec = new(ticketV5)
case p_regtopicV5:
dec = new(regtopicV5)
case p_regconfirmationV5:
dec = new(regconfirmationV5)
case p_topicqueryV5:
dec = new(topicqueryV5)
default:
return nil, fmt.Errorf("unknown packet type %d", ptype)
}
if err := rlp.DecodeBytes(body, dec); err != nil {
return nil, err
}
return dec, nil
}
// sha256reset returns the shared hash instance.
func (c *wireCodec) sha256reset() hash.Hash {
c.sha256.Reset()
return c.sha256
}
// sha256sum computes sha256 on the concatenation of inputs.
func (c *wireCodec) sha256sum(inputs ...[]byte) []byte {
c.sha256.Reset()
for _, b := range inputs {
c.sha256.Write(b)
}
return c.sha256.Sum(nil)
}
func xorTag(a []byte, b enode.ID) enode.ID {
var r enode.ID
for i := range r {
r[i] = a[i] ^ b[i]
}
return r
}
// ecdh creates a shared secret.
func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte {
secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes())
if secX == nil {
return nil
}
sec := make([]byte, 33)
sec[0] = 0x02 | byte(secY.Bit(0))
math.ReadBits(secX, sec[1:])
return sec
}
// encryptGCM encrypts pt using AES-GCM with the given key and nonce.
func encryptGCM(dest, key, nonce, pt, authData []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
panic(fmt.Errorf("can't create block cipher: %v", err))
}
aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
panic(fmt.Errorf("can't create GCM: %v", err))
}
return aesgcm.Seal(dest, nonce, pt, authData), nil
}
// decryptGCM decrypts ct using AES-GCM with the given key and nonce.
func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("can't create block cipher: %v", err)
}
if len(nonce) != gcmNonceSize {
return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce))
}
aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
return nil, fmt.Errorf("can't create GCM: %v", err)
}
pt := make([]byte, 0, len(ct))
return aesgcm.Open(pt, nonce, ct, authData)
}

@ -1,373 +0,0 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover
import (
"bytes"
"crypto/ecdsa"
"encoding/hex"
"fmt"
"net"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
)
var (
testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f")
testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628")
testIDnonce = [32]byte{5, 6, 7, 8, 9, 10, 11, 12}
)
func TestDeriveKeysV5(t *testing.T) {
t.Parallel()
var (
n1 = enode.ID{1}
n2 = enode.ID{2}
challenge = &whoareyouV5{}
db, _ = enode.OpenDB("")
ln = enode.NewLocalNode(db, testKeyA)
c = newWireCodec(ln, testKeyA, mclock.System{})
)
defer db.Close()
sec1 := c.deriveKeys(n1, n2, testKeyA, &testKeyB.PublicKey, challenge)
sec2 := c.deriveKeys(n1, n2, testKeyB, &testKeyA.PublicKey, challenge)
if sec1 == nil || sec2 == nil {
t.Fatal("key agreement failed")
}
if !reflect.DeepEqual(sec1, sec2) {
t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2)
}
}
// This test checks the basic handshake flow where A talks to B and A has no secrets.
func TestHandshakeV5(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{})
resp := net.nodeB.expectDecode(t, p_unknownV5, packet)
// A <- B WHOAREYOU
challenge := &whoareyouV5{
AuthTag: resp.(*unknownV5).AuthTag,
IDNonce: testIDnonce,
RecordSeq: 0,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou)
// A -> B FINDNODE
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{})
net.nodeB.expectDecode(t, p_findnodeV5, findnode)
if len(net.nodeB.c.sc.handshakes) > 0 {
t.Fatalf("node B didn't remove handshake from challenge map")
}
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1})
net.nodeA.expectDecode(t, p_nodesV5, nodes)
}
// This test checks that handshake attempts are removed within the timeout.
func TestHandshakeV5_timeout(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{})
resp := net.nodeB.expectDecode(t, p_unknownV5, packet)
// A <- B WHOAREYOU
challenge := &whoareyouV5{
AuthTag: resp.(*unknownV5).AuthTag,
IDNonce: testIDnonce,
RecordSeq: 0,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou)
// A -> B FINDNODE after timeout
net.clock.Run(handshakeTimeout + 1)
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{})
net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode)
}
// This test checks handshake behavior when no record is sent in the auth response.
func TestHandshakeV5_norecord(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{})
resp := net.nodeB.expectDecode(t, p_unknownV5, packet)
// A <- B WHOAREYOU
nodeA := net.nodeA.n()
if nodeA.Seq() == 0 {
t.Fatal("need non-zero sequence number")
}
challenge := &whoareyouV5{
AuthTag: resp.(*unknownV5).AuthTag,
IDNonce: testIDnonce,
RecordSeq: nodeA.Seq(),
node: nodeA,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou)
// A -> B FINDNODE
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{})
net.nodeB.expectDecode(t, p_findnodeV5, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1})
net.nodeA.expectDecode(t, p_nodesV5, nodes)
}
// In this test, A tries to send FINDNODE with existing secrets but B doesn't know
// anything about A.
func TestHandshakeV5_rekey(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
initKeys := &handshakeSecrets{
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeys.readKey, initKeys.writeKey)
// A -> B FINDNODE (encrypted with zero keys)
findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{})
net.nodeB.expectDecode(t, p_unknownV5, findnode)
// A <- B WHOAREYOU
challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou)
// Check that new keys haven't been stored yet.
if s := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()); !bytes.Equal(s.writeKey, initKeys.writeKey) || !bytes.Equal(s.readKey, initKeys.readKey) {
t.Fatal("node A stored keys too early")
}
if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil {
t.Fatal("node B stored keys too early")
}
// A -> B FINDNODE encrypted with new keys
findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{})
net.nodeB.expectDecode(t, p_findnodeV5, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1})
net.nodeA.expectDecode(t, p_nodesV5, nodes)
}
// In this test A and B have different keys before the handshake.
func TestHandshakeV5_rekey2(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
initKeysA := &handshakeSecrets{
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
initKeysB := &handshakeSecrets{
readKey: []byte("CCCCCCCCCCCCCCCC"),
writeKey: []byte("DDDDDDDDDDDDDDDD"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA.readKey, initKeysA.writeKey)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB.readKey, initKeysA.writeKey)
// A -> B FINDNODE encrypted with initKeysA
findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{Distance: 3})
net.nodeB.expectDecode(t, p_unknownV5, findnode)
// A <- B WHOAREYOU
challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou)
// A -> B FINDNODE encrypted with new keys
findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{})
net.nodeB.expectDecode(t, p_findnodeV5, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1})
net.nodeA.expectDecode(t, p_nodesV5, nodes)
}
// This test checks some malformed packets.
func TestDecodeErrorsV5(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
net.nodeA.expectDecodeErr(t, errTooShort, []byte{})
// TODO some more tests would be nice :)
}
// This benchmark checks performance of authHeader decoding, verification and key derivation.
func BenchmarkV5_DecodeAuthSecp256k1(b *testing.B) {
net := newHandshakeTest()
defer net.close()
var (
idA = net.nodeA.id()
addrA = net.nodeA.addr()
challenge = &whoareyouV5{AuthTag: []byte("authresp"), RecordSeq: 0, node: net.nodeB.n()}
nonce = make([]byte, gcmNonceSize)
)
header, _, _ := net.nodeA.c.makeAuthHeader(nonce, challenge)
challenge.node = nil // force ENR signature verification in decoder
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := net.nodeB.c.decodeAuthResp(idA, addrA, header, challenge)
if err != nil {
b.Fatal(err)
}
}
}
// This benchmark checks how long it takes to decode an encrypted ping packet.
func BenchmarkV5_DecodePing(b *testing.B) {
net := newHandshakeTest()
defer net.close()
r := []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17}
w := []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), r, w)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), w, r)
addrB := net.nodeA.addr()
ping := &pingV5{ReqID: []byte("reqid"), ENRSeq: 5}
enc, _, err := net.nodeA.c.encode(net.nodeB.id(), addrB, ping, nil)
if err != nil {
b.Fatalf("can't encode: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, p, _ := net.nodeB.c.decode(enc, addrB)
if _, ok := p.(*pingV5); !ok {
b.Fatalf("wrong packet type %T", p)
}
}
}
var pp = spew.NewDefaultConfig()
type handshakeTest struct {
nodeA, nodeB handshakeTestNode
clock mclock.Simulated
}
type handshakeTestNode struct {
ln *enode.LocalNode
c *wireCodec
}
func newHandshakeTest() *handshakeTest {
t := new(handshakeTest)
t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock)
t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock)
return t
}
func (t *handshakeTest) close() {
t.nodeA.ln.Database().Close()
t.nodeB.ln.Database().Close()
}
func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) {
db, _ := enode.OpenDB("")
n.ln = enode.NewLocalNode(db, key)
n.ln.SetStaticIP(ip)
n.c = newWireCodec(n.ln, key, clock)
}
func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p packetV5) ([]byte, []byte) {
t.Helper()
return n.encodeWithChallenge(t, to, nil, p)
}
func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *whoareyouV5, p packetV5) ([]byte, []byte) {
t.Helper()
// Copy challenge and add destination node. This avoids sharing 'c' among the two codecs.
var challenge *whoareyouV5
if c != nil {
challengeCopy := *c
challenge = &challengeCopy
challenge.node = to.n()
}
// Encode to destination.
enc, authTag, err := n.c.encode(to.id(), to.addr(), p, challenge)
if err != nil {
t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err))
}
t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.name(), hex.Dump(enc))
return enc, authTag
}
func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) packetV5 {
t.Helper()
dec, err := n.decode(p)
if err != nil {
t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err))
}
t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec))
if dec.kind() != ptype {
t.Fatalf("expected packet type %d, got %d", ptype, dec.kind())
}
return dec
}
func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) {
t.Helper()
if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) {
t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr))
}
}
func (n *handshakeTestNode) decode(input []byte) (packetV5, error) {
_, _, p, err := n.c.decode(input, "127.0.0.1")
return p, err
}
func (n *handshakeTestNode) n() *enode.Node {
return n.ln.Node()
}
func (n *handshakeTestNode) addr() string {
return n.ln.Node().IP().String()
}
func (n *handshakeTestNode) id() enode.ID {
return n.ln.ID()
}

@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
@ -38,36 +39,24 @@ import (
const ( const (
lookupRequestLimit = 3 // max requests against a single node during lookup lookupRequestLimit = 3 // max requests against a single node during lookup
findnodeResultLimit = 15 // applies in FINDNODE handler findnodeResultLimit = 16 // applies in FINDNODE handler
totalNodesResponseLimit = 5 // applies in waitForNodes totalNodesResponseLimit = 5 // applies in waitForNodes
nodesResponseItemLimit = 3 // applies in sendNodes nodesResponseItemLimit = 3 // applies in sendNodes
respTimeoutV5 = 700 * time.Millisecond respTimeoutV5 = 700 * time.Millisecond
) )
// codecV5 is implemented by wireCodec (and testCodec). // codecV5 is implemented by v5wire.Codec (and testCodec).
// //
// The UDPv5 transport is split into two objects: the codec object deals with // The UDPv5 transport is split into two objects: the codec object deals with
// encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns. // encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns.
type codecV5 interface { type codecV5 interface {
// encode encodes a packet. The 'challenge' parameter is non-nil for calls which got a // Encode encodes a packet.
// WHOAREYOU response. Encode(enode.ID, string, v5wire.Packet, *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error)
encode(fromID enode.ID, fromAddr string, p packetV5, challenge *whoareyouV5) (enc []byte, authTag []byte, err error)
// decode decodes a packet. It returns an *unknownV5 packet if decryption fails. // decode decodes a packet. It returns a *v5wire.Unknown packet if decryption fails.
// The fromNode return value is non-nil when the input contains a handshake response. // The *enode.Node return value is non-nil when the input contains a handshake response.
decode(input []byte, fromAddr string) (fromID enode.ID, fromNode *enode.Node, p packetV5, err error) Decode([]byte, string) (enode.ID, *enode.Node, v5wire.Packet, error)
}
// packetV5 is implemented by all discv5 packet type structs.
type packetV5 interface {
// These methods provide information and set the request ID.
name() string
kind() byte
setreqid([]byte)
// handle should perform the appropriate action to handle the packet, i.e. this is the
// place to send the response.
handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr)
} }
// UDPv5 is the implementation of protocol version 5. // UDPv5 is the implementation of protocol version 5.
@ -83,6 +72,10 @@ type UDPv5 struct {
clock mclock.Clock clock mclock.Clock
validSchemes enr.IdentityScheme validSchemes enr.IdentityScheme
// talkreq handler registry
trlock sync.Mutex
trhandlers map[string]func([]byte) []byte
// channels into dispatch // channels into dispatch
packetInCh chan ReadPacket packetInCh chan ReadPacket
readNextCh chan struct{} readNextCh chan struct{}
@ -93,7 +86,7 @@ type UDPv5 struct {
// state of dispatch // state of dispatch
codec codecV5 codec codecV5
activeCallByNode map[enode.ID]*callV5 activeCallByNode map[enode.ID]*callV5
activeCallByAuth map[string]*callV5 activeCallByAuth map[v5wire.Nonce]*callV5
callQueue map[enode.ID][]*callV5 callQueue map[enode.ID][]*callV5
// shutdown stuff // shutdown stuff
@ -106,16 +99,16 @@ type UDPv5 struct {
// callV5 represents a remote procedure call against another node. // callV5 represents a remote procedure call against another node.
type callV5 struct { type callV5 struct {
node *enode.Node node *enode.Node
packet packetV5 packet v5wire.Packet
responseType byte // expected packet type of response responseType byte // expected packet type of response
reqid []byte reqid []byte
ch chan packetV5 // responses sent here ch chan v5wire.Packet // responses sent here
err chan error // errors sent here err chan error // errors sent here
// Valid for active calls only: // Valid for active calls only:
authTag []byte // authTag of request packet nonce v5wire.Nonce // nonce of request packet
handshakeCount int // # times we attempted handshake for this call handshakeCount int // # times we attempted handshake for this call
challenge *whoareyouV5 // last sent handshake challenge challenge *v5wire.Whoareyou // last sent handshake challenge
timeout mclock.Timer timeout mclock.Timer
} }
@ -152,6 +145,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
log: cfg.Log, log: cfg.Log,
validSchemes: cfg.ValidSchemes, validSchemes: cfg.ValidSchemes,
clock: cfg.Clock, clock: cfg.Clock,
trhandlers: make(map[string]func([]byte) []byte),
// channels into dispatch // channels into dispatch
packetInCh: make(chan ReadPacket, 1), packetInCh: make(chan ReadPacket, 1),
readNextCh: make(chan struct{}, 1), readNextCh: make(chan struct{}, 1),
@ -159,9 +153,9 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
callDoneCh: make(chan *callV5), callDoneCh: make(chan *callV5),
respTimeoutCh: make(chan *callTimeout), respTimeoutCh: make(chan *callTimeout),
// state of dispatch // state of dispatch
codec: newWireCodec(ln, cfg.PrivateKey, cfg.Clock), codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock),
activeCallByNode: make(map[enode.ID]*callV5), activeCallByNode: make(map[enode.ID]*callV5),
activeCallByAuth: make(map[string]*callV5), activeCallByAuth: make(map[v5wire.Nonce]*callV5),
callQueue: make(map[enode.ID][]*callV5), callQueue: make(map[enode.ID][]*callV5),
// shutdown // shutdown
closeCtx: closeCtx, closeCtx: closeCtx,
@ -236,6 +230,29 @@ func (t *UDPv5) LocalNode() *enode.LocalNode {
return t.localNode return t.localNode
} }
// RegisterTalkHandler adds a handler for 'talk requests'. The handler function is called
// whenever a request for the given protocol is received and should return the response
// data or nil.
func (t *UDPv5) RegisterTalkHandler(protocol string, handler func([]byte) []byte) {
t.trlock.Lock()
defer t.trlock.Unlock()
t.trhandlers[protocol] = handler
}
// TalkRequest sends a talk request to n and waits for a response.
func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]byte, error) {
req := &v5wire.TalkRequest{Protocol: protocol, Message: request}
resp := t.call(n, v5wire.TalkResponseMsg, req)
defer t.callDone(resp)
select {
case respMsg := <-resp.ch:
return respMsg.(*v5wire.TalkResponse).Message, nil
case err := <-resp.err:
return nil, err
}
}
// RandomNodes returns an iterator that finds random nodes in the DHT.
func (t *UDPv5) RandomNodes() enode.Iterator { func (t *UDPv5) RandomNodes() enode.Iterator {
if t.tab.len() == 0 { if t.tab.len() == 0 {
// All nodes were dropped, refresh. The very first query will hit this // All nodes were dropped, refresh. The very first query will hit this
@ -283,16 +300,14 @@ func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) {
nodes = nodesByDistance{target: target} nodes = nodesByDistance{target: target}
err error err error
) )
for i := 0; i < lookupRequestLimit && len(nodes.entries) < findnodeResultLimit; i++ { var r []*enode.Node
var r []*enode.Node r, err = t.findnode(unwrapNode(destNode), dists)
r, err = t.findnode(unwrapNode(destNode), dists[i]) if err == errClosed {
if err == errClosed { return nil, err
return nil, err }
} for _, n := range r {
for _, n := range r { if n.ID() != t.Self().ID() {
if n.ID() != t.Self().ID() { nodes.push(wrapNode(n), findnodeResultLimit)
nodes.push(wrapNode(n), findnodeResultLimit)
}
} }
} }
return nodes.entries, err return nodes.entries, err
@ -301,15 +316,15 @@ func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) {
// lookupDistances computes the distance parameter for FINDNODE calls to dest. // lookupDistances computes the distance parameter for FINDNODE calls to dest.
// It chooses distances adjacent to logdist(target, dest), e.g. for a target // It chooses distances adjacent to logdist(target, dest), e.g. for a target
// with logdist(target, dest) = 255 the result is [255, 256, 254]. // with logdist(target, dest) = 255 the result is [255, 256, 254].
func lookupDistances(target, dest enode.ID) (dists []int) { func lookupDistances(target, dest enode.ID) (dists []uint) {
td := enode.LogDist(target, dest) td := enode.LogDist(target, dest)
dists = append(dists, td) dists = append(dists, uint(td))
for i := 1; len(dists) < lookupRequestLimit; i++ { for i := 1; len(dists) < lookupRequestLimit; i++ {
if td+i < 256 { if td+i < 256 {
dists = append(dists, td+i) dists = append(dists, uint(td+i))
} }
if td-i > 0 { if td-i > 0 {
dists = append(dists, td-i) dists = append(dists, uint(td-i))
} }
} }
return dists return dists
@ -317,11 +332,13 @@ func lookupDistances(target, dest enode.ID) (dists []int) {
// ping calls PING on a node and waits for a PONG response. // ping calls PING on a node and waits for a PONG response.
func (t *UDPv5) ping(n *enode.Node) (uint64, error) { func (t *UDPv5) ping(n *enode.Node) (uint64, error) {
resp := t.call(n, p_pongV5, &pingV5{ENRSeq: t.localNode.Node().Seq()}) req := &v5wire.Ping{ENRSeq: t.localNode.Node().Seq()}
resp := t.call(n, v5wire.PongMsg, req)
defer t.callDone(resp) defer t.callDone(resp)
select { select {
case pong := <-resp.ch: case pong := <-resp.ch:
return pong.(*pongV5).ENRSeq, nil return pong.(*v5wire.Pong).ENRSeq, nil
case err := <-resp.err: case err := <-resp.err:
return 0, err return 0, err
} }
@ -329,7 +346,7 @@ func (t *UDPv5) ping(n *enode.Node) (uint64, error) {
// requestENR requests n's record. // requestENR requests n's record.
func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) {
nodes, err := t.findnode(n, 0) nodes, err := t.findnode(n, []uint{0})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,26 +356,14 @@ func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) {
return nodes[0], nil return nodes[0], nil
} }
// requestTicket calls REQUESTTICKET on a node and waits for a TICKET response.
func (t *UDPv5) requestTicket(n *enode.Node) ([]byte, error) {
resp := t.call(n, p_ticketV5, &pingV5{})
defer t.callDone(resp)
select {
case response := <-resp.ch:
return response.(*ticketV5).Ticket, nil
case err := <-resp.err:
return nil, err
}
}
// findnode calls FINDNODE on a node and waits for responses. // findnode calls FINDNODE on a node and waits for responses.
func (t *UDPv5) findnode(n *enode.Node, distance int) ([]*enode.Node, error) { func (t *UDPv5) findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) {
resp := t.call(n, p_nodesV5, &findnodeV5{Distance: uint(distance)}) resp := t.call(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances})
return t.waitForNodes(resp, distance) return t.waitForNodes(resp, distances)
} }
// waitForNodes waits for NODES responses to the given call. // waitForNodes waits for NODES responses to the given call.
func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { func (t *UDPv5) waitForNodes(c *callV5, distances []uint) ([]*enode.Node, error) {
defer t.callDone(c) defer t.callDone(c)
var ( var (
@ -369,11 +374,11 @@ func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) {
for { for {
select { select {
case responseP := <-c.ch: case responseP := <-c.ch:
response := responseP.(*nodesV5) response := responseP.(*v5wire.Nodes)
for _, record := range response.Nodes { for _, record := range response.Nodes {
node, err := t.verifyResponseNode(c, record, distance, seen) node, err := t.verifyResponseNode(c, record, distances, seen)
if err != nil { if err != nil {
t.log.Debug("Invalid record in "+response.name(), "id", c.node.ID(), "err", err) t.log.Debug("Invalid record in "+response.Name(), "id", c.node.ID(), "err", err)
continue continue
} }
nodes = append(nodes, node) nodes = append(nodes, node)
@ -391,7 +396,7 @@ func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) {
} }
// verifyResponseNode checks validity of a record in a NODES response. // verifyResponseNode checks validity of a record in a NODES response.
func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen map[enode.ID]struct{}) (*enode.Node, error) { func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, seen map[enode.ID]struct{}) (*enode.Node, error) {
node, err := enode.New(t.validSchemes, r) node, err := enode.New(t.validSchemes, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -402,9 +407,10 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen
if c.node.UDP() <= 1024 { if c.node.UDP() <= 1024 {
return nil, errLowPort return nil, errLowPort
} }
if distance != -1 { if distances != nil {
if d := enode.LogDist(c.node.ID(), node.ID()); d != distance { nd := enode.LogDist(c.node.ID(), node.ID())
return nil, fmt.Errorf("wrong distance %d", d) if !containsUint(uint(nd), distances) {
return nil, errors.New("does not match any requested distance")
} }
} }
if _, ok := seen[node.ID()]; ok { if _, ok := seen[node.ID()]; ok {
@ -414,20 +420,29 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen
return node, nil return node, nil
} }
// call sends the given call and sets up a handler for response packets (of type c.responseType). func containsUint(x uint, xs []uint) bool {
// Responses are dispatched to the call's response channel. for _, v := range xs {
func (t *UDPv5) call(node *enode.Node, responseType byte, packet packetV5) *callV5 { if x == v {
return true
}
}
return false
}
// call sends the given call and sets up a handler for response packets (of message type
// responseType). Responses are dispatched to the call's response channel.
func (t *UDPv5) call(node *enode.Node, responseType byte, packet v5wire.Packet) *callV5 {
c := &callV5{ c := &callV5{
node: node, node: node,
packet: packet, packet: packet,
responseType: responseType, responseType: responseType,
reqid: make([]byte, 8), reqid: make([]byte, 8),
ch: make(chan packetV5, 1), ch: make(chan v5wire.Packet, 1),
err: make(chan error, 1), err: make(chan error, 1),
} }
// Assign request ID. // Assign request ID.
crand.Read(c.reqid) crand.Read(c.reqid)
packet.setreqid(c.reqid) packet.SetRequestID(c.reqid)
// Send call to dispatch. // Send call to dispatch.
select { select {
case t.callCh <- c: case t.callCh <- c:
@ -482,7 +497,7 @@ func (t *UDPv5) dispatch() {
panic("BUG: callDone for inactive call") panic("BUG: callDone for inactive call")
} }
c.timeout.Stop() c.timeout.Stop()
delete(t.activeCallByAuth, string(c.authTag)) delete(t.activeCallByAuth, c.nonce)
delete(t.activeCallByNode, id) delete(t.activeCallByNode, id)
t.sendNextCall(id) t.sendNextCall(id)
@ -502,7 +517,7 @@ func (t *UDPv5) dispatch() {
for id, c := range t.activeCallByNode { for id, c := range t.activeCallByNode {
c.err <- errClosed c.err <- errClosed
delete(t.activeCallByNode, id) delete(t.activeCallByNode, id)
delete(t.activeCallByAuth, string(c.authTag)) delete(t.activeCallByAuth, c.nonce)
} }
return return
} }
@ -548,38 +563,37 @@ func (t *UDPv5) sendNextCall(id enode.ID) {
// sendCall encodes and sends a request packet to the call's recipient node. // sendCall encodes and sends a request packet to the call's recipient node.
// This performs a handshake if needed. // This performs a handshake if needed.
func (t *UDPv5) sendCall(c *callV5) { func (t *UDPv5) sendCall(c *callV5) {
if len(c.authTag) > 0 { // The call might have a nonce from a previous handshake attempt. Remove the entry for
// The call already has an authTag from a previous handshake attempt. Remove the // the old nonce because we're about to generate a new nonce for this call.
// entry for the authTag because we're about to generate a new authTag for this if c.nonce != (v5wire.Nonce{}) {
// call. delete(t.activeCallByAuth, c.nonce)
delete(t.activeCallByAuth, string(c.authTag))
} }
addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()} addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()}
newTag, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) newNonce, _ := t.send(c.node.ID(), addr, c.packet, c.challenge)
c.authTag = newTag c.nonce = newNonce
t.activeCallByAuth[string(c.authTag)] = c t.activeCallByAuth[newNonce] = c
t.startResponseTimeout(c) t.startResponseTimeout(c)
} }
// sendResponse sends a response packet to the given node. // sendResponse sends a response packet to the given node.
// This doesn't trigger a handshake even if no keys are available. // This doesn't trigger a handshake even if no keys are available.
func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet packetV5) error { func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error {
_, err := t.send(toID, toAddr, packet, nil) _, err := t.send(toID, toAddr, packet, nil)
return err return err
} }
// send sends a packet to the given node. // send sends a packet to the given node.
func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet packetV5, c *whoareyouV5) ([]byte, error) { func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) {
addr := toAddr.String() addr := toAddr.String()
enc, authTag, err := t.codec.encode(toID, addr, packet, c) enc, nonce, err := t.codec.Encode(toID, addr, packet, c)
if err != nil { if err != nil {
t.log.Warn(">> "+packet.name(), "id", toID, "addr", addr, "err", err) t.log.Warn(">> "+packet.Name(), "id", toID, "addr", addr, "err", err)
return authTag, err return nonce, err
} }
_, err = t.conn.WriteToUDP(enc, toAddr) _, err = t.conn.WriteToUDP(enc, toAddr)
t.log.Trace(">> "+packet.name(), "id", toID, "addr", addr) t.log.Trace(">> "+packet.Name(), "id", toID, "addr", addr)
return authTag, err return nonce, err
} }
// readLoop runs in its own goroutine and reads packets from the network. // readLoop runs in its own goroutine and reads packets from the network.
@ -617,7 +631,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool {
// handlePacket decodes and processes an incoming packet from the network. // handlePacket decodes and processes an incoming packet from the network.
func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
addr := fromAddr.String() addr := fromAddr.String()
fromID, fromNode, packet, err := t.codec.decode(rawpacket, addr) fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr)
if err != nil { if err != nil {
t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err)
return err return err
@ -626,31 +640,32 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
// Handshake succeeded, add to table. // Handshake succeeded, add to table.
t.tab.addSeenNode(wrapNode(fromNode)) t.tab.addSeenNode(wrapNode(fromNode))
} }
if packet.kind() != p_whoareyouV5 { if packet.Kind() != v5wire.WhoareyouPacket {
// WHOAREYOU logged separately to report the sender ID. // WHOAREYOU logged separately to report errors.
t.log.Trace("<< "+packet.name(), "id", fromID, "addr", addr) t.log.Trace("<< "+packet.Name(), "id", fromID, "addr", addr)
} }
packet.handle(t, fromID, fromAddr) t.handle(packet, fromID, fromAddr)
return nil return nil
} }
// handleCallResponse dispatches a response packet to the call waiting for it. // handleCallResponse dispatches a response packet to the call waiting for it.
func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, reqid []byte, p packetV5) { func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool {
ac := t.activeCallByNode[fromID] ac := t.activeCallByNode[fromID]
if ac == nil || !bytes.Equal(reqid, ac.reqid) { if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) {
t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.name()), "id", fromID, "addr", fromAddr) t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr)
return return false
} }
if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() { if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() {
t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.name()), "id", fromID, "addr", fromAddr) t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr)
return return false
} }
if p.kind() != ac.responseType { if p.Kind() != ac.responseType {
t.log.Debug(fmt.Sprintf("Wrong disv5 response type %s", p.name()), "id", fromID, "addr", fromAddr) t.log.Debug(fmt.Sprintf("Wrong discv5 response type %s", p.Name()), "id", fromID, "addr", fromAddr)
return return false
} }
t.startResponseTimeout(ac) t.startResponseTimeout(ac)
ac.ch <- p ac.ch <- p
return true
} }
// getNode looks for a node record in table and database. // getNode looks for a node record in table and database.
@ -664,50 +679,65 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node {
return nil return nil
} }
// UNKNOWN // handle processes incoming packets according to their message type.
func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) {
switch p := p.(type) {
case *v5wire.Unknown:
t.handleUnknown(p, fromID, fromAddr)
case *v5wire.Whoareyou:
t.handleWhoareyou(p, fromID, fromAddr)
case *v5wire.Ping:
t.handlePing(p, fromID, fromAddr)
case *v5wire.Pong:
if t.handleCallResponse(fromID, fromAddr, p) {
t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)})
}
case *v5wire.Findnode:
t.handleFindnode(p, fromID, fromAddr)
case *v5wire.Nodes:
t.handleCallResponse(fromID, fromAddr, p)
case *v5wire.TalkRequest:
t.handleTalkRequest(p, fromID, fromAddr)
case *v5wire.TalkResponse:
t.handleCallResponse(fromID, fromAddr, p)
}
}
func (p *unknownV5) name() string { return "UNKNOWN/v5" } // handleUnknown initiates a handshake by responding with WHOAREYOU.
func (p *unknownV5) kind() byte { return p_unknownV5 } func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) {
func (p *unknownV5) setreqid(id []byte) {} challenge := &v5wire.Whoareyou{Nonce: p.Nonce}
func (p *unknownV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
challenge := &whoareyouV5{AuthTag: p.AuthTag}
crand.Read(challenge.IDNonce[:]) crand.Read(challenge.IDNonce[:])
if n := t.getNode(fromID); n != nil { if n := t.getNode(fromID); n != nil {
challenge.node = n challenge.Node = n
challenge.RecordSeq = n.Seq() challenge.RecordSeq = n.Seq()
} }
t.sendResponse(fromID, fromAddr, challenge) t.sendResponse(fromID, fromAddr, challenge)
} }
// WHOAREYOU
func (p *whoareyouV5) name() string { return "WHOAREYOU/v5" }
func (p *whoareyouV5) kind() byte { return p_whoareyouV5 }
func (p *whoareyouV5) setreqid(id []byte) {}
func (p *whoareyouV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
c, err := p.matchWithCall(t, p.AuthTag)
if err != nil {
t.log.Debug("Invalid WHOAREYOU/v5", "addr", fromAddr, "err", err)
return
}
// Resend the call that was answered by WHOAREYOU.
t.log.Trace("<< "+p.name(), "id", c.node.ID(), "addr", fromAddr)
c.handshakeCount++
c.challenge = p
p.node = c.node
t.sendCall(c)
}
var ( var (
errChallengeNoCall = errors.New("no matching call") errChallengeNoCall = errors.New("no matching call")
errChallengeTwice = errors.New("second handshake") errChallengeTwice = errors.New("second handshake")
) )
// matchWithCall checks whether the handshake attempt matches the active call. // handleWhoareyou resends the active call as a handshake packet.
func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) { func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) {
c := t.activeCallByAuth[string(authTag)] c, err := t.matchWithCall(fromID, p.Nonce)
if err != nil {
t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err)
return
}
// Resend the call that was answered by WHOAREYOU.
t.log.Trace("<< "+p.Name(), "id", c.node.ID(), "addr", fromAddr)
c.handshakeCount++
c.challenge = p
p.Node = c.node
t.sendCall(c)
}
// matchWithCall checks whether a handshake attempt matches the active call.
func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, error) {
c := t.activeCallByAuth[nonce]
if c == nil { if c == nil {
return nil, errChallengeNoCall return nil, errChallengeNoCall
} }
@ -717,14 +747,9 @@ func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) {
return c, nil return c, nil
} }
// PING // handlePing sends a PONG response.
func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) {
func (p *pingV5) name() string { return "PING/v5" } t.sendResponse(fromID, fromAddr, &v5wire.Pong{
func (p *pingV5) kind() byte { return p_pingV5 }
func (p *pingV5) setreqid(id []byte) { p.ReqID = id }
func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.sendResponse(fromID, fromAddr, &pongV5{
ReqID: p.ReqID, ReqID: p.ReqID,
ToIP: fromAddr.IP, ToIP: fromAddr.IP,
ToPort: uint16(fromAddr.Port), ToPort: uint16(fromAddr.Port),
@ -732,121 +757,81 @@ func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
}) })
} }
// PONG // handleFindnode returns nodes to the requester.
func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) {
func (p *pongV5) name() string { return "PONG/v5" } nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit)
func (p *pongV5) kind() byte { return p_pongV5 } for _, resp := range packNodes(p.ReqID, nodes) {
func (p *pongV5) setreqid(id []byte) { p.ReqID = id } t.sendResponse(fromID, fromAddr, resp)
}
func (p *pongV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)})
t.handleCallResponse(fromID, fromAddr, p.ReqID, p)
} }
// FINDNODE // collectTableNodes creates a FINDNODE result set for the given distances.
func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node {
func (p *findnodeV5) name() string { return "FINDNODE/v5" } var nodes []*enode.Node
func (p *findnodeV5) kind() byte { return p_findnodeV5 } var processed = make(map[uint]struct{})
func (p *findnodeV5) setreqid(id []byte) { p.ReqID = id } for _, dist := range distances {
// Reject duplicate / invalid distances.
func (p *findnodeV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { _, seen := processed[dist]
if p.Distance == 0 { if seen || dist > 256 {
t.sendNodes(fromID, fromAddr, p.ReqID, []*enode.Node{t.Self()}) continue
return }
}
if p.Distance > 256 { // Get the nodes.
p.Distance = 256 var bn []*enode.Node
} if dist == 0 {
// Get bucket entries. bn = []*enode.Node{t.Self()}
t.tab.mutex.Lock() } else if dist <= 256 {
nodes := unwrapNodes(t.tab.bucketAtDistance(int(p.Distance)).entries) t.tab.mutex.Lock()
t.tab.mutex.Unlock() bn = unwrapNodes(t.tab.bucketAtDistance(int(dist)).entries)
if len(nodes) > findnodeResultLimit { t.tab.mutex.Unlock()
nodes = nodes[:findnodeResultLimit] }
} processed[dist] = struct{}{}
t.sendNodes(fromID, fromAddr, p.ReqID, nodes)
} // Apply some pre-checks to avoid sending invalid nodes.
for _, n := range bn {
// sendNodes sends the given records in one or more NODES packets. // TODO livenessChecks > 1
func (t *UDPv5) sendNodes(toID enode.ID, toAddr *net.UDPAddr, reqid []byte, nodes []*enode.Node) { if netutil.CheckRelayIP(rip, n.IP()) != nil {
// TODO livenessChecks > 1 continue
// TODO CheckRelayIP }
total := uint8(math.Ceil(float64(len(nodes)) / 3)) nodes = append(nodes, n)
resp := &nodesV5{ReqID: reqid, Total: total, Nodes: make([]*enr.Record, 3)} if len(nodes) >= limit {
sent := false return nodes
for len(nodes) > 0 { }
items := min(nodesResponseItemLimit, len(nodes)) }
resp.Nodes = resp.Nodes[:items] }
for i := 0; i < items; i++ { return nodes
resp.Nodes[i] = nodes[i].Record() }
// packNodes creates NODES response packets for the given node list.
func packNodes(reqid []byte, nodes []*enode.Node) []*v5wire.Nodes {
if len(nodes) == 0 {
return []*v5wire.Nodes{{ReqID: reqid, Total: 1}}
}
total := uint8(math.Ceil(float64(len(nodes)) / 3))
var resp []*v5wire.Nodes
for len(nodes) > 0 {
p := &v5wire.Nodes{ReqID: reqid, Total: total}
items := min(nodesResponseItemLimit, len(nodes))
for i := 0; i < items; i++ {
p.Nodes = append(p.Nodes, nodes[i].Record())
} }
t.sendResponse(toID, toAddr, resp)
nodes = nodes[items:] nodes = nodes[items:]
sent = true resp = append(resp, p)
} }
// Ensure at least one response is sent. return resp
if !sent { }
resp.Total = 1
resp.Nodes = nil // handleTalkRequest runs the talk request handler of the requested protocol.
t.sendResponse(toID, toAddr, resp) func (t *UDPv5) handleTalkRequest(p *v5wire.TalkRequest, fromID enode.ID, fromAddr *net.UDPAddr) {
t.trlock.Lock()
handler := t.trhandlers[p.Protocol]
t.trlock.Unlock()
var response []byte
if handler != nil {
response = handler(p.Message)
} }
} resp := &v5wire.TalkResponse{ReqID: p.ReqID, Message: response}
t.sendResponse(fromID, fromAddr, resp)
// NODES
func (p *nodesV5) name() string { return "NODES/v5" }
func (p *nodesV5) kind() byte { return p_nodesV5 }
func (p *nodesV5) setreqid(id []byte) { p.ReqID = id }
func (p *nodesV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.handleCallResponse(fromID, fromAddr, p.ReqID, p)
}
// REQUESTTICKET
func (p *requestTicketV5) name() string { return "REQUESTTICKET/v5" }
func (p *requestTicketV5) kind() byte { return p_requestTicketV5 }
func (p *requestTicketV5) setreqid(id []byte) { p.ReqID = id }
func (p *requestTicketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.sendResponse(fromID, fromAddr, &ticketV5{ReqID: p.ReqID})
}
// TICKET
func (p *ticketV5) name() string { return "TICKET/v5" }
func (p *ticketV5) kind() byte { return p_ticketV5 }
func (p *ticketV5) setreqid(id []byte) { p.ReqID = id }
func (p *ticketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.handleCallResponse(fromID, fromAddr, p.ReqID, p)
}
// REGTOPIC
func (p *regtopicV5) name() string { return "REGTOPIC/v5" }
func (p *regtopicV5) kind() byte { return p_regtopicV5 }
func (p *regtopicV5) setreqid(id []byte) { p.ReqID = id }
func (p *regtopicV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.sendResponse(fromID, fromAddr, &regconfirmationV5{ReqID: p.ReqID, Registered: false})
}
// REGCONFIRMATION
func (p *regconfirmationV5) name() string { return "REGCONFIRMATION/v5" }
func (p *regconfirmationV5) kind() byte { return p_regconfirmationV5 }
func (p *regconfirmationV5) setreqid(id []byte) { p.ReqID = id }
func (p *regconfirmationV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
t.handleCallResponse(fromID, fromAddr, p.ReqID, p)
}
// TOPICQUERY
func (p *topicqueryV5) name() string { return "TOPICQUERY/v5" }
func (p *topicqueryV5) kind() byte { return p_topicqueryV5 }
func (p *topicqueryV5) setreqid(id []byte) { p.ReqID = id }
func (p *topicqueryV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) {
} }

@ -24,22 +24,25 @@ import (
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
// Real sockets, real crypto: this test checks end-to-end connectivity for UDPv5. // Real sockets, real crypto: this test checks end-to-end connectivity for UDPv5.
func TestEndToEndV5(t *testing.T) { func TestUDPv5_lookupE2E(t *testing.T) {
t.Parallel() t.Parallel()
const N = 5
var nodes []*UDPv5 var nodes []*UDPv5
for i := 0; i < 5; i++ { for i := 0; i < N; i++ {
var cfg Config var cfg Config
if len(nodes) > 0 { if len(nodes) > 0 {
bn := nodes[0].Self() bn := nodes[0].Self()
@ -49,12 +52,22 @@ func TestEndToEndV5(t *testing.T) {
nodes = append(nodes, node) nodes = append(nodes, node)
defer node.Close() defer node.Close()
} }
last := nodes[N-1]
target := nodes[rand.Intn(N-2)].Self()
last := nodes[len(nodes)-1] // It is expected that all nodes can be found.
target := nodes[rand.Intn(len(nodes)-2)].Self() expectedResult := make([]*enode.Node, len(nodes))
for i := range nodes {
expectedResult[i] = nodes[i].Self()
}
sort.Slice(expectedResult, func(i, j int) bool {
return enode.DistCmp(target.ID(), expectedResult[i].ID(), expectedResult[j].ID()) < 0
})
// Do the lookup.
results := last.Lookup(target.ID()) results := last.Lookup(target.ID())
if len(results) == 0 || results[0].ID() != target.ID() { if err := checkNodesEqual(results, expectedResult); err != nil {
t.Fatalf("lookup returned wrong results: %v", results) t.Fatalf("lookup returned wrong results: %v", err)
} }
} }
@ -93,8 +106,8 @@ func TestUDPv5_pingHandling(t *testing.T) {
test := newUDPV5Test(t) test := newUDPV5Test(t)
defer test.close() defer test.close()
test.packetIn(&pingV5{ReqID: []byte("foo")}) test.packetIn(&v5wire.Ping{ReqID: []byte("foo")})
test.waitPacketOut(func(p *pongV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) { if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID) t.Error("wrong request ID in response:", p.ReqID)
} }
@ -110,13 +123,13 @@ func TestUDPv5_unknownPacket(t *testing.T) {
test := newUDPV5Test(t) test := newUDPV5Test(t)
defer test.close() defer test.close()
authTag := [12]byte{1, 2, 3} nonce := v5wire.Nonce{1, 2, 3}
check := func(p *whoareyouV5, wantSeq uint64) { check := func(p *v5wire.Whoareyou, wantSeq uint64) {
t.Helper() t.Helper()
if !bytes.Equal(p.AuthTag, authTag[:]) { if p.Nonce != nonce {
t.Error("wrong token in WHOAREYOU:", p.AuthTag, authTag[:]) t.Error("wrong nonce in WHOAREYOU:", p.Nonce, nonce)
} }
if p.IDNonce == ([32]byte{}) { if p.IDNonce == ([16]byte{}) {
t.Error("all zero ID nonce") t.Error("all zero ID nonce")
} }
if p.RecordSeq != wantSeq { if p.RecordSeq != wantSeq {
@ -125,8 +138,8 @@ func TestUDPv5_unknownPacket(t *testing.T) {
} }
// Unknown packet from unknown node. // Unknown packet from unknown node.
test.packetIn(&unknownV5{AuthTag: authTag[:]}) test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {
check(p, 0) check(p, 0)
}) })
@ -134,8 +147,8 @@ func TestUDPv5_unknownPacket(t *testing.T) {
n := test.getNode(test.remotekey, test.remoteaddr).Node() n := test.getNode(test.remotekey, test.remoteaddr).Node()
test.table.addSeenNode(wrapNode(n)) test.table.addSeenNode(wrapNode(n))
test.packetIn(&unknownV5{AuthTag: authTag[:]}) test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {
check(p, n.Seq()) check(p, n.Seq())
}) })
} }
@ -147,24 +160,40 @@ func TestUDPv5_findnodeHandling(t *testing.T) {
defer test.close() defer test.close()
// Create test nodes and insert them into the table. // Create test nodes and insert them into the table.
nodes := nodesAtDistance(test.table.self().ID(), 253, 10) nodes253 := nodesAtDistance(test.table.self().ID(), 253, 10)
fillTable(test.table, wrapNodes(nodes)) nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4)
nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10)
fillTable(test.table, wrapNodes(nodes253))
fillTable(test.table, wrapNodes(nodes249))
fillTable(test.table, wrapNodes(nodes248))
// Requesting with distance zero should return the node's own record. // Requesting with distance zero should return the node's own record.
test.packetIn(&findnodeV5{ReqID: []byte{0}, Distance: 0}) test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}})
test.expectNodes([]byte{0}, 1, []*enode.Node{test.udp.Self()}) test.expectNodes([]byte{0}, 1, []*enode.Node{test.udp.Self()})
// Requesting with distance > 256 caps it at 256. // Requesting with distance > 256 shouldn't crash.
test.packetIn(&findnodeV5{ReqID: []byte{1}, Distance: 4234098}) test.packetIn(&v5wire.Findnode{ReqID: []byte{1}, Distances: []uint{4234098}})
test.expectNodes([]byte{1}, 1, nil) test.expectNodes([]byte{1}, 1, nil)
// This request gets no nodes because the corresponding bucket is empty. // Requesting with empty distance list shouldn't crash either.
test.packetIn(&findnodeV5{ReqID: []byte{2}, Distance: 254}) test.packetIn(&v5wire.Findnode{ReqID: []byte{2}, Distances: []uint{}})
test.expectNodes([]byte{2}, 1, nil) test.expectNodes([]byte{2}, 1, nil)
// This request gets all test nodes. // This request gets no nodes because the corresponding bucket is empty.
test.packetIn(&findnodeV5{ReqID: []byte{3}, Distance: 253}) test.packetIn(&v5wire.Findnode{ReqID: []byte{3}, Distances: []uint{254}})
test.expectNodes([]byte{3}, 4, nodes) test.expectNodes([]byte{3}, 1, nil)
// This request gets all the distance-253 nodes.
test.packetIn(&v5wire.Findnode{ReqID: []byte{4}, Distances: []uint{253}})
test.expectNodes([]byte{4}, 4, nodes253)
// This request gets all the distance-249 nodes and some more at 248 because
// the bucket at 249 is not full.
test.packetIn(&v5wire.Findnode{ReqID: []byte{5}, Distances: []uint{249, 248}})
var nodes []*enode.Node
nodes = append(nodes, nodes249...)
nodes = append(nodes, nodes248[:10]...)
test.expectNodes([]byte{5}, 5, nodes)
} }
func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes []*enode.Node) { func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes []*enode.Node) {
@ -172,16 +201,17 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes
for _, n := range wantNodes { for _, n := range wantNodes {
nodeSet[n.ID()] = n.Record() nodeSet[n.ID()] = n.Record()
} }
for { for {
test.waitPacketOut(func(p *nodesV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, wantReqID) {
test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID)
}
if len(p.Nodes) > 3 { if len(p.Nodes) > 3 {
test.t.Fatalf("too many nodes in response") test.t.Fatalf("too many nodes in response")
} }
if p.Total != wantTotal { if p.Total != wantTotal {
test.t.Fatalf("wrong total response count %d", p.Total) test.t.Fatalf("wrong total response count %d, want %d", p.Total, wantTotal)
}
if !bytes.Equal(p.ReqID, wantReqID) {
test.t.Fatalf("wrong request ID in response: %v", p.ReqID)
} }
for _, record := range p.Nodes { for _, record := range p.Nodes {
n, _ := enode.New(enode.ValidSchemesForTesting, record) n, _ := enode.New(enode.ValidSchemesForTesting, record)
@ -215,7 +245,7 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) {}) test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err) t.Fatalf("want errTimeout, got %q", err)
} }
@ -225,8 +255,8 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.packetInFrom(test.remotekey, test.remoteaddr, &pongV5{ReqID: p.ReqID}) test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != nil { if err := <-done; err != nil {
t.Fatal(err) t.Fatal(err)
@ -237,9 +267,9 @@ func TestUDPv5_pingCall(t *testing.T) {
_, err := test.udp.ping(remote) _, err := test.udp.ping(remote)
done <- err done <- err
}() }()
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101}
test.packetInFrom(test.remotekey, wrongAddr, &pongV5{ReqID: p.ReqID}) test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout for reply from wrong IP, got %q", err) t.Fatalf("want errTimeout for reply from wrong IP, got %q", err)
@ -255,29 +285,29 @@ func TestUDPv5_findnodeCall(t *testing.T) {
// Launch the request: // Launch the request:
var ( var (
distance = 230 distances = []uint{230}
remote = test.getNode(test.remotekey, test.remoteaddr).Node() remote = test.getNode(test.remotekey, test.remoteaddr).Node()
nodes = nodesAtDistance(remote.ID(), distance, 8) nodes = nodesAtDistance(remote.ID(), int(distances[0]), 8)
done = make(chan error, 1) done = make(chan error, 1)
response []*enode.Node response []*enode.Node
) )
go func() { go func() {
var err error var err error
response, err = test.udp.findnode(remote, distance) response, err = test.udp.findnode(remote, distances)
done <- err done <- err
}() }()
// Serve the responses: // Serve the responses:
test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) {
if p.Distance != uint(distance) { if !reflect.DeepEqual(p.Distances, distances) {
t.Fatalf("wrong bucket: %d", p.Distance) t.Fatalf("wrong distances in request: %v", p.Distances)
} }
test.packetIn(&nodesV5{ test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID, ReqID: p.ReqID,
Total: 2, Total: 2,
Nodes: nodesToRecords(nodes[:4]), Nodes: nodesToRecords(nodes[:4]),
}) })
test.packetIn(&nodesV5{ test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID, ReqID: p.ReqID,
Total: 2, Total: 2,
Nodes: nodesToRecords(nodes[4:]), Nodes: nodesToRecords(nodes[4:]),
@ -314,16 +344,16 @@ func TestUDPv5_callResend(t *testing.T) {
}() }()
// Ping answered by WHOAREYOU. // Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.packetIn(&whoareyouV5{AuthTag: authTag}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
// Ping should be re-sent. // Ping should be re-sent.
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.packetIn(&pongV5{ReqID: p.ReqID}) test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
}) })
// Answer the other ping. // Answer the other ping.
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {
test.packetIn(&pongV5{ReqID: p.ReqID}) test.packetIn(&v5wire.Pong{ReqID: p.ReqID})
}) })
if err := <-done; err != nil { if err := <-done; err != nil {
t.Fatalf("unexpected ping error: %v", err) t.Fatalf("unexpected ping error: %v", err)
@ -347,12 +377,12 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) {
}() }()
// Ping answered by WHOAREYOU. // Ping answered by WHOAREYOU.
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.packetIn(&whoareyouV5{AuthTag: authTag}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
// Ping answered by WHOAREYOU again. // Ping answered by WHOAREYOU again.
test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) {
test.packetIn(&whoareyouV5{AuthTag: authTag}) test.packetIn(&v5wire.Whoareyou{Nonce: nonce})
}) })
if err := <-done; err != errTimeout { if err := <-done; err != errTimeout {
t.Fatalf("unexpected ping error: %q", err) t.Fatalf("unexpected ping error: %q", err)
@ -367,27 +397,27 @@ func TestUDPv5_callTimeoutReset(t *testing.T) {
// Launch the request: // Launch the request:
var ( var (
distance = 230 distance = uint(230)
remote = test.getNode(test.remotekey, test.remoteaddr).Node() remote = test.getNode(test.remotekey, test.remoteaddr).Node()
nodes = nodesAtDistance(remote.ID(), distance, 8) nodes = nodesAtDistance(remote.ID(), int(distance), 8)
done = make(chan error, 1) done = make(chan error, 1)
) )
go func() { go func() {
_, err := test.udp.findnode(remote, distance) _, err := test.udp.findnode(remote, []uint{distance})
done <- err done <- err
}() }()
// Serve two responses, slowly. // Serve two responses, slowly.
test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) {
time.Sleep(respTimeout - 50*time.Millisecond) time.Sleep(respTimeout - 50*time.Millisecond)
test.packetIn(&nodesV5{ test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID, ReqID: p.ReqID,
Total: 2, Total: 2,
Nodes: nodesToRecords(nodes[:4]), Nodes: nodesToRecords(nodes[:4]),
}) })
time.Sleep(respTimeout - 50*time.Millisecond) time.Sleep(respTimeout - 50*time.Millisecond)
test.packetIn(&nodesV5{ test.packetIn(&v5wire.Nodes{
ReqID: p.ReqID, ReqID: p.ReqID,
Total: 2, Total: 2,
Nodes: nodesToRecords(nodes[4:]), Nodes: nodesToRecords(nodes[4:]),
@ -398,6 +428,97 @@ func TestUDPv5_callTimeoutReset(t *testing.T) {
} }
} }
// This test checks that TALKREQ calls the registered handler function.
func TestUDPv5_talkHandling(t *testing.T) {
t.Parallel()
test := newUDPV5Test(t)
defer test.close()
var recvMessage []byte
test.udp.RegisterTalkHandler("test", func(message []byte) []byte {
recvMessage = message
return []byte("test response")
})
// Successful case:
test.packetIn(&v5wire.TalkRequest{
ReqID: []byte("foo"),
Protocol: "test",
Message: []byte("test request"),
})
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("foo")) {
t.Error("wrong request ID in response:", p.ReqID)
}
if string(p.Message) != "test response" {
t.Errorf("wrong talk response message: %q", p.Message)
}
if string(recvMessage) != "test request" {
t.Errorf("wrong message received in handler: %q", recvMessage)
}
})
// Check that empty response is returned for unregistered protocols.
recvMessage = nil
test.packetIn(&v5wire.TalkRequest{
ReqID: []byte("2"),
Protocol: "wrong",
Message: []byte("test request"),
})
test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) {
if !bytes.Equal(p.ReqID, []byte("2")) {
t.Error("wrong request ID in response:", p.ReqID)
}
if string(p.Message) != "" {
t.Errorf("wrong talk response message: %q", p.Message)
}
if recvMessage != nil {
t.Errorf("handler was called for wrong protocol: %q", recvMessage)
}
})
}
// This test checks that outgoing TALKREQ calls work.
func TestUDPv5_talkRequest(t *testing.T) {
t.Parallel()
test := newUDPV5Test(t)
defer test.close()
remote := test.getNode(test.remotekey, test.remoteaddr).Node()
done := make(chan error, 1)
// This request times out.
go func() {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err
}()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {})
if err := <-done; err != errTimeout {
t.Fatalf("want errTimeout, got %q", err)
}
// This request works.
go func() {
_, err := test.udp.TalkRequest(remote, "test", []byte("test request"))
done <- err
}()
test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {
if p.Protocol != "test" {
t.Errorf("wrong protocol ID in talk request: %q", p.Protocol)
}
if string(p.Message) != "test request" {
t.Errorf("wrong message talk request: %q", p.Message)
}
test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.TalkResponse{
ReqID: p.ReqID,
Message: []byte("test response"),
})
})
if err := <-done; err != nil {
t.Fatal(err)
}
}
// This test checks that lookup works. // This test checks that lookup works.
func TestUDPv5_lookup(t *testing.T) { func TestUDPv5_lookup(t *testing.T) {
t.Parallel() t.Parallel()
@ -417,7 +538,8 @@ func TestUDPv5_lookup(t *testing.T) {
} }
// Seed table with initial node. // Seed table with initial node.
fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}) initialNode := lookupTestnet.node(256, 0)
fillTable(test.table, []*node{wrapNode(initialNode)})
// Start the lookup. // Start the lookup.
resultC := make(chan []*enode.Node, 1) resultC := make(chan []*enode.Node, 1)
@ -427,22 +549,30 @@ func TestUDPv5_lookup(t *testing.T) {
}() }()
// Answer lookup packets. // Answer lookup packets.
asked := make(map[enode.ID]bool)
for done := false; !done; { for done := false; !done; {
done = test.waitPacketOut(func(p packetV5, to *net.UDPAddr, authTag []byte) { done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) {
recipient, key := lookupTestnet.nodeByAddr(to) recipient, key := lookupTestnet.nodeByAddr(to)
switch p := p.(type) { switch p := p.(type) {
case *pingV5: case *v5wire.Ping:
test.packetInFrom(key, to, &pongV5{ReqID: p.ReqID}) test.packetInFrom(key, to, &v5wire.Pong{ReqID: p.ReqID})
case *findnodeV5: case *v5wire.Findnode:
nodes := lookupTestnet.neighborsAtDistance(recipient, p.Distance, 3) if asked[recipient.ID()] {
response := &nodesV5{ReqID: p.ReqID, Total: 1, Nodes: nodesToRecords(nodes)} t.Error("Asked node", recipient.ID(), "twice")
test.packetInFrom(key, to, response) }
asked[recipient.ID()] = true
nodes := lookupTestnet.neighborsAtDistances(recipient, p.Distances, 16)
t.Logf("Got FINDNODE for %v, returning %d nodes", p.Distances, len(nodes))
for _, resp := range packNodes(p.ReqID, nodes) {
test.packetInFrom(key, to, resp)
}
} }
}) })
} }
// Verify result nodes. // Verify result nodes.
checkLookupResults(t, lookupTestnet, <-resultC) results := <-resultC
checkLookupResults(t, lookupTestnet, results)
} }
// This test checks the local node can be utilised to set key-values. // This test checks the local node can be utilised to set key-values.
@ -481,6 +611,7 @@ type udpV5Test struct {
nodesByIP map[string]*enode.LocalNode nodesByIP map[string]*enode.LocalNode
} }
// testCodec is the packet encoding used by protocol tests. This codec does not perform encryption.
type testCodec struct { type testCodec struct {
test *udpV5Test test *udpV5Test
id enode.ID id enode.ID
@ -489,46 +620,44 @@ type testCodec struct {
type testCodecFrame struct { type testCodecFrame struct {
NodeID enode.ID NodeID enode.ID
AuthTag []byte AuthTag v5wire.Nonce
Ptype byte Ptype byte
Packet rlp.RawValue Packet rlp.RawValue
} }
func (c *testCodec) encode(toID enode.ID, addr string, p packetV5, _ *whoareyouV5) ([]byte, []byte, error) { func (c *testCodec) Encode(toID enode.ID, addr string, p v5wire.Packet, _ *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) {
c.ctr++ c.ctr++
authTag := make([]byte, 8) var authTag v5wire.Nonce
binary.BigEndian.PutUint64(authTag, c.ctr) binary.BigEndian.PutUint64(authTag[:], c.ctr)
penc, _ := rlp.EncodeToBytes(p) penc, _ := rlp.EncodeToBytes(p)
frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.kind(), penc}) frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.Kind(), penc})
return frame, authTag, err return frame, authTag, err
} }
func (c *testCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5wire.Packet, error) {
frame, p, err := c.decodeFrame(input) frame, p, err := c.decodeFrame(input)
if err != nil { if err != nil {
return enode.ID{}, nil, nil, err return enode.ID{}, nil, nil, err
} }
if p.kind() == p_whoareyouV5 {
frame.NodeID = enode.ID{} // match wireCodec behavior
}
return frame.NodeID, nil, p, nil return frame.NodeID, nil, p, nil
} }
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p packetV5, err error) { func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
if err = rlp.DecodeBytes(input, &frame); err != nil { if err = rlp.DecodeBytes(input, &frame); err != nil {
return frame, nil, fmt.Errorf("invalid frame: %v", err) return frame, nil, fmt.Errorf("invalid frame: %v", err)
} }
switch frame.Ptype { switch frame.Ptype {
case p_unknownV5: case v5wire.UnknownPacket:
dec := new(unknownV5) dec := new(v5wire.Unknown)
err = rlp.DecodeBytes(frame.Packet, &dec) err = rlp.DecodeBytes(frame.Packet, &dec)
p = dec p = dec
case p_whoareyouV5: case v5wire.WhoareyouPacket:
dec := new(whoareyouV5) dec := new(v5wire.Whoareyou)
err = rlp.DecodeBytes(frame.Packet, &dec) err = rlp.DecodeBytes(frame.Packet, &dec)
p = dec p = dec
default: default:
p, err = decodePacketBodyV5(frame.Ptype, frame.Packet) p, err = v5wire.DecodeMessage(frame.Ptype, frame.Packet)
} }
return frame, p, err return frame, p, err
} }
@ -561,20 +690,20 @@ func newUDPV5Test(t *testing.T) *udpV5Test {
} }
// handles a packet as if it had been sent to the transport. // handles a packet as if it had been sent to the transport.
func (test *udpV5Test) packetIn(packet packetV5) { func (test *udpV5Test) packetIn(packet v5wire.Packet) {
test.t.Helper() test.t.Helper()
test.packetInFrom(test.remotekey, test.remoteaddr, packet) test.packetInFrom(test.remotekey, test.remoteaddr, packet)
} }
// handles a packet as if it had been sent to the transport by the key/endpoint. // handles a packet as if it had been sent to the transport by the key/endpoint.
func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet packetV5) { func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) {
test.t.Helper() test.t.Helper()
ln := test.getNode(key, addr) ln := test.getNode(key, addr)
codec := &testCodec{test: test, id: ln.ID()} codec := &testCodec{test: test, id: ln.ID()}
enc, _, err := codec.encode(test.udp.Self().ID(), addr.String(), packet, nil) enc, _, err := codec.Encode(test.udp.Self().ID(), addr.String(), packet, nil)
if err != nil { if err != nil {
test.t.Errorf("%s encode error: %v", packet.name(), err) test.t.Errorf("%s encode error: %v", packet.Name(), err)
} }
if test.udp.dispatchReadPacket(addr, enc) { if test.udp.dispatchReadPacket(addr, enc) {
<-test.udp.readNextCh // unblock UDPv5.dispatch <-test.udp.readNextCh // unblock UDPv5.dispatch
@ -596,8 +725,12 @@ func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.
return ln return ln
} }
// waitPacketOut waits for the next output packet and handles it using the given 'validate'
// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is
// assignable to packetV5.
func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) {
test.t.Helper() test.t.Helper()
fn := reflect.ValueOf(validate) fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0) exptype := fn.Type().In(0)

@ -0,0 +1,180 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package v5wire
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"errors"
"fmt"
"hash"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
"golang.org/x/crypto/hkdf"
)
const (
// Encryption/authentication parameters.
aesKeySize = 16
gcmNonceSize = 12
)
// Nonce represents a nonce used for AES/GCM.
type Nonce [gcmNonceSize]byte
// EncodePubkey encodes a public key.
func EncodePubkey(key *ecdsa.PublicKey) []byte {
switch key.Curve {
case crypto.S256():
return crypto.CompressPubkey(key)
default:
panic("unsupported curve " + key.Curve.Params().Name + " in EncodePubkey")
}
}
// DecodePubkey decodes a public key in compressed format.
func DecodePubkey(curve elliptic.Curve, e []byte) (*ecdsa.PublicKey, error) {
switch curve {
case crypto.S256():
if len(e) != 33 {
return nil, errors.New("wrong size public key data")
}
return crypto.DecompressPubkey(e)
default:
return nil, fmt.Errorf("unsupported curve %s in DecodePubkey", curve.Params().Name)
}
}
// idNonceHash computes the ID signature hash used in the handshake.
func idNonceHash(h hash.Hash, challenge, ephkey []byte, destID enode.ID) []byte {
h.Reset()
h.Write([]byte("discovery v5 identity proof"))
h.Write(challenge)
h.Write(ephkey)
h.Write(destID[:])
return h.Sum(nil)
}
// makeIDSignature creates the ID nonce signature.
func makeIDSignature(hash hash.Hash, key *ecdsa.PrivateKey, challenge, ephkey []byte, destID enode.ID) ([]byte, error) {
input := idNonceHash(hash, challenge, ephkey, destID)
switch key.Curve {
case crypto.S256():
idsig, err := crypto.Sign(input, key)
if err != nil {
return nil, err
}
return idsig[:len(idsig)-1], nil // remove recovery ID
default:
return nil, fmt.Errorf("unsupported curve %s", key.Curve.Params().Name)
}
}
// s256raw is an unparsed secp256k1 public key ENR entry.
type s256raw []byte
func (s256raw) ENRKey() string { return "secp256k1" }
// verifyIDSignature checks that signature over idnonce was made by the given node.
func verifyIDSignature(hash hash.Hash, sig []byte, n *enode.Node, challenge, ephkey []byte, destID enode.ID) error {
switch idscheme := n.Record().IdentityScheme(); idscheme {
case "v4":
var pubkey s256raw
if n.Load(&pubkey) != nil {
return errors.New("no secp256k1 public key in record")
}
input := idNonceHash(hash, challenge, ephkey, destID)
if !crypto.VerifySignature(pubkey, input, sig) {
return errInvalidNonceSig
}
return nil
default:
return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme)
}
}
type hashFn func() hash.Hash
// deriveKeys creates the session keys.
func deriveKeys(hash hashFn, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, n1, n2 enode.ID, challenge []byte) *session {
const text = "discovery v5 key agreement"
var info = make([]byte, 0, len(text)+len(n1)+len(n2))
info = append(info, text...)
info = append(info, n1[:]...)
info = append(info, n2[:]...)
eph := ecdh(priv, pub)
if eph == nil {
return nil
}
kdf := hkdf.New(hash, eph, challenge, info)
sec := session{writeKey: make([]byte, aesKeySize), readKey: make([]byte, aesKeySize)}
kdf.Read(sec.writeKey)
kdf.Read(sec.readKey)
for i := range eph {
eph[i] = 0
}
return &sec
}
// ecdh creates a shared secret.
func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte {
secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes())
if secX == nil {
return nil
}
sec := make([]byte, 33)
sec[0] = 0x02 | byte(secY.Bit(0))
math.ReadBits(secX, sec[1:])
return sec
}
// encryptGCM encrypts pt using AES-GCM with the given key and nonce. The ciphertext is
// appended to dest, which must not overlap with plaintext. The resulting ciphertext is 16
// bytes longer than plaintext because it contains an authentication tag.
func encryptGCM(dest, key, nonce, plaintext, authData []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
panic(fmt.Errorf("can't create block cipher: %v", err))
}
aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
panic(fmt.Errorf("can't create GCM: %v", err))
}
return aesgcm.Seal(dest, nonce, plaintext, authData), nil
}
// decryptGCM decrypts ct using AES-GCM with the given key and nonce.
func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("can't create block cipher: %v", err)
}
if len(nonce) != gcmNonceSize {
return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce))
}
aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
return nil, fmt.Errorf("can't create GCM: %v", err)
}
pt := make([]byte, 0, len(ct))
return aesgcm.Open(pt, nonce, ct, authData)
}

@ -0,0 +1,124 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package v5wire
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/sha256"
"reflect"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
)
func TestVector_ECDH(t *testing.T) {
var (
staticKey = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
publicKey = hexPubkey(crypto.S256(), "0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231")
want = hexutil.MustDecode("0x033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e")
)
result := ecdh(staticKey, publicKey)
check(t, "shared-secret", result, want)
}
func TestVector_KDF(t *testing.T) {
var (
ephKey = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
cdata = hexutil.MustDecode("0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000")
net = newHandshakeTest()
)
defer net.close()
destKey := &testKeyB.PublicKey
s := deriveKeys(sha256.New, ephKey, destKey, net.nodeA.id(), net.nodeB.id(), cdata)
t.Logf("ephemeral-key = %#x", ephKey.D)
t.Logf("dest-pubkey = %#x", EncodePubkey(destKey))
t.Logf("node-id-a = %#x", net.nodeA.id().Bytes())
t.Logf("node-id-b = %#x", net.nodeB.id().Bytes())
t.Logf("challenge-data = %#x", cdata)
check(t, "initiator-key", s.writeKey, hexutil.MustDecode("0xdccc82d81bd610f4f76d3ebe97a40571"))
check(t, "recipient-key", s.readKey, hexutil.MustDecode("0xac74bb8773749920b0d3a8881c173ec5"))
}
func TestVector_IDSignature(t *testing.T) {
var (
key = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736")
destID = enode.HexID("0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9")
ephkey = hexutil.MustDecode("0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231")
cdata = hexutil.MustDecode("0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000")
)
sig, err := makeIDSignature(sha256.New(), key, cdata, ephkey, destID)
if err != nil {
t.Fatal(err)
}
t.Logf("static-key = %#x", key.D)
t.Logf("challenge-data = %#x", cdata)
t.Logf("ephemeral-pubkey = %#x", ephkey)
t.Logf("node-id-B = %#x", destID.Bytes())
expected := "0x94852a1e2318c4e5e9d422c98eaf19d1d90d876b29cd06ca7cb7546d0fff7b484fe86c09a064fe72bdbef73ba8e9c34df0cd2b53e9d65528c2c7f336d5dfc6e6"
check(t, "id-signature", sig, hexutil.MustDecode(expected))
}
func TestDeriveKeys(t *testing.T) {
t.Parallel()
var (
n1 = enode.ID{1}
n2 = enode.ID{2}
cdata = []byte{1, 2, 3, 4}
)
sec1 := deriveKeys(sha256.New, testKeyA, &testKeyB.PublicKey, n1, n2, cdata)
sec2 := deriveKeys(sha256.New, testKeyB, &testKeyA.PublicKey, n1, n2, cdata)
if sec1 == nil || sec2 == nil {
t.Fatal("key agreement failed")
}
if !reflect.DeepEqual(sec1, sec2) {
t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2)
}
}
func check(t *testing.T, what string, x, y []byte) {
t.Helper()
if !bytes.Equal(x, y) {
t.Errorf("wrong %s: %#x != %#x", what, x, y)
} else {
t.Logf("%s = %#x", what, x)
}
}
func hexPrivkey(input string) *ecdsa.PrivateKey {
key, err := crypto.HexToECDSA(strings.TrimPrefix(input, "0x"))
if err != nil {
panic(err)
}
return key
}
func hexPubkey(curve elliptic.Curve, input string) *ecdsa.PublicKey {
key, err := DecodePubkey(curve, hexutil.MustDecode(input))
if err != nil {
panic(err)
}
return key
}

@ -0,0 +1,648 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package v5wire
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
crand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"hash"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
)
// TODO concurrent WHOAREYOU tie-breaker
// TODO rehandshake after X packets
// Header represents a packet header.
type Header struct {
IV [sizeofMaskingIV]byte
StaticHeader
AuthData []byte
src enode.ID // used by decoder
}
// StaticHeader contains the static fields of a packet header.
type StaticHeader struct {
ProtocolID [6]byte
Version uint16
Flag byte
Nonce Nonce
AuthSize uint16
}
// Authdata layouts.
type (
whoareyouAuthData struct {
IDNonce [16]byte // ID proof data
RecordSeq uint64 // highest known ENR sequence of requester
}
handshakeAuthData struct {
h struct {
SrcID enode.ID
SigSize byte // ignature data
PubkeySize byte // offset of
}
// Trailing variable-size data.
signature, pubkey, record []byte
}
messageAuthData struct {
SrcID enode.ID
}
)
// Packet header flag values.
const (
flagMessage = iota
flagWhoareyou
flagHandshake
)
// Protocol constants.
const (
version = 1
minVersion = 1
sizeofMaskingIV = 16
minMessageSize = 48 // this refers to data after static headers
randomPacketMsgSize = 20
)
var protocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'}
// Errors.
var (
errTooShort = errors.New("packet too short")
errInvalidHeader = errors.New("invalid packet header")
errInvalidFlag = errors.New("invalid flag value in header")
errMinVersion = errors.New("version of packet header below minimum")
errMsgTooShort = errors.New("message/handshake packet below minimum size")
errAuthSize = errors.New("declared auth size is beyond packet length")
errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake")
errInvalidAuthKey = errors.New("invalid ephemeral pubkey")
errNoRecord = errors.New("expected ENR in handshake but none sent")
errInvalidNonceSig = errors.New("invalid ID nonce signature")
errMessageTooShort = errors.New("message contains no data")
errMessageDecrypt = errors.New("cannot decrypt message")
)
// Public errors.
var (
ErrInvalidReqID = errors.New("request ID larger than 8 bytes")
)
// Packet sizes.
var (
sizeofStaticHeader = binary.Size(StaticHeader{})
sizeofWhoareyouAuthData = binary.Size(whoareyouAuthData{})
sizeofHandshakeAuthData = binary.Size(handshakeAuthData{}.h)
sizeofMessageAuthData = binary.Size(messageAuthData{})
sizeofStaticPacketData = sizeofMaskingIV + sizeofStaticHeader
)
// Codec encodes and decodes Discovery v5 packets.
// This type is not safe for concurrent use.
type Codec struct {
sha256 hash.Hash
localnode *enode.LocalNode
privkey *ecdsa.PrivateKey
sc *SessionCache
// encoder buffers
buf bytes.Buffer // whole packet
headbuf bytes.Buffer // packet header
msgbuf bytes.Buffer // message RLP plaintext
msgctbuf []byte // message data ciphertext
// decoder buffer
reader bytes.Reader
}
// NewCodec creates a wire codec.
func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *Codec {
c := &Codec{
sha256: sha256.New(),
localnode: ln,
privkey: key,
sc: NewSessionCache(1024, clock),
}
return c
}
// Encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The
// 'challenge' parameter should be the most recently received WHOAREYOU packet from that
// node.
func (c *Codec) Encode(id enode.ID, addr string, packet Packet, challenge *Whoareyou) ([]byte, Nonce, error) {
// Create the packet header.
var (
head Header
session *session
msgData []byte
err error
)
switch {
case packet.Kind() == WhoareyouPacket:
head, err = c.encodeWhoareyou(id, packet.(*Whoareyou))
case challenge != nil:
// We have an unanswered challenge, send handshake.
head, session, err = c.encodeHandshakeHeader(id, addr, challenge)
default:
session = c.sc.session(id, addr)
if session != nil {
// There is a session, use it.
head, err = c.encodeMessageHeader(id, session)
} else {
// No keys, send random data to kick off the handshake.
head, msgData, err = c.encodeRandom(id)
}
}
if err != nil {
return nil, Nonce{}, err
}
// Generate masking IV.
if err := c.sc.maskingIVGen(head.IV[:]); err != nil {
return nil, Nonce{}, fmt.Errorf("can't generate masking IV: %v", err)
}
// Encode header data.
c.writeHeaders(&head)
// Store sent WHOAREYOU challenges.
if challenge, ok := packet.(*Whoareyou); ok {
challenge.ChallengeData = bytesCopy(&c.buf)
c.sc.storeSentHandshake(id, addr, challenge)
} else if msgData == nil {
headerData := c.buf.Bytes()
msgData, err = c.encryptMessage(session, packet, &head, headerData)
if err != nil {
return nil, Nonce{}, err
}
}
enc, err := c.EncodeRaw(id, head, msgData)
return enc, head.Nonce, err
}
// EncodeRaw encodes a packet with the given header.
func (c *Codec) EncodeRaw(id enode.ID, head Header, msgdata []byte) ([]byte, error) {
c.writeHeaders(&head)
// Apply masking.
masked := c.buf.Bytes()[sizeofMaskingIV:]
mask := head.mask(id)
mask.XORKeyStream(masked[:], masked[:])
// Write message data.
c.buf.Write(msgdata)
return c.buf.Bytes(), nil
}
func (c *Codec) writeHeaders(head *Header) {
c.buf.Reset()
c.buf.Write(head.IV[:])
binary.Write(&c.buf, binary.BigEndian, &head.StaticHeader)
c.buf.Write(head.AuthData)
}
// makeHeader creates a packet header.
func (c *Codec) makeHeader(toID enode.ID, flag byte, authsizeExtra int) Header {
var authsize int
switch flag {
case flagMessage:
authsize = sizeofMessageAuthData
case flagWhoareyou:
authsize = sizeofWhoareyouAuthData
case flagHandshake:
authsize = sizeofHandshakeAuthData
default:
panic(fmt.Errorf("BUG: invalid packet header flag %x", flag))
}
authsize += authsizeExtra
if authsize > int(^uint16(0)) {
panic(fmt.Errorf("BUG: auth size %d overflows uint16", authsize))
}
return Header{
StaticHeader: StaticHeader{
ProtocolID: protocolID,
Version: version,
Flag: flag,
AuthSize: uint16(authsize),
},
}
}
// encodeRandom encodes a packet with random content.
func (c *Codec) encodeRandom(toID enode.ID) (Header, []byte, error) {
head := c.makeHeader(toID, flagMessage, 0)
// Encode auth data.
auth := messageAuthData{SrcID: c.localnode.ID()}
if _, err := crand.Read(head.Nonce[:]); err != nil {
return head, nil, fmt.Errorf("can't get random data: %v", err)
}
c.headbuf.Reset()
binary.Write(&c.headbuf, binary.BigEndian, auth)
head.AuthData = c.headbuf.Bytes()
// Fill message ciphertext buffer with random bytes.
c.msgctbuf = append(c.msgctbuf[:0], make([]byte, randomPacketMsgSize)...)
crand.Read(c.msgctbuf)
return head, c.msgctbuf, nil
}
// encodeWhoareyou encodes a WHOAREYOU packet.
func (c *Codec) encodeWhoareyou(toID enode.ID, packet *Whoareyou) (Header, error) {
// Sanity check node field to catch misbehaving callers.
if packet.RecordSeq > 0 && packet.Node == nil {
panic("BUG: missing node in whoareyou with non-zero seq")
}
// Create header.
head := c.makeHeader(toID, flagWhoareyou, 0)
head.AuthData = bytesCopy(&c.buf)
head.Nonce = packet.Nonce
// Encode auth data.
auth := &whoareyouAuthData{
IDNonce: packet.IDNonce,
RecordSeq: packet.RecordSeq,
}
c.headbuf.Reset()
binary.Write(&c.headbuf, binary.BigEndian, auth)
head.AuthData = c.headbuf.Bytes()
return head, nil
}
// encodeHandshakeMessage encodes the handshake message packet header.
func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Whoareyou) (Header, *session, error) {
// Ensure calling code sets challenge.node.
if challenge.Node == nil {
panic("BUG: missing challenge.Node in encode")
}
// Generate new secrets.
auth, session, err := c.makeHandshakeAuth(toID, addr, challenge)
if err != nil {
return Header{}, nil, err
}
// Generate nonce for message.
nonce, err := c.sc.nextNonce(session)
if err != nil {
return Header{}, nil, fmt.Errorf("can't generate nonce: %v", err)
}
// TODO: this should happen when the first authenticated message is received
c.sc.storeNewSession(toID, addr, session)
// Encode the auth header.
var (
authsizeExtra = len(auth.pubkey) + len(auth.signature) + len(auth.record)
head = c.makeHeader(toID, flagHandshake, authsizeExtra)
)
c.headbuf.Reset()
binary.Write(&c.headbuf, binary.BigEndian, &auth.h)
c.headbuf.Write(auth.signature)
c.headbuf.Write(auth.pubkey)
c.headbuf.Write(auth.record)
head.AuthData = c.headbuf.Bytes()
head.Nonce = nonce
return head, session, err
}
// encodeAuthHeader creates the auth header on a request packet following WHOAREYOU.
func (c *Codec) makeHandshakeAuth(toID enode.ID, addr string, challenge *Whoareyou) (*handshakeAuthData, *session, error) {
auth := new(handshakeAuthData)
auth.h.SrcID = c.localnode.ID()
// Create the ephemeral key. This needs to be first because the
// key is part of the ID nonce signature.
var remotePubkey = new(ecdsa.PublicKey)
if err := challenge.Node.Load((*enode.Secp256k1)(remotePubkey)); err != nil {
return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient")
}
ephkey, err := c.sc.ephemeralKeyGen()
if err != nil {
return nil, nil, fmt.Errorf("can't generate ephemeral key")
}
ephpubkey := EncodePubkey(&ephkey.PublicKey)
auth.pubkey = ephpubkey[:]
auth.h.PubkeySize = byte(len(auth.pubkey))
// Add ID nonce signature to response.
cdata := challenge.ChallengeData
idsig, err := makeIDSignature(c.sha256, c.privkey, cdata, ephpubkey[:], toID)
if err != nil {
return nil, nil, fmt.Errorf("can't sign: %v", err)
}
auth.signature = idsig
auth.h.SigSize = byte(len(auth.signature))
// Add our record to response if it's newer than what remote side has.
ln := c.localnode.Node()
if challenge.RecordSeq < ln.Seq() {
auth.record, _ = rlp.EncodeToBytes(ln.Record())
}
// Create session keys.
sec := deriveKeys(sha256.New, ephkey, remotePubkey, c.localnode.ID(), challenge.Node.ID(), cdata)
if sec == nil {
return nil, nil, fmt.Errorf("key derivation failed")
}
return auth, sec, err
}
// encodeMessage encodes an encrypted message packet.
func (c *Codec) encodeMessageHeader(toID enode.ID, s *session) (Header, error) {
head := c.makeHeader(toID, flagMessage, 0)
// Create the header.
nonce, err := c.sc.nextNonce(s)
if err != nil {
return Header{}, fmt.Errorf("can't generate nonce: %v", err)
}
auth := messageAuthData{SrcID: c.localnode.ID()}
c.buf.Reset()
binary.Write(&c.buf, binary.BigEndian, &auth)
head.AuthData = bytesCopy(&c.buf)
head.Nonce = nonce
return head, err
}
func (c *Codec) encryptMessage(s *session, p Packet, head *Header, headerData []byte) ([]byte, error) {
// Encode message plaintext.
c.msgbuf.Reset()
c.msgbuf.WriteByte(p.Kind())
if err := rlp.Encode(&c.msgbuf, p); err != nil {
return nil, err
}
messagePT := c.msgbuf.Bytes()
// Encrypt into message ciphertext buffer.
messageCT, err := encryptGCM(c.msgctbuf[:0], s.writeKey, head.Nonce[:], messagePT, headerData)
if err == nil {
c.msgctbuf = messageCT
}
return messageCT, err
}
// Decode decodes a discovery packet.
func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) {
// Unmask the static header.
if len(input) < sizeofStaticPacketData {
return enode.ID{}, nil, nil, errTooShort
}
var head Header
copy(head.IV[:], input[:sizeofMaskingIV])
mask := head.mask(c.localnode.ID())
staticHeader := input[sizeofMaskingIV:sizeofStaticPacketData]
mask.XORKeyStream(staticHeader, staticHeader)
// Decode and verify the static header.
c.reader.Reset(staticHeader)
binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader)
remainingInput := len(input) - sizeofStaticPacketData
if err := head.checkValid(remainingInput); err != nil {
return enode.ID{}, nil, nil, err
}
// Unmask auth data.
authDataEnd := sizeofStaticPacketData + int(head.AuthSize)
authData := input[sizeofStaticPacketData:authDataEnd]
mask.XORKeyStream(authData, authData)
head.AuthData = authData
// Delete timed-out handshakes. This must happen before decoding to avoid
// processing the same handshake twice.
c.sc.handshakeGC()
// Decode auth part and message.
headerData := input[:authDataEnd]
msgData := input[authDataEnd:]
switch head.Flag {
case flagWhoareyou:
p, err = c.decodeWhoareyou(&head, headerData)
case flagHandshake:
n, p, err = c.decodeHandshakeMessage(addr, &head, headerData, msgData)
case flagMessage:
p, err = c.decodeMessage(addr, &head, headerData, msgData)
default:
err = errInvalidFlag
}
return head.src, n, p, err
}
// decodeWhoareyou reads packet data after the header as a WHOAREYOU packet.
func (c *Codec) decodeWhoareyou(head *Header, headerData []byte) (Packet, error) {
if len(head.AuthData) != sizeofWhoareyouAuthData {
return nil, fmt.Errorf("invalid auth size %d for WHOAREYOU", len(head.AuthData))
}
var auth whoareyouAuthData
c.reader.Reset(head.AuthData)
binary.Read(&c.reader, binary.BigEndian, &auth)
p := &Whoareyou{
Nonce: head.Nonce,
IDNonce: auth.IDNonce,
RecordSeq: auth.RecordSeq,
ChallengeData: make([]byte, len(headerData)),
}
copy(p.ChallengeData, headerData)
return p, nil
}
func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData, msgData []byte) (n *enode.Node, p Packet, err error) {
node, auth, session, err := c.decodeHandshake(fromAddr, head)
if err != nil {
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return nil, nil, err
}
// Decrypt the message using the new session keys.
msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, session.readKey)
if err != nil {
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return node, msg, err
}
// Handshake OK, drop the challenge and store the new session keys.
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return node, msg, nil
}
func (c *Codec) decodeHandshake(fromAddr string, head *Header) (n *enode.Node, auth handshakeAuthData, s *session, err error) {
if auth, err = c.decodeHandshakeAuthData(head); err != nil {
return nil, auth, nil, err
}
// Verify against our last WHOAREYOU.
challenge := c.sc.getHandshake(auth.h.SrcID, fromAddr)
if challenge == nil {
return nil, auth, nil, errUnexpectedHandshake
}
// Get node record.
n, err = c.decodeHandshakeRecord(challenge.Node, auth.h.SrcID, auth.record)
if err != nil {
return nil, auth, nil, err
}
// Verify ID nonce signature.
sig := auth.signature
cdata := challenge.ChallengeData
err = verifyIDSignature(c.sha256, sig, n, cdata, auth.pubkey, c.localnode.ID())
if err != nil {
return nil, auth, nil, err
}
// Verify ephemeral key is on curve.
ephkey, err := DecodePubkey(c.privkey.Curve, auth.pubkey)
if err != nil {
return nil, auth, nil, errInvalidAuthKey
}
// Derive sesssion keys.
session := deriveKeys(sha256.New, c.privkey, ephkey, auth.h.SrcID, c.localnode.ID(), cdata)
session = session.keysFlipped()
return n, auth, session, nil
}
// decodeHandshakeAuthData reads the authdata section of a handshake packet.
func (c *Codec) decodeHandshakeAuthData(head *Header) (auth handshakeAuthData, err error) {
// Decode fixed size part.
if len(head.AuthData) < sizeofHandshakeAuthData {
return auth, fmt.Errorf("header authsize %d too low for handshake", head.AuthSize)
}
c.reader.Reset(head.AuthData)
binary.Read(&c.reader, binary.BigEndian, &auth.h)
head.src = auth.h.SrcID
// Decode variable-size part.
var (
vardata = head.AuthData[sizeofHandshakeAuthData:]
sigAndKeySize = int(auth.h.SigSize) + int(auth.h.PubkeySize)
keyOffset = int(auth.h.SigSize)
recOffset = keyOffset + int(auth.h.PubkeySize)
)
if len(vardata) < sigAndKeySize {
return auth, errTooShort
}
auth.signature = vardata[:keyOffset]
auth.pubkey = vardata[keyOffset:recOffset]
auth.record = vardata[recOffset:]
return auth, nil
}
// decodeHandshakeRecord verifies the node record contained in a handshake packet. The
// remote node should include the record if we don't have one or if ours is older than the
// latest sequence number.
func (c *Codec) decodeHandshakeRecord(local *enode.Node, wantID enode.ID, remote []byte) (*enode.Node, error) {
node := local
if len(remote) > 0 {
var record enr.Record
if err := rlp.DecodeBytes(remote, &record); err != nil {
return nil, err
}
if local == nil || local.Seq() < record.Seq() {
n, err := enode.New(enode.ValidSchemes, &record)
if err != nil {
return nil, fmt.Errorf("invalid node record: %v", err)
}
if n.ID() != wantID {
return nil, fmt.Errorf("record in handshake has wrong ID: %v", n.ID())
}
node = n
}
}
if node == nil {
return nil, errNoRecord
}
return node, nil
}
// decodeMessage reads packet data following the header as an ordinary message packet.
func (c *Codec) decodeMessage(fromAddr string, head *Header, headerData, msgData []byte) (Packet, error) {
if len(head.AuthData) != sizeofMessageAuthData {
return nil, fmt.Errorf("invalid auth size %d for message packet", len(head.AuthData))
}
var auth messageAuthData
c.reader.Reset(head.AuthData)
binary.Read(&c.reader, binary.BigEndian, &auth)
head.src = auth.SrcID
// Try decrypting the message.
key := c.sc.readKey(auth.SrcID, fromAddr)
msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, key)
if err == errMessageDecrypt {
// It didn't work. Start the handshake since this is an ordinary message packet.
return &Unknown{Nonce: head.Nonce}, nil
}
return msg, err
}
func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet, error) {
msgdata, err := decryptGCM(readKey, nonce, input, headerData)
if err != nil {
return nil, errMessageDecrypt
}
if len(msgdata) == 0 {
return nil, errMessageTooShort
}
return DecodeMessage(msgdata[0], msgdata[1:])
}
// checkValid performs some basic validity checks on the header.
// The packetLen here is the length remaining after the static header.
func (h *StaticHeader) checkValid(packetLen int) error {
if h.ProtocolID != protocolID {
return errInvalidHeader
}
if h.Version < minVersion {
return errMinVersion
}
if h.Flag != flagWhoareyou && packetLen < minMessageSize {
return errMsgTooShort
}
if int(h.AuthSize) > packetLen {
return errAuthSize
}
return nil
}
// headerMask returns a cipher for 'masking' / 'unmasking' packet headers.
func (h *Header) mask(destID enode.ID) cipher.Stream {
block, err := aes.NewCipher(destID[:16])
if err != nil {
panic("can't create cipher")
}
return cipher.NewCTR(block, h.IV[:])
}
func bytesCopy(r *bytes.Buffer) []byte {
b := make([]byte, r.Len())
copy(b, r.Bytes())
return b
}

@ -0,0 +1,636 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package v5wire
import (
"bytes"
"crypto/ecdsa"
"encoding/hex"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
)
// To regenerate discv5 test vectors, run
//
// go test -run TestVectors -write-test-vectors
//
var writeTestVectorsFlag = flag.Bool("write-test-vectors", false, "Overwrite discv5 test vectors in testdata/")
var (
testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f")
testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628")
testEphKey, _ = crypto.HexToECDSA("0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6")
testIDnonce = [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
)
// This test checks that the minPacketSize and randomPacketMsgSize constants are well-defined.
func TestMinSizes(t *testing.T) {
var (
gcmTagSize = 16
emptyMsg = sizeofMessageAuthData + gcmTagSize
)
t.Log("static header size", sizeofStaticPacketData)
t.Log("whoareyou size", sizeofStaticPacketData+sizeofWhoareyouAuthData)
t.Log("empty msg size", sizeofStaticPacketData+emptyMsg)
if want := emptyMsg; minMessageSize != want {
t.Fatalf("wrong minMessageSize %d, want %d", minMessageSize, want)
}
if sizeofMessageAuthData+randomPacketMsgSize < minMessageSize {
t.Fatalf("randomPacketMsgSize %d too small", randomPacketMsgSize)
}
}
// This test checks the basic handshake flow where A talks to B and A has no secrets.
func TestHandshake(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{})
resp := net.nodeB.expectDecode(t, UnknownPacket, packet)
// A <- B WHOAREYOU
challenge := &Whoareyou{
Nonce: resp.(*Unknown).Nonce,
IDNonce: testIDnonce,
RecordSeq: 0,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// A -> B FINDNODE (handshake packet)
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecode(t, FindnodeMsg, findnode)
if len(net.nodeB.c.sc.handshakes) > 0 {
t.Fatalf("node B didn't remove handshake from challenge map")
}
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1})
net.nodeA.expectDecode(t, NodesMsg, nodes)
}
// This test checks that handshake attempts are removed within the timeout.
func TestHandshake_timeout(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{})
resp := net.nodeB.expectDecode(t, UnknownPacket, packet)
// A <- B WHOAREYOU
challenge := &Whoareyou{
Nonce: resp.(*Unknown).Nonce,
IDNonce: testIDnonce,
RecordSeq: 0,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// A -> B FINDNODE (handshake packet) after timeout
net.clock.Run(handshakeTimeout + 1)
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode)
}
// This test checks handshake behavior when no record is sent in the auth response.
func TestHandshake_norecord(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{})
resp := net.nodeB.expectDecode(t, UnknownPacket, packet)
// A <- B WHOAREYOU
nodeA := net.nodeA.n()
if nodeA.Seq() == 0 {
t.Fatal("need non-zero sequence number")
}
challenge := &Whoareyou{
Nonce: resp.(*Unknown).Nonce,
IDNonce: testIDnonce,
RecordSeq: nodeA.Seq(),
Node: nodeA,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// A -> B FINDNODE
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecode(t, FindnodeMsg, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1})
net.nodeA.expectDecode(t, NodesMsg, nodes)
}
// In this test, A tries to send FINDNODE with existing secrets but B doesn't know
// anything about A.
func TestHandshake_rekey(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
session := &session{
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
// A -> B FINDNODE (encrypted with zero keys)
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
net.nodeB.expectDecode(t, UnknownPacket, findnode)
// A <- B WHOAREYOU
challenge := &Whoareyou{Nonce: authTag, IDNonce: testIDnonce}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// Check that new keys haven't been stored yet.
sa := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr())
if !bytes.Equal(sa.writeKey, session.writeKey) || !bytes.Equal(sa.readKey, session.readKey) {
t.Fatal("node A stored keys too early")
}
if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil {
t.Fatal("node B stored keys too early")
}
// A -> B FINDNODE encrypted with new keys
findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecode(t, FindnodeMsg, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1})
net.nodeA.expectDecode(t, NodesMsg, nodes)
}
// In this test A and B have different keys before the handshake.
func TestHandshake_rekey2(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
initKeysA := &session{
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
initKeysB := &session{
readKey: []byte("CCCCCCCCCCCCCCCC"),
writeKey: []byte("DDDDDDDDDDDDDDDD"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
// A -> B FINDNODE encrypted with initKeysA
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
net.nodeB.expectDecode(t, UnknownPacket, findnode)
// A <- B WHOAREYOU
challenge := &Whoareyou{Nonce: authTag, IDNonce: testIDnonce}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// A -> B FINDNODE (handshake packet)
findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecode(t, FindnodeMsg, findnode)
// A <- B NODES
nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1})
net.nodeA.expectDecode(t, NodesMsg, nodes)
}
func TestHandshake_BadHandshakeAttack(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
// A -> B RANDOM PACKET
packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{})
resp := net.nodeB.expectDecode(t, UnknownPacket, packet)
// A <- B WHOAREYOU
challenge := &Whoareyou{
Nonce: resp.(*Unknown).Nonce,
IDNonce: testIDnonce,
RecordSeq: 0,
}
whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge)
net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou)
// A -> B FINDNODE
incorrect_challenge := &Whoareyou{
IDNonce: [16]byte{5, 6, 7, 8, 9, 6, 11, 12},
RecordSeq: challenge.RecordSeq,
Node: challenge.Node,
sent: challenge.sent,
}
incorrect_findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, incorrect_challenge, &Findnode{})
incorrect_findnode2 := make([]byte, len(incorrect_findnode))
copy(incorrect_findnode2, incorrect_findnode)
net.nodeB.expectDecodeErr(t, errInvalidNonceSig, incorrect_findnode)
// Reject new findnode as previous handshake is now deleted.
net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, incorrect_findnode2)
// The findnode packet is again rejected even with a valid challenge this time.
findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{})
net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode)
}
// This test checks some malformed packets.
func TestDecodeErrorsV5(t *testing.T) {
t.Parallel()
net := newHandshakeTest()
defer net.close()
net.nodeA.expectDecodeErr(t, errTooShort, []byte{})
// TODO some more tests would be nice :)
// - check invalid authdata sizes
// - check invalid handshake data sizes
}
// This test checks that all test vectors can be decoded.
func TestTestVectorsV5(t *testing.T) {
var (
idA = enode.PubkeyToIDV4(&testKeyA.PublicKey)
idB = enode.PubkeyToIDV4(&testKeyB.PublicKey)
addr = "127.0.0.1"
session = &session{
writeKey: hexutil.MustDecode("0x00000000000000000000000000000000"),
readKey: hexutil.MustDecode("0x01010101010101010101010101010101"),
}
challenge0A, challenge1A, challenge0B Whoareyou
)
// Create challenge packets.
c := Whoareyou{
Nonce: Nonce{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
IDNonce: testIDnonce,
}
challenge0A, challenge1A, challenge0B = c, c, c
challenge1A.RecordSeq = 1
net := newHandshakeTest()
challenge0A.Node = net.nodeA.n()
challenge0B.Node = net.nodeB.n()
challenge1A.Node = net.nodeA.n()
net.close()
type testVectorTest struct {
name string // test vector name
packet Packet // the packet to be encoded
challenge *Whoareyou // handshake challenge passed to encoder
prep func(*handshakeTest) // called before encode/decode
}
tests := []testVectorTest{
{
name: "v5.1-whoareyou",
packet: &challenge0B,
},
{
name: "v5.1-ping-message",
packet: &Ping{
ReqID: []byte{0, 0, 0, 1},
ENRSeq: 2,
},
prep: func(net *handshakeTest) {
net.nodeA.c.sc.storeNewSession(idB, addr, session)
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
},
},
{
name: "v5.1-ping-handshake-enr",
packet: &Ping{
ReqID: []byte{0, 0, 0, 1},
ENRSeq: 1,
},
challenge: &challenge0A,
prep: func(net *handshakeTest) {
// Update challenge.Header.AuthData.
net.nodeA.c.Encode(idB, "", &challenge0A, nil)
net.nodeB.c.sc.storeSentHandshake(idA, addr, &challenge0A)
},
},
{
name: "v5.1-ping-handshake",
packet: &Ping{
ReqID: []byte{0, 0, 0, 1},
ENRSeq: 1,
},
challenge: &challenge1A,
prep: func(net *handshakeTest) {
// Update challenge data.
net.nodeA.c.Encode(idB, "", &challenge1A, nil)
net.nodeB.c.sc.storeSentHandshake(idA, addr, &challenge1A)
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
net := newHandshakeTest()
defer net.close()
// Override all random inputs.
net.nodeA.c.sc.nonceGen = func(counter uint32) (Nonce, error) {
return Nonce{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, nil
}
net.nodeA.c.sc.maskingIVGen = func(buf []byte) error {
return nil // all zero
}
net.nodeA.c.sc.ephemeralKeyGen = func() (*ecdsa.PrivateKey, error) {
return testEphKey, nil
}
// Prime the codec for encoding/decoding.
if test.prep != nil {
test.prep(net)
}
file := filepath.Join("testdata", test.name+".txt")
if *writeTestVectorsFlag {
// Encode the packet.
d, nonce := net.nodeA.encodeWithChallenge(t, net.nodeB, test.challenge, test.packet)
comment := testVectorComment(net, test.packet, test.challenge, nonce)
writeTestVector(file, comment, d)
}
enc := hexFile(file)
net.nodeB.expectDecode(t, test.packet.Kind(), enc)
})
}
}
// testVectorComment creates the commentary for discv5 test vector files.
func testVectorComment(net *handshakeTest, p Packet, challenge *Whoareyou, nonce Nonce) string {
o := new(strings.Builder)
printWhoareyou := func(p *Whoareyou) {
fmt.Fprintf(o, "whoareyou.challenge-data = %#x\n", p.ChallengeData)
fmt.Fprintf(o, "whoareyou.request-nonce = %#x\n", p.Nonce[:])
fmt.Fprintf(o, "whoareyou.id-nonce = %#x\n", p.IDNonce[:])
fmt.Fprintf(o, "whoareyou.enr-seq = %d\n", p.RecordSeq)
}
fmt.Fprintf(o, "src-node-id = %#x\n", net.nodeA.id().Bytes())
fmt.Fprintf(o, "dest-node-id = %#x\n", net.nodeB.id().Bytes())
switch p := p.(type) {
case *Whoareyou:
// WHOAREYOU packet.
printWhoareyou(p)
case *Ping:
fmt.Fprintf(o, "nonce = %#x\n", nonce[:])
fmt.Fprintf(o, "read-key = %#x\n", net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()).writeKey)
fmt.Fprintf(o, "ping.req-id = %#x\n", p.ReqID)
fmt.Fprintf(o, "ping.enr-seq = %d\n", p.ENRSeq)
if challenge != nil {
// Handshake message packet.
fmt.Fprint(o, "\nhandshake inputs:\n\n")
printWhoareyou(challenge)
fmt.Fprintf(o, "ephemeral-key = %#x\n", testEphKey.D.Bytes())
fmt.Fprintf(o, "ephemeral-pubkey = %#x\n", crypto.CompressPubkey(&testEphKey.PublicKey))
}
default:
panic(fmt.Errorf("unhandled packet type %T", p))
}
return o.String()
}
// This benchmark checks performance of handshake packet decoding.
func BenchmarkV5_DecodeHandshakePingSecp256k1(b *testing.B) {
net := newHandshakeTest()
defer net.close()
var (
idA = net.nodeA.id()
challenge = &Whoareyou{Node: net.nodeB.n()}
message = &Ping{ReqID: []byte("reqid")}
)
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), "", message, challenge)
if err != nil {
b.Fatal("can't encode handshake packet")
}
challenge.Node = nil // force ENR signature verification in decoder
b.ResetTimer()
input := make([]byte, len(enc))
for i := 0; i < b.N; i++ {
copy(input, enc)
net.nodeB.c.sc.storeSentHandshake(idA, "", challenge)
_, _, _, err := net.nodeB.c.Decode(input, "")
if err != nil {
b.Fatal(err)
}
}
}
// This benchmark checks how long it takes to decode an encrypted ping packet.
func BenchmarkV5_DecodePing(b *testing.B) {
net := newHandshakeTest()
defer net.close()
session := &session{
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
addrB := net.nodeA.addr()
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)
if err != nil {
b.Fatalf("can't encode: %v", err)
}
b.ResetTimer()
input := make([]byte, len(enc))
for i := 0; i < b.N; i++ {
copy(input, enc)
_, _, packet, _ := net.nodeB.c.Decode(input, addrB)
if _, ok := packet.(*Ping); !ok {
b.Fatalf("wrong packet type %T", packet)
}
}
}
var pp = spew.NewDefaultConfig()
type handshakeTest struct {
nodeA, nodeB handshakeTestNode
clock mclock.Simulated
}
type handshakeTestNode struct {
ln *enode.LocalNode
c *Codec
}
func newHandshakeTest() *handshakeTest {
t := new(handshakeTest)
t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock)
t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock)
return t
}
func (t *handshakeTest) close() {
t.nodeA.ln.Database().Close()
t.nodeB.ln.Database().Close()
}
func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) {
db, _ := enode.OpenDB("")
n.ln = enode.NewLocalNode(db, key)
n.ln.SetStaticIP(ip)
if n.ln.Node().Seq() != 1 {
panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq()))
}
n.c = NewCodec(n.ln, key, clock)
}
func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) {
t.Helper()
return n.encodeWithChallenge(t, to, nil, p)
}
func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *Whoareyou, p Packet) ([]byte, Nonce) {
t.Helper()
// Copy challenge and add destination node. This avoids sharing 'c' among the two codecs.
var challenge *Whoareyou
if c != nil {
challengeCopy := *c
challenge = &challengeCopy
challenge.Node = to.n()
}
// Encode to destination.
enc, nonce, err := n.c.Encode(to.id(), to.addr(), p, challenge)
if err != nil {
t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err))
}
t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.Name(), hex.Dump(enc))
return enc, nonce
}
func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) Packet {
t.Helper()
dec, err := n.decode(p)
if err != nil {
t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err))
}
t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec))
if dec.Kind() != ptype {
t.Fatalf("expected packet type %d, got %d", ptype, dec.Kind())
}
return dec
}
func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) {
t.Helper()
if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) {
t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr))
}
}
func (n *handshakeTestNode) decode(input []byte) (Packet, error) {
_, _, p, err := n.c.Decode(input, "127.0.0.1")
return p, err
}
func (n *handshakeTestNode) n() *enode.Node {
return n.ln.Node()
}
func (n *handshakeTestNode) addr() string {
return n.ln.Node().IP().String()
}
func (n *handshakeTestNode) id() enode.ID {
return n.ln.ID()
}
// hexFile reads the given file and decodes the hex data contained in it.
// Whitespace and any lines beginning with the # character are ignored.
func hexFile(file string) []byte {
fileContent, err := ioutil.ReadFile(file)
if err != nil {
panic(err)
}
// Gather hex data, ignore comments.
var text []byte
for _, line := range bytes.Split(fileContent, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) > 0 && line[0] == '#' {
continue
}
text = append(text, line...)
}
// Parse the hex.
if bytes.HasPrefix(text, []byte("0x")) {
text = text[2:]
}
data := make([]byte, hex.DecodedLen(len(text)))
if _, err := hex.Decode(data, text); err != nil {
panic("invalid hex in " + file)
}
return data
}
// writeTestVector writes a test vector file with the given commentary and binary data.
func writeTestVector(file, comment string, data []byte) {
fd, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
panic(err)
}
defer fd.Close()
if len(comment) > 0 {
for _, line := range strings.Split(strings.TrimSpace(comment), "\n") {
fmt.Fprintf(fd, "# %s\n", line)
}
fmt.Fprintln(fd)
}
for len(data) > 0 {
var chunk []byte
if len(data) < 32 {
chunk = data
} else {
chunk = data[:32]
}
data = data[len(chunk):]
fmt.Fprintf(fd, "%x\n", chunk)
}
}

249
p2p/discover/v5wire/msg.go Normal file

@ -0,0 +1,249 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package v5wire
import (
"fmt"
"net"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rlp"
)
// Packet is implemented by all message types.
type Packet interface {
Name() string // Name returns a string corresponding to the message type.
Kind() byte // Kind returns the message type.
RequestID() []byte // Returns the request ID.
SetRequestID([]byte) // Sets the request ID.
}
// Message types.
const (
PingMsg byte = iota + 1
PongMsg
FindnodeMsg
NodesMsg
TalkRequestMsg
TalkResponseMsg
RequestTicketMsg
TicketMsg
RegtopicMsg
RegconfirmationMsg
TopicQueryMsg
UnknownPacket = byte(255) // any non-decryptable packet
WhoareyouPacket = byte(254) // the WHOAREYOU packet
)
// Protocol messages.
type (
// Unknown represents any packet that can't be decrypted.
Unknown struct {
Nonce Nonce
}
// WHOAREYOU contains the handshake challenge.
Whoareyou struct {
ChallengeData []byte // Encoded challenge
Nonce Nonce // Nonce of request packet
IDNonce [16]byte // Identity proof data
RecordSeq uint64 // ENR sequence number of recipient
// Node is the locally known node record of recipient.
// This must be set by the caller of Encode.
Node *enode.Node
sent mclock.AbsTime // for handshake GC.
}
// PING is sent during liveness checks.
Ping struct {
ReqID []byte
ENRSeq uint64
}
// PONG is the reply to PING.
Pong struct {
ReqID []byte
ENRSeq uint64
ToIP net.IP // These fields should mirror the UDP envelope address of the ping
ToPort uint16 // packet, which provides a way to discover the the external address (after NAT).
}
// FINDNODE is a query for nodes in the given bucket.
Findnode struct {
ReqID []byte
Distances []uint
}
// NODES is the reply to FINDNODE and TOPICQUERY.
Nodes struct {
ReqID []byte
Total uint8
Nodes []*enr.Record
}
// TALKREQ is an application-level request.
TalkRequest struct {
ReqID []byte
Protocol string
Message []byte
}
// TALKRESP is the reply to TALKREQ.
TalkResponse struct {
ReqID []byte
Message []byte
}
// REQUESTTICKET requests a ticket for a topic queue.
RequestTicket struct {
ReqID []byte
Topic []byte
}
// TICKET is the response to REQUESTTICKET.
Ticket struct {
ReqID []byte
Ticket []byte
}
// REGTOPIC registers the sender in a topic queue using a ticket.
Regtopic struct {
ReqID []byte
Ticket []byte
ENR *enr.Record
}
// REGCONFIRMATION is the reply to REGTOPIC.
Regconfirmation struct {
ReqID []byte
Registered bool
}
// TOPICQUERY asks for nodes with the given topic.
TopicQuery struct {
ReqID []byte
Topic []byte
}
)
// DecodeMessage decodes the message body of a packet.
func DecodeMessage(ptype byte, body []byte) (Packet, error) {
var dec Packet
switch ptype {
case PingMsg:
dec = new(Ping)
case PongMsg:
dec = new(Pong)
case FindnodeMsg:
dec = new(Findnode)
case NodesMsg:
dec = new(Nodes)
case TalkRequestMsg:
dec = new(TalkRequest)
case TalkResponseMsg:
dec = new(TalkResponse)
case RequestTicketMsg:
dec = new(RequestTicket)
case TicketMsg:
dec = new(Ticket)
case RegtopicMsg:
dec = new(Regtopic)
case RegconfirmationMsg:
dec = new(Regconfirmation)
case TopicQueryMsg:
dec = new(TopicQuery)
default:
return nil, fmt.Errorf("unknown packet type %d", ptype)
}
if err := rlp.DecodeBytes(body, dec); err != nil {
return nil, err
}
if dec.RequestID() != nil && len(dec.RequestID()) > 8 {
return nil, ErrInvalidReqID
}
return dec, nil
}
func (*Whoareyou) Name() string { return "WHOAREYOU/v5" }
func (*Whoareyou) Kind() byte { return WhoareyouPacket }
func (*Whoareyou) RequestID() []byte { return nil }
func (*Whoareyou) SetRequestID([]byte) {}
func (*Unknown) Name() string { return "UNKNOWN/v5" }
func (*Unknown) Kind() byte { return UnknownPacket }
func (*Unknown) RequestID() []byte { return nil }
func (*Unknown) SetRequestID([]byte) {}
func (*Ping) Name() string { return "PING/v5" }
func (*Ping) Kind() byte { return PingMsg }
func (p *Ping) RequestID() []byte { return p.ReqID }
func (p *Ping) SetRequestID(id []byte) { p.ReqID = id }
func (*Pong) Name() string { return "PONG/v5" }
func (*Pong) Kind() byte { return PongMsg }
func (p *Pong) RequestID() []byte { return p.ReqID }
func (p *Pong) SetRequestID(id []byte) { p.ReqID = id }
func (*Findnode) Name() string { return "FINDNODE/v5" }
func (*Findnode) Kind() byte { return FindnodeMsg }
func (p *Findnode) RequestID() []byte { return p.ReqID }
func (p *Findnode) SetRequestID(id []byte) { p.ReqID = id }
func (*Nodes) Name() string { return "NODES/v5" }
func (*Nodes) Kind() byte { return NodesMsg }
func (p *Nodes) RequestID() []byte { return p.ReqID }
func (p *Nodes) SetRequestID(id []byte) { p.ReqID = id }
func (*TalkRequest) Name() string { return "TALKREQ/v5" }
func (*TalkRequest) Kind() byte { return TalkRequestMsg }
func (p *TalkRequest) RequestID() []byte { return p.ReqID }
func (p *TalkRequest) SetRequestID(id []byte) { p.ReqID = id }
func (*TalkResponse) Name() string { return "TALKRESP/v5" }
func (*TalkResponse) Kind() byte { return TalkResponseMsg }
func (p *TalkResponse) RequestID() []byte { return p.ReqID }
func (p *TalkResponse) SetRequestID(id []byte) { p.ReqID = id }
func (*RequestTicket) Name() string { return "REQTICKET/v5" }
func (*RequestTicket) Kind() byte { return RequestTicketMsg }
func (p *RequestTicket) RequestID() []byte { return p.ReqID }
func (p *RequestTicket) SetRequestID(id []byte) { p.ReqID = id }
func (*Regtopic) Name() string { return "REGTOPIC/v5" }
func (*Regtopic) Kind() byte { return RegtopicMsg }
func (p *Regtopic) RequestID() []byte { return p.ReqID }
func (p *Regtopic) SetRequestID(id []byte) { p.ReqID = id }
func (*Ticket) Name() string { return "TICKET/v5" }
func (*Ticket) Kind() byte { return TicketMsg }
func (p *Ticket) RequestID() []byte { return p.ReqID }
func (p *Ticket) SetRequestID(id []byte) { p.ReqID = id }
func (*Regconfirmation) Name() string { return "REGCONFIRMATION/v5" }
func (*Regconfirmation) Kind() byte { return RegconfirmationMsg }
func (p *Regconfirmation) RequestID() []byte { return p.ReqID }
func (p *Regconfirmation) SetRequestID(id []byte) { p.ReqID = id }
func (*TopicQuery) Name() string { return "TOPICQUERY/v5" }
func (*TopicQuery) Kind() byte { return TopicQueryMsg }
func (p *TopicQuery) RequestID() []byte { return p.ReqID }
func (p *TopicQuery) SetRequestID(id []byte) { p.ReqID = id }

@ -14,22 +14,33 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover package v5wire
import ( import (
"crypto/ecdsa"
crand "crypto/rand" crand "crypto/rand"
"encoding/binary"
"time"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/hashicorp/golang-lru/simplelru" "github.com/hashicorp/golang-lru/simplelru"
) )
// The sessionCache keeps negotiated encryption keys and const handshakeTimeout = time.Second
// The SessionCache keeps negotiated encryption keys and
// state for in-progress handshakes in the Discovery v5 wire protocol. // state for in-progress handshakes in the Discovery v5 wire protocol.
type sessionCache struct { type SessionCache struct {
sessions *simplelru.LRU sessions *simplelru.LRU
handshakes map[sessionID]*whoareyouV5 handshakes map[sessionID]*Whoareyou
clock mclock.Clock clock mclock.Clock
// hooks for overriding randomness.
nonceGen func(uint32) (Nonce, error)
maskingIVGen func([]byte) error
ephemeralKeyGen func() (*ecdsa.PrivateKey, error)
} }
// sessionID identifies a session or handshake. // sessionID identifies a session or handshake.
@ -45,27 +56,45 @@ type session struct {
nonceCounter uint32 nonceCounter uint32
} }
func newSessionCache(maxItems int, clock mclock.Clock) *sessionCache { // keysFlipped returns a copy of s with the read and write keys flipped.
func (s *session) keysFlipped() *session {
return &session{s.readKey, s.writeKey, s.nonceCounter}
}
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
cache, err := simplelru.NewLRU(maxItems, nil) cache, err := simplelru.NewLRU(maxItems, nil)
if err != nil { if err != nil {
panic("can't create session cache") panic("can't create session cache")
} }
return &sessionCache{ return &SessionCache{
sessions: cache, sessions: cache,
handshakes: make(map[sessionID]*whoareyouV5), handshakes: make(map[sessionID]*Whoareyou),
clock: clock, clock: clock,
nonceGen: generateNonce,
maskingIVGen: generateMaskingIV,
ephemeralKeyGen: crypto.GenerateKey,
} }
} }
func generateNonce(counter uint32) (n Nonce, err error) {
binary.BigEndian.PutUint32(n[:4], counter)
_, err = crand.Read(n[4:])
return n, err
}
func generateMaskingIV(buf []byte) error {
_, err := crand.Read(buf)
return err
}
// nextNonce creates a nonce for encrypting a message to the given session. // nextNonce creates a nonce for encrypting a message to the given session.
func (sc *sessionCache) nextNonce(id enode.ID, addr string) []byte { func (sc *SessionCache) nextNonce(s *session) (Nonce, error) {
n := make([]byte, gcmNonceSize) s.nonceCounter++
crand.Read(n) return sc.nonceGen(s.nonceCounter)
return n
} }
// session returns the current session for the given node, if any. // session returns the current session for the given node, if any.
func (sc *sessionCache) session(id enode.ID, addr string) *session { func (sc *SessionCache) session(id enode.ID, addr string) *session {
item, ok := sc.sessions.Get(sessionID{id, addr}) item, ok := sc.sessions.Get(sessionID{id, addr})
if !ok { if !ok {
return nil return nil
@ -74,46 +103,36 @@ func (sc *sessionCache) session(id enode.ID, addr string) *session {
} }
// readKey returns the current read key for the given node. // readKey returns the current read key for the given node.
func (sc *sessionCache) readKey(id enode.ID, addr string) []byte { func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
if s := sc.session(id, addr); s != nil { if s := sc.session(id, addr); s != nil {
return s.readKey return s.readKey
} }
return nil return nil
} }
// writeKey returns the current read key for the given node.
func (sc *sessionCache) writeKey(id enode.ID, addr string) []byte {
if s := sc.session(id, addr); s != nil {
return s.writeKey
}
return nil
}
// storeNewSession stores new encryption keys in the cache. // storeNewSession stores new encryption keys in the cache.
func (sc *sessionCache) storeNewSession(id enode.ID, addr string, r, w []byte) { func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
sc.sessions.Add(sessionID{id, addr}, &session{ sc.sessions.Add(sessionID{id, addr}, s)
readKey: r, writeKey: w,
})
} }
// getHandshake gets the handshake challenge we previously sent to the given remote node. // getHandshake gets the handshake challenge we previously sent to the given remote node.
func (sc *sessionCache) getHandshake(id enode.ID, addr string) *whoareyouV5 { func (sc *SessionCache) getHandshake(id enode.ID, addr string) *Whoareyou {
return sc.handshakes[sessionID{id, addr}] return sc.handshakes[sessionID{id, addr}]
} }
// storeSentHandshake stores the handshake challenge sent to the given remote node. // storeSentHandshake stores the handshake challenge sent to the given remote node.
func (sc *sessionCache) storeSentHandshake(id enode.ID, addr string, challenge *whoareyouV5) { func (sc *SessionCache) storeSentHandshake(id enode.ID, addr string, challenge *Whoareyou) {
challenge.sent = sc.clock.Now() challenge.sent = sc.clock.Now()
sc.handshakes[sessionID{id, addr}] = challenge sc.handshakes[sessionID{id, addr}] = challenge
} }
// deleteHandshake deletes handshake data for the given node. // deleteHandshake deletes handshake data for the given node.
func (sc *sessionCache) deleteHandshake(id enode.ID, addr string) { func (sc *SessionCache) deleteHandshake(id enode.ID, addr string) {
delete(sc.handshakes, sessionID{id, addr}) delete(sc.handshakes, sessionID{id, addr})
} }
// handshakeGC deletes timed-out handshakes. // handshakeGC deletes timed-out handshakes.
func (sc *sessionCache) handshakeGC() { func (sc *SessionCache) handshakeGC() {
deadline := sc.clock.Now().Add(-handshakeTimeout) deadline := sc.clock.Now().Add(-handshakeTimeout)
for key, challenge := range sc.handshakes { for key, challenge := range sc.handshakes {
if challenge.sent < deadline { if challenge.sent < deadline {

@ -0,0 +1,27 @@
# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb
# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9
# nonce = 0xffffffffffffffffffffffff
# read-key = 0x53b1c075f41876423154e157470c2f48
# ping.req-id = 0x00000001
# ping.enr-seq = 1
#
# handshake inputs:
#
# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000
# whoareyou.request-nonce = 0x0102030405060708090a0b0c
# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10
# whoareyou.enr-seq = 0
# ephemeral-key = 0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6
# ephemeral-pubkey = 0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5
00000000000000000000000000000000088b3d4342774649305f313964a39e55
ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3
4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856
2fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b2
1481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1
f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6
cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb1
2a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a
80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e
4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b1394
71

@ -0,0 +1,23 @@
# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb
# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9
# nonce = 0xffffffffffffffffffffffff
# read-key = 0x4f9fac6de7567d1e3b1241dffe90f662
# ping.req-id = 0x00000001
# ping.enr-seq = 1
#
# handshake inputs:
#
# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000001
# whoareyou.request-nonce = 0x0102030405060708090a0b0c
# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10
# whoareyou.enr-seq = 1
# ephemeral-key = 0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6
# ephemeral-pubkey = 0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5
00000000000000000000000000000000088b3d4342774649305f313964a39e55
ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3
4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef
268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb
a776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1
f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d83
9cf8

@ -0,0 +1,10 @@
# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb
# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9
# nonce = 0xffffffffffffffffffffffff
# read-key = 0x00000000000000000000000000000000
# ping.req-id = 0x00000001
# ping.enr-seq = 2
00000000000000000000000000000000088b3d4342774649325f313964a39e55
ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3
4c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc

@ -0,0 +1,9 @@
# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb
# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9
# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000
# whoareyou.request-nonce = 0x0102030405060708090a0b0c
# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10
# whoareyou.enr-seq = 0
00000000000000000000000000000000088b3d434277464933a1ccc59f5967ad
1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d

@ -23,3 +23,11 @@ func IsTemporaryError(err error) bool {
}) })
return ok && tempErr.Temporary() || isPacketTooBig(err) return ok && tempErr.Temporary() || isPacketTooBig(err)
} }
// IsTimeout checks whether the given error is a timeout.
func IsTimeout(err error) bool {
timeoutErr, ok := err.(interface {
Timeout() bool
})
return ok && timeoutErr.Timeout()
}