Merge pull request #3325 from fjl/p2p-netrestrict

Prevent relay of invalid IPs, add --netrestrict
This commit is contained in:
Felix Lange 2016-11-25 13:59:18 +01:00 committed by GitHub
commit d1a95c643e
25 changed files with 643 additions and 230 deletions

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
func main() { func main() {
@ -39,6 +40,7 @@ func main() {
nodeKeyFile = flag.String("nodekey", "", "private key filename") nodeKeyFile = flag.String("nodekey", "", "private key filename")
nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)") nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)")
natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)") natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
netrestrict = flag.String("netrestrict", "", "restrict network communication to the given IP networks (CIDR masks)")
runv5 = flag.Bool("v5", false, "run a v5 topic discovery bootnode") runv5 = flag.Bool("v5", false, "run a v5 topic discovery bootnode")
nodeKey *ecdsa.PrivateKey nodeKey *ecdsa.PrivateKey
@ -81,12 +83,20 @@ func main() {
os.Exit(0) os.Exit(0)
} }
var restrictList *netutil.Netlist
if *netrestrict != "" {
restrictList, err = netutil.ParseNetlist(*netrestrict)
if err != nil {
utils.Fatalf("-netrestrict: %v", err)
}
}
if *runv5 { if *runv5 {
if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} else { } else {
if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} }

