p2p/discover: implement ENR node filtering (#1320)

This commit is contained in:
Matus Kysel 2023-05-08 09:54:12 +02:00 committed by GitHub
parent acaafde156
commit 031fce3c92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 203 additions and 49 deletions

@ -45,8 +45,10 @@ func main() {
runv5 = flag.Bool("v5", false, "run a v5 topic discovery bootnode")
verbosity = flag.Int("verbosity", int(log.LvlInfo), "log verbosity (0-5)")
vmodule = flag.String("vmodule", "", "log verbosity pattern")
networkFilter = flag.String("network", "", "<bsc/chapel/rialto> filters nodes by eth ENR entry")
nodeKey *ecdsa.PrivateKey
filterFunction discover.NodeFilterFunc
err error
)
flag.Parse()
@ -86,6 +88,12 @@ func main() {
}
}
if *networkFilter != "" {
if filterFunction, err = discover.ParseEthFilter(*networkFilter); err != nil {
utils.Fatalf("-network: %v", err)
}
}
if *writeAddr {
fmt.Printf("%x\n", crypto.FromECDSAPub(&nodeKey.PublicKey)[1:])
os.Exit(0)
@ -125,6 +133,7 @@ func main() {
cfg := discover.Config{
PrivateKey: nodeKey,
NetRestrict: restrictList,
FilterFunction: filterFunction,
}
if *runv5 {
if _, err := discover.ListenV5(conn, ln, cfg); err != nil {

@ -633,6 +633,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol {
// Start implements node.Lifecycle, starting all internal goroutines needed by the
// Ethereum protocol implementation.
func (s *Ethereum) Start() error {
eth.StartENRFilter(s.blockchain, s.p2pServer)
eth.StartENRUpdater(s.blockchain, s.p2pServer.LocalNode())
// Start the bloom bits servicing goroutines

@ -19,6 +19,7 @@ package eth
import (
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp"
)
@ -57,6 +58,11 @@ func StartENRUpdater(chain *core.BlockChain, ln *enode.LocalNode) {
}()
}
func StartENRFilter(chain *core.BlockChain, p2p *p2p.Server) {
forkFilter := forkid.NewFilter(chain)
p2p.SetFilter(forkFilter)
}
// currentENREntry constructs an `eth` ENR entry based on the current state of the chain.
func currentENREntry(chain *core.BlockChain) *enrEntry {
return &enrEntry{

@ -18,13 +18,17 @@ package discover
import (
"crypto/ecdsa"
"fmt"
"net"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
)
// UDPConn is a network connection on which discovery can operate.
@ -35,6 +39,34 @@ type UDPConn interface {
LocalAddr() net.Addr
}
type NodeFilterFunc func(*enr.Record) bool
func ParseEthFilter(chain string) (NodeFilterFunc, error) {
var filter forkid.Filter
switch chain {
case "bsc":
filter = forkid.NewStaticFilter(params.BSCChainConfig, params.BSCGenesisHash)
case "chapel":
filter = forkid.NewStaticFilter(params.ChapelChainConfig, params.ChapelGenesisHash)
case "rialto":
filter = forkid.NewStaticFilter(params.RialtoChainConfig, params.RialtoGenesisHash)
default:
return nil, fmt.Errorf("unknown network %q", chain)
}
f := func(r *enr.Record) bool {
var eth struct {
ForkID forkid.ID
Tail []rlp.RawValue `rlp:"tail"`
}
if r.Load(enr.WithEntry("eth", &eth)) != nil {
return false
}
return filter(eth.ForkID) == nil
}
return f, nil
}
// Config holds settings for the discovery listener.
type Config struct {
// These settings are required and configure the UDP listener:
@ -47,6 +79,7 @@ type Config struct {
Log log.Logger // if set, log messages go here
ValidSchemes enr.IdentityScheme // allowed identity schemes
Clock mclock.Clock
FilterFunction NodeFilterFunc // function for filtering ENR entries
}
func (cfg Config) withDefaults() Config {

@ -80,6 +80,8 @@ type Table struct {
closeReq chan struct{}
closed chan struct{}
enrFilter NodeFilterFunc
nodeAddedHook func(*node) // for testing
}
@ -100,7 +102,7 @@ type bucket struct {
ips netutil.DistinctNetSet
}
func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger) (*Table, error) {
func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger, filter NodeFilterFunc) (*Table, error) {
tab := &Table{
net: t,
db: db,
@ -111,6 +113,7 @@ func newTable(t transport, db *enode.DB, bootnodes []*enode.Node, log log.Logger
rand: mrand.New(mrand.NewSource(0)),
ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
log: log,
enrFilter: filter,
}
if err := tab.setFallbackNodes(bootnodes); err != nil {
return nil, err
@ -339,10 +342,16 @@ func (tab *Table) doRevalidate(done chan<- struct{}) {
// Also fetch record if the node replied and returned a higher sequence number.
if last.Seq() < remoteSeq {
n, err := tab.net.RequestENR(unwrapNode(last))
if err != nil {
tab.log.Debug("ENR request failed", "id", last.ID(), "addr", last.addr(), "err", err)
n, enrErr := tab.net.RequestENR(unwrapNode(last))
if enrErr != nil {
tab.log.Debug("ENR request failed", "id", last.ID(), "addr", last.addr(), "err", enrErr)
} else {
if tab.enrFilter != nil {
if !tab.enrFilter(n.Record()) {
tab.log.Trace("ENR record filter out", "id", last.ID(), "addr", last.addr())
err = fmt.Errorf("filtered node")
}
}
last = &node{Node: *n, addedAt: last.addedAt, livenessChecks: last.livenessChecks}
}
}
@ -473,10 +482,20 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
//
// The caller must not hold tab.mutex.
func (tab *Table) addSeenNode(n *node) {
gopool.Submit(func() {
tab.addSeenNodeSync(n)
})
}
func (tab *Table) addSeenNodeSync(n *node) {
if n.ID() == tab.self().ID() {
return
}
if tab.filterNode(n) {
return
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.bucket(n.ID())
@ -502,6 +521,20 @@ func (tab *Table) addSeenNode(n *node) {
}
}
func (tab *Table) filterNode(n *node) bool {
if tab.enrFilter == nil {
return false
}
if node, err := tab.net.RequestENR(unwrapNode(n)); err != nil {
tab.log.Debug("ENR request failed", "id", n.ID(), "addr", n.addr(), "err", err)
return false
} else if !tab.enrFilter(node.Record()) {
tab.log.Trace("ENR record filter out", "id", n.ID(), "addr", n.addr())
return true
}
return false
}
// addVerifiedNode adds a node whose existence has been verified recently to the front of a
// bucket. If the node is already in the bucket, it is moved to the front. If the bucket
// has no space, the node is added to the replacements list.
@ -511,14 +544,23 @@ func (tab *Table) addSeenNode(n *node) {
// ping repeatedly.
//
// The caller must not hold tab.mutex.
func (tab *Table) addVerifiedNode(n *node) {
gopool.Submit(func() {
tab.addVerifiedNodeSync(n)
})
}
func (tab *Table) addVerifiedNodeSync(n *node) {
if !tab.isInitDone() {
return
}
if n.ID() == tab.self().ID() {
return
}
if tab.filterNode(n) {
return
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.bucket(n.ID())

@ -27,10 +27,13 @@ import (
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
)
func TestTable_pingReplace(t *testing.T) {
@ -65,7 +68,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding
// its bucket if it is unresponsive. Revalidate again to ensure that
transport.dead[last.ID()] = !lastInBucketIsResponding
transport.dead[pingSender.ID()] = !newNodeIsResponding
tab.addSeenNode(pingSender)
tab.addSeenNodeSync(pingSender)
tab.doRevalidate(make(chan struct{}, 1))
tab.doRevalidate(make(chan struct{}, 1))
@ -148,7 +151,7 @@ func TestTable_IPLimit(t *testing.T) {
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.addSeenNode(n)
tab.addSeenNodeSync(n)
}
if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table")
@ -314,8 +317,8 @@ func TestTable_addVerifiedNode(t *testing.T) {
// Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addSeenNode(n1)
tab.addSeenNode(n2)
tab.addSeenNodeSync(n1)
tab.addSeenNodeSync(n2)
// Verify bucket content:
bcontent := []*node{n1, n2}
@ -327,7 +330,7 @@ func TestTable_addVerifiedNode(t *testing.T) {
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addVerifiedNode(newn2)
tab.addVerifiedNodeSync(newn2)
// Check that bucket is updated correctly.
newBcontent := []*node{newn2, n1}
@ -346,8 +349,8 @@ func TestTable_addSeenNode(t *testing.T) {
// Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addSeenNode(n1)
tab.addSeenNode(n2)
tab.addSeenNodeSync(n1)
tab.addSeenNodeSync(n2)
// Verify bucket content:
bcontent := []*node{n1, n2}
@ -359,7 +362,7 @@ func TestTable_addSeenNode(t *testing.T) {
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addSeenNode(newn2)
tab.addSeenNodeSync(newn2)
// Check that bucket content is unchanged.
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
@ -382,7 +385,7 @@ func TestTable_revalidateSyncRecord(t *testing.T) {
r.Set(enr.IP(net.IP{127, 0, 0, 1}))
id := enode.ID{1}
n1 := wrapNode(enode.SignNull(&r, id))
tab.addSeenNode(n1)
tab.addSeenNodeSync(n1)
// Update the node record.
r.Set(enr.WithEntry("foo", "bar"))
@ -396,6 +399,41 @@ func TestTable_revalidateSyncRecord(t *testing.T) {
}
}
// This test checks that ENR filtering is working properly
func TestTable_filterNode(t *testing.T) {
// Create ENR filter
type eth struct {
ForkID forkid.ID
Tail []rlp.RawValue `rlp:"tail"`
}
enrFilter, _ := ParseEthFilter("bsc")
// Check test ENR record
var r1 enr.Record
r1.Set(enr.WithEntry("foo", "bar"))
if enrFilter(&r1) {
t.Fatalf("filterNode doesn't work correctly for entry")
}
t.Logf("Check test ENR record - passed")
// Check wrong genesis ENR record
var r2 enr.Record
r2.Set(enr.WithEntry("eth", eth{ForkID: forkid.NewID(params.BSCChainConfig, params.ChapelGenesisHash, uint64(0))}))
if enrFilter(&r2) {
t.Fatalf("filterNode doesn't work correctly for wrong genesis entry")
}
t.Logf("Check wrong genesis ENR record - passed")
// Check correct genesis ENR record
var r3 enr.Record
r3.Set(enr.WithEntry("eth", eth{ForkID: forkid.NewID(params.BSCChainConfig, params.BSCGenesisHash, uint64(0))}))
if !enrFilter(&r3) {
t.Fatalf("filterNode doesn't work correctly for correct genesis entry")
}
t.Logf("Check correct genesis ENR record - passed")
}
// gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} {

@ -43,7 +43,7 @@ func init() {
func newTestTable(t transport) (*Table, *enode.DB) {
db, _ := enode.OpenDB("")
tab, _ := newTable(t, db, nil, log.Root())
tab, _ := newTable(t, db, nil, log.Root(), nil)
go tab.loop()
return tab, db
}
@ -110,7 +110,7 @@ func fillBucket(tab *Table, n *node) (last *node) {
// if the bucket is not full. The caller must not hold tab.mutex.
func fillTable(tab *Table, nodes []*node) {
for _, n := range nodes {
tab.addSeenNode(n)
tab.addSeenNodeSync(n)
}
}

@ -42,7 +42,7 @@ var (
errExpired = errors.New("expired")
errUnsolicitedReply = errors.New("unsolicited reply")
errUnknownNode = errors.New("unknown node")
errTimeout = errors.New("RPC timeout")
errTimeout = errors.New("udp timeout")
errClockWarp = errors.New("reply deadline too far in the future")
errClosed = errors.New("socket closed")
errLowPort = errors.New("low port")
@ -143,7 +143,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
log: cfg.Log,
}
tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log)
tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log, cfg.FilterFunction)
if err != nil {
return nil, err
}

@ -164,7 +164,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
closeCtx: closeCtx,
cancelCloseCtx: cancelCloseCtx,
}
tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log)
tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log, cfg.FilterFunction)
if err != nil {
return nil, err
}

@ -145,7 +145,7 @@ func TestUDPv5_unknownPacket(t *testing.T) {
// Make node known.
n := test.getNode(test.remotekey, test.remoteaddr).Node()
test.table.addSeenNode(wrapNode(n))
test.table.addSeenNodeSync(wrapNode(n))
test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {

@ -32,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/gopool"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
@ -40,6 +41,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/ethereum/go-ethereum/rlp"
)
const (
@ -193,6 +195,8 @@ type Server struct {
discmix *enode.FairMix
dialsched *dialScheduler
forkFilter forkid.Filter
// Channels into the run loop.
quit chan struct{}
addtrusted chan *enode.Node
@ -593,6 +597,21 @@ func (srv *Server) setupDiscovery() error {
}
srv.localnode.SetFallbackUDP(realaddr.Port)
// ENR filter function
f := func(r *enr.Record) bool {
if srv.forkFilter == nil {
return true
}
var eth struct {
ForkID forkid.ID
Tail []rlp.RawValue `rlp:"tail"`
}
if r.Load(enr.WithEntry("eth", &eth)) != nil {
return false
}
return srv.forkFilter(eth.ForkID) == nil
}
// Discovery V4
var unhandled chan discover.ReadPacket
var sconn *sharedUDPConn
@ -607,6 +626,7 @@ func (srv *Server) setupDiscovery() error {
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
Log: srv.log,
FilterFunction: f,
}
ntab, err := discover.ListenV4(conn, srv.localnode, cfg)
if err != nil {
@ -623,6 +643,7 @@ func (srv *Server) setupDiscovery() error {
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodesV5,
Log: srv.log,
FilterFunction: f,
}
var err error
if sconn != nil {
@ -666,6 +687,10 @@ func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
func (srv *Server) SetFilter(f forkid.Filter) {
srv.forkFilter = f
}
func (srv *Server) maxDialedConns() (limit int) {
if srv.NoDial || srv.MaxPeers == 0 {
return 0