p2p/discover: pass invalid discv5 packets to Unhandled channel (#26699)

This makes it possible to run another protocol alongside discv5, by reading 
unhandled packets from the channel.
This commit is contained in:
Martin Holst Swende 2023-03-14 07:40:40 -04:00 committed by GitHub
parent c8a6b7100c
commit eca3d39c31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 2 deletions

@ -83,6 +83,7 @@ type UDPv5 struct {
callCh chan *callV5 callCh chan *callV5
callDoneCh chan *callV5 callDoneCh chan *callV5
respTimeoutCh chan *callTimeout respTimeoutCh chan *callTimeout
unhandled chan<- ReadPacket
// state of dispatch // state of dispatch
codec codecV5 codec codecV5
@ -156,6 +157,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
callCh: make(chan *callV5), callCh: make(chan *callV5),
callDoneCh: make(chan *callV5), callDoneCh: make(chan *callV5),
respTimeoutCh: make(chan *callTimeout), respTimeoutCh: make(chan *callTimeout),
unhandled: cfg.Unhandled,
// state of dispatch // state of dispatch
codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID), codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock, cfg.V5ProtocolID),
activeCallByNode: make(map[enode.ID]*callV5), activeCallByNode: make(map[enode.ID]*callV5),
@ -657,6 +659,14 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
addr := fromAddr.String() addr := fromAddr.String()
fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr)
if err != nil { if err != nil {
if t.unhandled != nil && v5wire.IsInvalidHeader(err) {
// The packet seems unrelated to discv5, send it to the next protocol.
// t.log.Trace("Unhandled discv5 packet", "id", fromID, "addr", addr, "err", err)
up := ReadPacket{Data: make([]byte, len(rawpacket)), Addr: fromAddr}
copy(up.Data, rawpacket)
t.unhandled <- up
return nil
}
t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err)
return err return err
} }

@ -94,6 +94,8 @@ const (
// Should reject packets smaller than minPacketSize. // Should reject packets smaller than minPacketSize.
minPacketSize = 63 minPacketSize = 63
maxPacketSize = 1280
minMessageSize = 48 // this refers to data after static headers minMessageSize = 48 // this refers to data after static headers
randomPacketMsgSize = 20 randomPacketMsgSize = 20
) )
@ -122,6 +124,13 @@ var (
ErrInvalidReqID = errors.New("request ID larger than 8 bytes") ErrInvalidReqID = errors.New("request ID larger than 8 bytes")
) )
// IsInvalidHeader reports whether 'err' is related to an invalid packet header. When it
// returns false, it is pretty certain that the packet causing the error does not belong
// to discv5.
func IsInvalidHeader(err error) bool {
return err == errTooShort || err == errInvalidHeader || err == errMsgTooShort
}
// Packet sizes. // Packet sizes.
var ( var (
sizeofStaticHeader = binary.Size(StaticHeader{}) sizeofStaticHeader = binary.Size(StaticHeader{})
@ -147,6 +156,7 @@ type Codec struct {
msgctbuf []byte // message data ciphertext msgctbuf []byte // message data ciphertext
// decoder buffer // decoder buffer
decbuf []byte
reader bytes.Reader reader bytes.Reader
} }
@ -158,6 +168,7 @@ func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, pr
privkey: key, privkey: key,
sc: NewSessionCache(1024, clock), sc: NewSessionCache(1024, clock),
protocolID: DefaultProtocolID, protocolID: DefaultProtocolID,
decbuf: make([]byte, maxPacketSize),
} }
if protocolID != nil { if protocolID != nil {
c.protocolID = *protocolID c.protocolID = *protocolID
@ -424,10 +435,13 @@ func (c *Codec) encryptMessage(s *session, p Packet, head *Header, headerData []
} }
// Decode decodes a discovery packet. // Decode decodes a discovery packet.
func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) { func (c *Codec) Decode(inputData []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) {
if len(input) < minPacketSize { if len(inputData) < minPacketSize {
return enode.ID{}, nil, nil, errTooShort return enode.ID{}, nil, nil, errTooShort
} }
// Copy the packet to a tmp buffer to avoid modifying it.
c.decbuf = append(c.decbuf[:0], inputData...)
input := c.decbuf
// Unmask the static header. // Unmask the static header.
var head Header var head Header
copy(head.IV[:], input[:sizeofMaskingIV]) copy(head.IV[:], input[:sizeofMaskingIV])