@ -96,6 +96,7 @@ func init() {
utils.BootnodesFlag, utils.BootnodesFlag,
utils.KeyStoreDirFlag, utils.KeyStoreDirFlag,
utils.ListenPortFlag, utils.ListenPortFlag,
utils.NetrestrictFlag,
utils.MaxPeersFlag, utils.MaxPeersFlag,
utils.NATFlag, utils.NATFlag,
utils.NodeKeyFileFlag, utils.NodeKeyFileFlag,

@ -148,6 +148,7 @@ participating.
utils.NatspecEnabledFlag, utils.NatspecEnabledFlag,
utils.NoDiscoverFlag, utils.NoDiscoverFlag,
utils.DiscoveryV5Flag, utils.DiscoveryV5Flag,
utils.NetrestrictFlag,
utils.NodeKeyFileFlag, utils.NodeKeyFileFlag,
utils.NodeKeyHexFlag, utils.NodeKeyHexFlag,
utils.RPCEnabledFlag, utils.RPCEnabledFlag,

@ -45,6 +45,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/pow" "github.com/ethereum/go-ethereum/pow"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
@ -366,10 +367,16 @@ var (
Name: "v5disc", Name: "v5disc",
Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism", Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism",
} }
NetrestrictFlag = cli.StringFlag{
Name: "netrestrict",
Usage: "Restricts network communication to the given IP networks (CIDR masks)",
}
WhisperEnabledFlag = cli.BoolFlag{ WhisperEnabledFlag = cli.BoolFlag{
Name: "shh", Name: "shh",
Usage: "Enable Whisper", Usage: "Enable Whisper",
} }
// ATM the url is left to the user and deployment to // ATM the url is left to the user and deployment to
JSpathFlag = cli.StringFlag{ JSpathFlag = cli.StringFlag{
Name: "jspath", Name: "jspath",
@ -693,6 +700,14 @@ func MakeNode(ctx *cli.Context, name, gitCommit string) *node.Node {
config.MaxPeers = 0 config.MaxPeers = 0
config.ListenAddr = ":0" config.ListenAddr = ":0"
} }
if netrestrict := ctx.GlobalString(NetrestrictFlag.Name); netrestrict != "" {
list, err := netutil.ParseNetlist(netrestrict)
if err != nil {
Fatalf("Option %q: %v", NetrestrictFlag.Name, err)
}
config.NetRestrict = list
}
stack, err := node.New(config) stack, err := node.New(config)
if err != nil { if err != nil {
Fatalf("Failed to create the protocol stack: %v", err) Fatalf("Failed to create the protocol stack: %v", err)

@ -34,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
var ( var (
@ -103,6 +104,10 @@ type Config struct {
// Listener address for the V5 discovery protocol UDP traffic. // Listener address for the V5 discovery protocol UDP traffic.
DiscoveryV5Addr string DiscoveryV5Addr string
// Restrict communication to white listed IP networks.
// The whitelist only applies when non-nil.
NetRestrict *netutil.Netlist
// BootstrapNodes used to establish connectivity with the rest of the network. // BootstrapNodes used to establish connectivity with the rest of the network.
BootstrapNodes []*discover.Node BootstrapNodes []*discover.Node

@ -165,6 +165,7 @@ func (n *Node) Start() error {
TrustedNodes: n.config.TrusterNodes(), TrustedNodes: n.config.TrusterNodes(),
NodeDatabase: n.config.NodeDB(), NodeDatabase: n.config.NodeDB(),
ListenAddr: n.config.ListenAddr, ListenAddr: n.config.ListenAddr,
NetRestrict: n.config.NetRestrict,
NAT: n.config.NAT, NAT: n.config.NAT,
Dialer: n.config.Dialer, Dialer: n.config.Dialer,
NoDial: n.config.NoDial, NoDial: n.config.NoDial,

@ -19,6 +19,7 @@ package p2p
import ( import (
"container/heap" "container/heap"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
const ( const (
@ -48,6 +50,7 @@ const (
type dialstate struct { type dialstate struct {
maxDynDials int maxDynDials int
ntab discoverTable ntab discoverTable
netrestrict *netutil.Netlist
lookupRunning bool lookupRunning bool
dialing map[discover.NodeID]connFlag dialing map[discover.NodeID]connFlag
@ -100,10 +103,11 @@ type waitExpireTask struct {
time.Duration time.Duration
} }
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab, ntab: ntab,
netrestrict: netrestrict,
static: make(map[discover.NodeID]*dialTask), static: make(map[discover.NodeID]*dialTask),
dialing: make(map[discover.NodeID]connFlag), dialing: make(map[discover.NodeID]connFlag),
randomNodes: make([]*discover.Node, maxdyn/2), randomNodes: make([]*discover.Node, maxdyn/2),
@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) {
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
var newtasks []task var newtasks []task
isDialing := func(id discover.NodeID) bool {
_, found := s.dialing[id]
return found || peers[id] != nil || s.hist.contains(id)
}
addDial := func(flag connFlag, n *discover.Node) bool { addDial := func(flag connFlag, n *discover.Node) bool {
if isDialing(n.ID) { if err := s.checkDial(n, peers); err != nil {
glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err)
return false return false
} }
s.dialing[n.ID] = flag s.dialing[n.ID] = flag
@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
// Create dials for static nodes if they are not connected. // Create dials for static nodes if they are not connected.
for id, t := range s.static { for id, t := range s.static {
if !isDialing(id) { err := s.checkDial(t.dest, peers)
switch err {
case errNotWhitelisted, errSelf:
glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err)
delete(s.static, t.dest.ID)
case nil:
s.dialing[id] = t.flags s.dialing[id] = t.flags
newtasks = append(newtasks, t) newtasks = append(newtasks, t)
} }
@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now
return newtasks return newtasks
} }
var (
errSelf = errors.New("is self")
errAlreadyDialing = errors.New("already dialing")
errAlreadyConnected = errors.New("already connected")
errRecentlyDialed = errors.New("recently dialed")
errNotWhitelisted = errors.New("not contained in netrestrict whitelist")
)
func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
_, dialing := s.dialing[n.ID]
switch {
case dialing:
return errAlreadyDialing
case peers[n.ID] != nil:
return errAlreadyConnected
case s.ntab != nil && n.ID == s.ntab.Self().ID:
return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
return errNotWhitelisted
case s.hist.contains(n.ID):
return errRecentlyDialed
}
return nil
}
func (s *dialstate) taskDone(t task, now time.Time) { func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) { switch t := t.(type) {
case *dialTask: case *dialTask:

@ -25,6 +25,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
func init() { func init() {
@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf,
// This test checks that dynamic dials are launched from discovery results. // This test checks that dynamic dials are launched from discovery results.
func TestDialStateDynDial(t *testing.T) { func TestDialStateDynDial(t *testing.T) {
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, fakeTable{}, 5), init: newDialState(nil, fakeTable{}, 5, nil),
rounds: []round{ rounds: []round{
// A discovery query is launched. // A discovery query is launched.
{ {
@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(nil, table, 10), init: newDialState(nil, table, 10, nil),
rounds: []round{ rounds: []round{
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
{ {
@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) {
}) })
} }
// This test checks that candidates that do not match the netrestrict list are not dialed.
func TestDialStateNetRestrict(t *testing.T) {
// This table always returns the same random nodes
// in the order given below.
table := fakeTable{
{ID: uintID(1), IP: net.ParseIP("127.0.0.1")},
{ID: uintID(2), IP: net.ParseIP("127.0.0.2")},
{ID: uintID(3), IP: net.ParseIP("127.0.0.3")},
{ID: uintID(4), IP: net.ParseIP("127.0.0.4")},
{ID: uintID(5), IP: net.ParseIP("127.0.2.5")},
{ID: uintID(6), IP: net.ParseIP("127.0.2.6")},
{ID: uintID(7), IP: net.ParseIP("127.0.2.7")},
{ID: uintID(8), IP: net.ParseIP("127.0.2.8")},
}
restrict := new(netutil.Netlist)
restrict.Add("127.0.2.0/24")
runDialTest(t, dialtest{
init: newDialState(nil, table, 10, restrict),
rounds: []round{
{
new: []task{
&dialTask{flags: dynDialedConn, dest: table[4]},
&discoverTask{},
},
},
},
})
}
// This test checks that static dials are launched. // This test checks that static dials are launched.
func TestDialStateStaticDial(t *testing.T) { func TestDialStateStaticDial(t *testing.T) {
wantStatic := []*discover.Node{ wantStatic := []*discover.Node{
@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0), init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) {
} }
runDialTest(t, dialtest{ runDialTest(t, dialtest{
init: newDialState(wantStatic, fakeTable{}, 0), init: newDialState(wantStatic, fakeTable{}, 0, nil),
rounds: []round{ rounds: []round{
// Static dials are launched for the nodes that // Static dials are launched for the nodes that
// aren't yet connected. // aren't yet connected.
@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) {
func TestDialResolve(t *testing.T) { func TestDialResolve(t *testing.T) {
resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444)
table := &resolveMock{answer: resolved} table := &resolveMock{answer: resolved}
state := newDialState(nil, table, 0) state := newDialState(nil, table, 0, nil)
// Check that the task is generated with an incomplete ID. // Check that the task is generated with an incomplete ID.
dest := discover.NewNode(uintID(1), nil, 0, 0) dest := discover.NewNode(uintID(1), nil, 0, 0)

@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node) n = new(Node)
n.sha = hashAtDistance(base, ld) n.sha = hashAtDistance(base, ld)
n.IP = net.IP{10, 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n return n
} }

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -126,8 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
} }
func nodeFromRPC(rn rpcNode) (*Node, error) { func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
// TODO: don't accept localhost, LAN addresses from internet hosts if rn.UDP <= 1024 {
return nil, errors.New("low port")
}
if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
return nil, err
}
if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
return nil, errors.New("not contained in netrestrict whitelist")
}
n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
err := n.validateComplete() err := n.validateComplete()
return n, err return n, err
@ -151,6 +160,7 @@ type conn interface {
// udp implements the RPC protocol. // udp implements the RPC protocol.
type udp struct { type udp struct {
conn conn conn conn
netrestrict *netutil.Netlist
priv *ecdsa.PrivateKey priv *ecdsa.PrivateKey
ourEndpoint rpcEndpoint ourEndpoint rpcEndpoint
@ -201,7 +211,7 @@ type reply struct {
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) {
addr, err := net.ResolveUDPAddr("udp", laddr) addr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -210,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
if err != nil { if err != nil {
return nil, err return nil, err
} }
tab, _, err := newUDP(priv, conn, natm, nodeDBPath) tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -218,10 +228,11 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP
return tab, nil return tab, nil
} }
func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) {
udp := &udp{ udp := &udp{
conn: c, conn: c,
priv: priv, priv: priv,
netrestrict: netrestrict,
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply), gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
@ -281,9 +292,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node
reply := r.(*neighbors) reply := r.(*neighbors)
for _, rn := range reply.Nodes { for _, rn := range reply.Nodes {
nreceived++ nreceived++
if n, err := nodeFromRPC(rn); err == nil { n, err := t.nodeFromRPC(toaddr, rn)
nodes = append(nodes, n) if err != nil {
glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err)
continue
} }
nodes = append(nodes, n)
} }
return nreceived >= bucketSize return nreceived >= bucketSize
}) })
@ -479,13 +493,6 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte,
return packet, nil return packet, nil
} }
func isTemporaryError(err error) bool {
tempErr, ok := err.(interface {
Temporary() bool
})
return ok && tempErr.Temporary() || isPacketTooBig(err)
}
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop() { func (t *udp) readLoop() {
defer t.conn.Close() defer t.conn.Close()
@ -495,7 +502,7 @@ func (t *udp) readLoop() {
buf := make([]byte, 1280) buf := make([]byte, 1280)
for { for {
nbytes, from, err := t.conn.ReadFromUDP(buf) nbytes, from, err := t.conn.ReadFromUDP(buf)
if isTemporaryError(err) { if netutil.IsTemporaryError(err) {
// Ignore temporary read errors. // Ignore temporary read errors.
glog.V(logger.Debug).Infof("Temporary read error: %v", err) glog.V(logger.Debug).Infof("Temporary read error: %v", err)
continue continue
@ -602,6 +609,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
// Send neighbors in chunks with at most maxNeighbors per packet // Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit. // to stay below the 1280 byte limit.
for i, n := range closest { for i, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP) != nil {
continue
}
p.Nodes = append(p.Nodes, nodeToRPC(n)) p.Nodes = append(p.Nodes, nodeToRPC(n))
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
t.send(from, neighborsPacket, p) t.send(from, neighborsPacket, p)

@ -43,56 +43,6 @@ func init() {
spew.Config.DisableMethods = true spew.Config.DisableMethods = true
} }
// This test checks that isPacketTooBig correctly identifies
// errors that result from receiving a UDP packet larger
// than the supplied receive buffer.
func TestIsPacketTooBig(t *testing.T) {
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
sender, err := net.Dial("udp", listener.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer sender.Close()
sendN := 1800
recvN := 300
for i := 0; i < 20; i++ {
go func() {
buf := make([]byte, sendN)
for i := range buf {
buf[i] = byte(i)
}
sender.Write(buf)
}()
buf := make([]byte, recvN)
listener.SetDeadline(time.Now().Add(1 * time.Second))
n, _, err := listener.ReadFrom(buf)
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
continue
}
if !isPacketTooBig(err) {
t.Fatal("unexpected read error:", spew.Sdump(err))
}
continue
}
if n != recvN {
t.Fatalf("short read: %d, want %d", n, recvN)
}
for i := range buf {
if buf[i] != byte(i) {
t.Fatalf("error in pattern")
break
}
}
}
}
// shared test variables // shared test variables
var ( var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
@ -118,9 +68,9 @@ func newUDPTest(t *testing.T) *udpTest {
pipe: newpipe(), pipe: newpipe(),
localkey: newkey(), localkey: newkey(),
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
} }
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil)
return test return test
} }
@ -362,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) {
// check that the sent neighbors are all returned by findnode // check that the sent neighbors are all returned by findnode
select { select {
case result := <-resultc: case result := <-resultc:
if !reflect.DeepEqual(result, list) { want := append(list[:2], list[3:]...)
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list) if !reflect.DeepEqual(result, want) {
t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, want)
} }
case err := <-errc: case err := <-errc:
t.Errorf("findnode error: %v", err) t.Errorf("findnode error: %v", err)

@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -45,6 +46,7 @@ const (
bucketRefreshInterval = 1 * time.Minute bucketRefreshInterval = 1 * time.Minute
seedCount = 30 seedCount = 30
seedMaxAge = 5 * 24 * time.Hour seedMaxAge = 5 * 24 * time.Hour
lowPort = 1024
) )
const testTopic = "foo" const testTopic = "foo"
@ -64,6 +66,7 @@ func debugLog(s string) {
type Network struct { type Network struct {
db *nodeDB // database of known nodes db *nodeDB // database of known nodes
conn transport conn transport
netrestrict *netutil.Netlist
closed chan struct{} // closed when loop is done closed chan struct{} // closed when loop is done
closeReq chan struct{} // 'request to close' closeReq chan struct{} // 'request to close'
@ -132,7 +135,7 @@ type timeoutEvent struct {
node *Node node *Node
} }
func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
ourID := PubkeyID(&ourPubkey) ourID := PubkeyID(&ourPubkey)
var db *nodeDB var db *nodeDB
@ -147,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d
net := &Network{ net := &Network{
db: db, db: db,
conn: conn, conn: conn,
netrestrict: netrestrict,
tab: tab, tab: tab,
topictab: newTopicTable(db, tab.self), topictab: newTopicTable(db, tab.self),
ticketStore: newTicketStore(), ticketStore: newTicketStore(),
@ -684,16 +688,22 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node {
return n return n
} }
func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) { func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
if rn.ID == net.tab.self.ID { if rn.ID == net.tab.self.ID {
return nil, errors.New("is self") return nil, errors.New("is self")
} }
if rn.UDP <= lowPort {
return nil, errors.New("low port")
}
n = net.nodes[rn.ID] n = net.nodes[rn.ID]
if n == nil { if n == nil {
// We haven't seen this node before. // We haven't seen this node before.
n, err = nodeFromRPC(rn) n, err = nodeFromRPC(sender, rn)
n.state = unknown if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
return n, errors.New("not contained in netrestrict whitelist")
}
if err == nil { if err == nil {
n.state = unknown
net.nodes[n.ID] = n net.nodes[n.ID] = n
} }
return n, err return n, err
@ -1095,7 +1105,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
net.conn.sendNeighbours(n, results) net.conn.sendNeighbours(n, results)
return n.state, nil return n.state, nil
case neighborsPacket: case neighborsPacket:
err := net.handleNeighboursPacket(n, pkt.data.(*neighbors)) err := net.handleNeighboursPacket(n, pkt)
return n.state, err return n.state, err
case neighboursTimeout: case neighboursTimeout:
if n.pendingNeighbours != nil { if n.pendingNeighbours != nil {
@ -1182,17 +1192,18 @@ func rlpHash(x interface{}) (h common.Hash) {
return h return h
} }
func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error { func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
if n.pendingNeighbours == nil { if n.pendingNeighbours == nil {
return errNoQuery return errNoQuery
} }
net.abortTimedEvent(n, neighboursTimeout) net.abortTimedEvent(n, neighboursTimeout)
req := pkt.data.(*neighbors)
nodes := make([]*Node, len(req.Nodes)) nodes := make([]*Node, len(req.Nodes))
for i, rn := range req.Nodes { for i, rn := range req.Nodes {
nn, err := net.internNodeFromNeighbours(rn) nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
if err != nil { if err != nil {
glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err) glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err)
continue continue
} }
nodes[i] = nn nodes[i] = nn

@ -28,7 +28,7 @@ import (
func TestNetwork_Lookup(t *testing.T) { func TestNetwork_Lookup(t *testing.T) {
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) {
// t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) // t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
// } // }
// seed table with initial node (otherwise lookup will terminate immediately) // seed table with initial node (otherwise lookup will terminate immediately)
seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)} seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)}
if err := network.SetFallbackNodes(seeds); err != nil { if err := network.SetFallbackNodes(seeds); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) {
func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) { func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) {
// current log distance is encoded in port number // current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port) // fmt.Println("findnode query at dist", toaddr.Port)
if to.UDP == 0 { if to.UDP <= lowPort {
panic("query to node at distance 0") panic("query to node at or below distance 0")
} }
next := to.UDP - 1 next := to.UDP - 1
var result []rpcNode var result []rpcNode
for i, id := range tn.dists[to.UDP] { for i, id := range tn.dists[to.UDP-lowPort] {
result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
} }
injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
} }
@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha
// ignored // ignored
case findnodeHashPacket: case findnodeHashPacket:
// current log distance is encoded in port number // current log distance is encoded in port number
// fmt.Println("findnode query at dist", toaddr.Port) // fmt.Println("findnode query at dist", toaddr.Port-lowPort)
if to.UDP == 0 { if to.UDP <= lowPort {
panic("query to node at distance 0") panic("query to node at or below distance 0")
} }
next := to.UDP - 1 next := to.UDP - 1
var result []rpcNode var result []rpcNode
for i, id := range tn.dists[to.UDP] { for i, id := range tn.dists[to.UDP-lowPort] {
result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort)))
} }
injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result})
default: default:
@ -329,7 +329,10 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int,
} }
func (*preminedTestnet) Close() {} func (*preminedTestnet) Close() {}
func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) }
func (*preminedTestnet) localAddr() *net.UDPAddr {
return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000}
}
// mine generates a testnet struct literal with nodes at // mine generates a testnet struct literal with nodes at
// various distances to the given target. // various distances to the given target.

@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network {
addr := &net.UDPAddr{IP: ip, Port: 30303} addr := &net.UDPAddr{IP: ip, Port: 30303}
transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key}
net, err := newNetwork(transport, key.PublicKey, nil, "<no database>") net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil)
if err != nil { if err != nil {
panic("cannot launch new node: " + err.Error()) panic("cannot launch new node: " + err.Error())
} }

@ -29,6 +29,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP) return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP)
} }
func nodeFromRPC(rn rpcNode) (*Node, error) { func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
// TODO: don't accept localhost, LAN addresses from internet hosts if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
return nil, err
}
n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
err := n.validateComplete() err := n.validateComplete()
return n, err return n, err
@ -235,12 +238,12 @@ type udp struct {
} }
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
transport, err := listenUDP(priv, laddr) transport, err := listenUDP(priv, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
return return
} }
for i, result := range nodes { for i, result := range nodes {
if netutil.CheckRelayIP(remote.IP, result.IP) != nil {
continue
}
p.Nodes = append(p.Nodes, nodeToRPC(result)) p.Nodes = append(p.Nodes, nodeToRPC(result))
if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 {
t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
@ -385,7 +391,7 @@ func (t *udp) readLoop() {
buf := make([]byte, 1280) buf := make([]byte, 1280)
for { for {
nbytes, from, err := t.conn.ReadFromUDP(buf) nbytes, from, err := t.conn.ReadFromUDP(buf)
if isTemporaryError(err) { if netutil.IsTemporaryError(err) {
// Ignore temporary read errors. // Ignore temporary read errors.
glog.V(logger.Debug).Infof("Temporary read error: %v", err) glog.V(logger.Debug).Infof("Temporary read error: %v", err)
continue continue
@ -398,13 +404,6 @@ func (t *udp) readLoop() {
} }
} }
func isTemporaryError(err error) bool {
tempErr, ok := err.(interface {
Temporary() bool
})
return ok && tempErr.Temporary() || isPacketTooBig(err)
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
pkt := ingressPacket{remoteAddr: from} pkt := ingressPacket{remoteAddr: from}
if err := decodePacket(buf, &pkt); err != nil { if err := decodePacket(buf, &pkt); err != nil {

@ -36,56 +36,6 @@ func init() {
spew.Config.DisableMethods = true spew.Config.DisableMethods = true
} }
// This test checks that isPacketTooBig correctly identifies
// errors that result from receiving a UDP packet larger
// than the supplied receive buffer.
func TestIsPacketTooBig(t *testing.T) {
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
sender, err := net.Dial("udp", listener.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer sender.Close()
sendN := 1800
recvN := 300
for i := 0; i < 20; i++ {
go func() {
buf := make([]byte, sendN)
for i := range buf {
buf[i] = byte(i)
}
sender.Write(buf)
}()
buf := make([]byte, recvN)
listener.SetDeadline(time.Now().Add(1 * time.Second))
n, _, err := listener.ReadFrom(buf)
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
continue
}
if !isPacketTooBig(err) {
t.Fatal("unexpected read error:", spew.Sdump(err))
}
continue
}
if n != recvN {
t.Fatalf("short read: %d, want %d", n, recvN)
}
for i := range buf {
if buf[i] != byte(i) {
t.Fatalf("error in pattern")
break
}
}
}
}
// shared test variables // shared test variables
var ( var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())

@ -1,40 +0,0 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
//+build windows
package discv5
import (
"net"
"os"
"syscall"
)
const _WSAEMSGSIZE = syscall.Errno(10040)
// reports whether err indicates that a UDP packet didn't
// fit the receive buffer. On Windows, WSARecvFrom returns
// code WSAEMSGSIZE and no data if this happens.
func isPacketTooBig(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
if scErr, ok := opErr.Err.(*os.SyscallError); ok {
return scErr.Err == _WSAEMSGSIZE
}
return opErr.Err == _WSAEMSGSIZE
}
return false
}

@ -14,13 +14,12 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
//+build !windows package netutil
package discv5 // IsTemporaryError checks whether the given error should be considered temporary.
func IsTemporaryError(err error) bool {
// reports whether err indicates that a UDP packet didn't tempErr, ok := err.(interface {
// fit the receive buffer. There is no such error on Temporary() bool
// non-Windows platforms. })
func isPacketTooBig(err error) bool { return ok && tempErr.Temporary() || isPacketTooBig(err)
return false
} }

73
p2p/netutil/error_test.go Normal file

@ -0,0 +1,73 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import (
"net"
"testing"
"time"
)
// This test checks that isPacketTooBig correctly identifies
// errors that result from receiving a UDP packet larger
// than the supplied receive buffer.
func TestIsPacketTooBig(t *testing.T) {
listener, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
sender, err := net.Dial("udp", listener.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer sender.Close()
sendN := 1800
recvN := 300
for i := 0; i < 20; i++ {
go func() {
buf := make([]byte, sendN)
for i := range buf {
buf[i] = byte(i)
}
sender.Write(buf)
}()
buf := make([]byte, recvN)
listener.SetDeadline(time.Now().Add(1 * time.Second))
n, _, err := listener.ReadFrom(buf)
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
continue
}
if !isPacketTooBig(err) {
t.Fatalf("unexpected read error: %v", err)
}
continue
}
if n != recvN {
t.Fatalf("short read: %d, want %d", n, recvN)
}
for i := range buf {
if buf[i] != byte(i) {
t.Fatalf("error in pattern")
break
}
}
}
}

166
p2p/netutil/net.go Normal file

@ -0,0 +1,166 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package netutil contains extensions to the net package.
package netutil
import (
"errors"
"net"
"strings"
)
var lan4, lan6, special4, special6 Netlist
func init() {
// Lists from RFC 5735, RFC 5156,
// https://www.iana.org/assignments/iana-ipv4-special-registry/
lan4.Add("0.0.0.0/8") // "This" network
lan4.Add("10.0.0.0/8") // Private Use
lan4.Add("172.16.0.0/12") // Private Use
lan4.Add("192.168.0.0/16") // Private Use
lan6.Add("fe80::/10") // Link-Local
lan6.Add("fc00::/7") // Unique-Local
special4.Add("192.0.0.0/29") // IPv4 Service Continuity
special4.Add("192.0.0.9/32") // PCP Anycast
special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery
special4.Add("192.0.0.171/32") // NAT64/DNS64 Discovery
special4.Add("192.0.2.0/24") // TEST-NET-1
special4.Add("192.31.196.0/24") // AS112
special4.Add("192.52.193.0/24") // AMT
special4.Add("192.88.99.0/24") // 6to4 Relay Anycast
special4.Add("192.175.48.0/24") // AS112
special4.Add("198.18.0.0/15") // Device Benchmark Testing
special4.Add("198.51.100.0/24") // TEST-NET-2
special4.Add("203.0.113.0/24") // TEST-NET-3
special4.Add("255.255.255.255/32") // Limited Broadcast
// http://www.iana.org/assignments/iana-ipv6-special-registry/
special6.Add("100::/64")
special6.Add("2001::/32")
special6.Add("2001:1::1/128")
special6.Add("2001:2::/48")
special6.Add("2001:3::/32")
special6.Add("2001:4:112::/48")
special6.Add("2001:5::/32")
special6.Add("2001:10::/28")
special6.Add("2001:20::/28")
special6.Add("2001:db8::/32")
special6.Add("2002::/16")
}
// Netlist is a list of IP networks.
type Netlist []net.IPNet
// ParseNetlist parses a comma-separated list of CIDR masks.
// Whitespace and extra commas are ignored.
func ParseNetlist(s string) (*Netlist, error) {
ws := strings.NewReplacer(" ", "", "\n", "", "\t", "")
masks := strings.Split(ws.Replace(s), ",")
l := make(Netlist, 0)
for _, mask := range masks {
if mask == "" {
continue
}
_, n, err := net.ParseCIDR(mask)
if err != nil {
return nil, err
}
l = append(l, *n)
}
return &l, nil
}
// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
// intended to be used for setting up static lists.
func (l *Netlist) Add(cidr string) {
_, n, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
*l = append(*l, *n)
}
// Contains reports whether the given IP is contained in the list.
func (l *Netlist) Contains(ip net.IP) bool {
if l == nil {
return false
}
for _, net := range *l {
if net.Contains(ip) {
return true
}
}
return false
}
// IsLAN reports whether an IP is a local network address.
func IsLAN(ip net.IP) bool {
if ip.IsLoopback() {
return true
}
if v4 := ip.To4(); v4 != nil {
return lan4.Contains(v4)
}
return lan6.Contains(ip)
}
// IsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func IsSpecialNetwork(ip net.IP) bool {
if ip.IsMulticast() {
return true
}
if v4 := ip.To4(); v4 != nil {
return special4.Contains(v4)
}
return special6.Contains(ip)
}
var (
errInvalid = errors.New("invalid IP")
errUnspecified = errors.New("zero address")
errSpecial = errors.New("special network")
errLoopback = errors.New("loopback address from non-loopback host")
errLAN = errors.New("LAN address from WAN host")
)
// CheckRelayIP reports whether an IP relayed from the given sender IP
// is a valid connection target.
//
// There are four rules:
// - Special network addresses are never valid.
// - Loopback addresses are OK if relayed by a loopback host.
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayIP(sender, addr net.IP) error {
if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
return errInvalid
}
if addr.IsUnspecified() {
return errUnspecified
}
if IsSpecialNetwork(addr) {
return errSpecial
}
if addr.IsLoopback() && !sender.IsLoopback() {
return errLoopback
}
if IsLAN(addr) && !IsLAN(sender) {
return errLAN
}
return nil
}

173
p2p/netutil/net_test.go Normal file

@ -0,0 +1,173 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import (
"net"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
)
func TestParseNetlist(t *testing.T) {
var tests = []struct {
input string
wantErr error
wantList *Netlist
}{
{
input: "",
wantList: &Netlist{},
},
{
input: "127.0.0.0/8",
wantErr: nil,
wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}},
},
{
input: "127.0.0.0/44",
wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"},
},
{
input: "127.0.0.0/16, 23.23.23.23/24,",
wantList: &Netlist{
{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)},
{IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)},
},
},
}
for _, test := range tests {
l, err := ParseNetlist(test.input)
if !reflect.DeepEqual(err, test.wantErr) {
t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr)
continue
}
if !reflect.DeepEqual(l, test.wantList) {
spew.Dump(l)
spew.Dump(test.wantList)
t.Errorf("%q: got %v, want %v", test.input, l, test.wantList)
}
}
}
func TestNilNetListContains(t *testing.T) {
var list *Netlist
checkContains(t, list.Contains, nil, []string{"1.2.3.4"})
}
func TestIsLAN(t *testing.T) {
checkContains(t, IsLAN,
[]string{ // included
"0.0.0.0",
"0.2.0.8",
"127.0.0.1",
"10.0.1.1",
"10.22.0.3",
"172.31.252.251",
"192.168.1.4",
"fe80::f4a1:8eff:fec5:9d9d",
"febf::ab32:2233",
"fc00::4",
},
[]string{ // excluded
"192.0.2.1",
"1.0.0.0",
"172.32.0.1",
"fec0::2233",
},
)
}
func TestIsSpecialNetwork(t *testing.T) {
checkContains(t, IsSpecialNetwork,
[]string{ // included
"192.0.2.1",
"192.0.2.44",
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
"255.255.255.255",
"224.0.0.22", // IPv4 multicast
"ff05::1:3", // IPv6 multicast
},
[]string{ // excluded
"192.0.3.1",
"1.0.0.0",
"172.32.0.1",
"fec0::2233",
},
)
}
func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) {
for _, s := range inc {
if !fn(parseIP(s)) {
t.Error("returned false for included address", s)
}
}
for _, s := range exc {
if fn(parseIP(s)) {
t.Error("returned true for excluded address", s)
}
}
}
func parseIP(s string) net.IP {
ip := net.ParseIP(s)
if ip == nil {
panic("invalid " + s)
}
return ip
}
func TestCheckRelayIP(t *testing.T) {
tests := []struct {
sender, addr string
want error
}{
{"127.0.0.1", "0.0.0.0", errUnspecified},
{"192.168.0.1", "0.0.0.0", errUnspecified},
{"23.55.1.242", "0.0.0.0", errUnspecified},
{"127.0.0.1", "255.255.255.255", errSpecial},
{"192.168.0.1", "255.255.255.255", errSpecial},
{"23.55.1.242", "255.255.255.255", errSpecial},
{"192.168.0.1", "127.0.2.19", errLoopback},
{"23.55.1.242", "192.168.0.1", errLAN},
{"127.0.0.1", "127.0.2.19", nil},
{"127.0.0.1", "192.168.0.1", nil},
{"127.0.0.1", "23.55.1.242", nil},
{"192.168.0.1", "192.168.0.1", nil},
{"192.168.0.1", "23.55.1.242", nil},
{"23.55.1.242", "23.55.1.242", nil},
}
for _, test := range tests {
err := CheckRelayIP(parseIP(test.sender), parseIP(test.addr))
if err != test.want {
t.Errorf("%s from %s: got %q, want %q", test.addr, test.sender, err, test.want)
}
}
}
func BenchmarkCheckRelayIP(b *testing.B) {
sender := parseIP("23.55.1.242")
addr := parseIP("23.55.1.2")
for i := 0; i < b.N; i++ {
CheckRelayIP(sender, addr)
}
}

@ -16,9 +16,9 @@
//+build !windows //+build !windows
package discover package netutil
// reports whether err indicates that a UDP packet didn't // isPacketTooBig reports whether err indicates that a UDP packet didn't
// fit the receive buffer. There is no such error on // fit the receive buffer. There is no such error on
// non-Windows platforms. // non-Windows platforms.
func isPacketTooBig(err error) bool { func isPacketTooBig(err error) bool {

@ -16,7 +16,7 @@
//+build windows //+build windows
package discover package netutil
import ( import (
"net" "net"
@ -26,7 +26,7 @@ import (
const _WSAEMSGSIZE = syscall.Errno(10040) const _WSAEMSGSIZE = syscall.Errno(10040)
// reports whether err indicates that a UDP packet didn't // isPacketTooBig reports whether err indicates that a UDP packet didn't
// fit the receive buffer. On Windows, WSARecvFrom returns // fit the receive buffer. On Windows, WSARecvFrom returns
// code WSAEMSGSIZE and no data if this happens. // code WSAEMSGSIZE and no data if this happens.
func isPacketTooBig(err error) bool { func isPacketTooBig(err error) bool {

@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
) )
const ( const (
@ -101,6 +102,11 @@ type Config struct {
// allowed to connect, even above the peer limit. // allowed to connect, even above the peer limit.
TrustedNodes []*discover.Node TrustedNodes []*discover.Node
// Connectivity can be restricted to certain IP networks.
// If this option is set to a non-nil value, only hosts which match one of the
// IP networks contained in the list are considered.
NetRestrict *netutil.Netlist
// NodeDatabase is the path to the database containing the previously seen // NodeDatabase is the path to the database containing the previously seen
// live nodes in the network. // live nodes in the network.
NodeDatabase string NodeDatabase string
@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) {
// node table // node table
if srv.Discovery { if srv.Discovery {
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict)
if err != nil { if err != nil {
return err return err
} }
@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) {
} }
if srv.DiscoveryV5 { if srv.DiscoveryV5 {
ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase) ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase)
if err != nil { if err != nil {
return err return err
} }
@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) {
if !srv.Discovery { if !srv.Discovery {
dynPeers = 0 dynPeers = 0
} }
dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake // handshake
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)}
@ -634,8 +640,19 @@ func (srv *Server) listenLoop() {
} }
break break
} }
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil {
if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr())
fd.Close()
slots <- struct{}{}
continue
}
}
fd = newMeteredConn(fd, true) fd = newMeteredConn(fd, true)
glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection // Spawn the handler. It will give the slot back when the connection
// has been established. // has been established.

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/swarm/network/kademlia" "github.com/ethereum/go-ethereum/swarm/network/kademlia"
"github.com/ethereum/go-ethereum/swarm/storage" "github.com/ethereum/go-ethereum/swarm/storage"
) )
@ -288,6 +289,10 @@ func newNodeRecord(addr *peerAddr) *kademlia.NodeRecord {
func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) { func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) {
var nrs []*kademlia.NodeRecord var nrs []*kademlia.NodeRecord
for _, p := range req.Peers { for _, p := range req.Peers {
if err := netutil.CheckRelayIP(from.remoteAddr.IP, p.IP); err != nil {
glog.V(logger.Detail).Infof("invalid peer IP %v from %v: %v", from.remoteAddr.IP, p.IP, err)
continue
}
nrs = append(nrs, newNodeRecord(p)) nrs = append(nrs, newNodeRecord(p))
} }
self.kad.Add(nrs) self.kad.Add(nrs)