p2p: use errors.Is for error comparison (#24882)

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Håvard Anda Estensen 2022-06-07 17:27:21 +02:00 committed by GitHub
parent 41e75480df
commit 138f0d7494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 24 additions and 13 deletions

@ -18,6 +18,7 @@ package discover
import ( import (
"context" "context"
"errors"
"time" "time"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -141,7 +142,7 @@ func (it *lookup) slowdown() {
func (it *lookup) query(n *node, reply chan<- []*node) { func (it *lookup) query(n *node, reply chan<- []*node) {
fails := it.tab.db.FindFails(n.ID(), n.IP()) fails := it.tab.db.FindFails(n.ID(), n.IP())
r, err := it.queryfunc(n) r, err := it.queryfunc(n)
if err == errClosed { if errors.Is(err, errClosed) {
// Avoid recording failures on shutdown. // Avoid recording failures on shutdown.
reply <- nil reply <- nil
return return

@ -328,7 +328,7 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubke
// enough nodes the reply matcher will time out waiting for the second reply, but // enough nodes the reply matcher will time out waiting for the second reply, but
// there's no need for an error in that case. // there's no need for an error in that case.
err := <-rm.errc err := <-rm.errc
if err == errTimeout && rm.reply != nil { if errors.Is(err, errTimeout) && rm.reply != nil {
err = nil err = nil
} }
return nodes, err return nodes, err
@ -526,7 +526,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) {
continue continue
} else if err != nil { } else if err != nil {
// Shut down the loop for permament errors. // Shut down the loop for permament errors.
if err != io.EOF { if !errors.Is(err, io.EOF) {
t.log.Debug("UDP read error", "err", err) t.log.Debug("UDP read error", "err", err)
} }
return return

@ -305,7 +305,7 @@ func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) {
) )
var r []*enode.Node var r []*enode.Node
r, err = t.findnode(unwrapNode(destNode), dists) r, err = t.findnode(unwrapNode(destNode), dists)
if err == errClosed { if errors.Is(err, errClosed) {
return nil, err return nil, err
} }
for _, n := range r { for _, n := range r {
@ -623,7 +623,7 @@ func (t *UDPv5) readLoop() {
continue continue
} else if err != nil { } else if err != nil {
// Shut down the loop for permament errors. // Shut down the loop for permament errors.
if err != io.EOF { if !errors.Is(err, io.EOF) {
t.log.Debug("UDP read error", "err", err) t.log.Debug("UDP read error", "err", err)
} }
return return

@ -596,7 +596,7 @@ func (c *Codec) decodeMessage(fromAddr string, head *Header, headerData, msgData
// Try decrypting the message. // Try decrypting the message.
key := c.sc.readKey(auth.SrcID, fromAddr) key := c.sc.readKey(auth.SrcID, fromAddr)
msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, key) msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, key)
if err == errMessageDecrypt { if errors.Is(err, errMessageDecrypt) {
// It didn't work. Start the handshake since this is an ordinary message packet. // It didn't work. Start the handshake since this is an ordinary message packet.
return &Unknown{Nonce: head.Nonce}, nil return &Unknown{Nonce: head.Nonce}, nil
} }

@ -19,6 +19,7 @@ package dnsdisc
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@ -204,7 +205,7 @@ func (c *Client) doResolveEntry(ctx context.Context, domain, hash string) (entry
} }
for _, txt := range txts { for _, txt := range txts {
e, err := parseEntry(txt, c.cfg.ValidSchemes) e, err := parseEntry(txt, c.cfg.ValidSchemes)
if err == errUnknownEntry { if errors.Is(err, errUnknownEntry) {
continue continue
} }
if !bytes.HasPrefix(crypto.Keccak256([]byte(txt)), wantHash) { if !bytes.HasPrefix(crypto.Keccak256([]byte(txt)), wantHash) {
@ -281,7 +282,7 @@ func (it *randomIterator) nextNode() *enode.Node {
} }
n, err := ct.syncRandom(it.ctx) n, err := ct.syncRandom(it.ctx)
if err != nil { if err != nil {
if err == it.ctx.Err() { if errors.Is(err, it.ctx.Err()) {
return nil // context canceled. return nil // context canceled.
} }
it.c.cfg.Logger.Debug("Error in DNS random node sync", "tree", ct.loc.domain, "err", err) it.c.cfg.Logger.Debug("Error in DNS random node sync", "tree", ct.loc.domain, "err", err)

@ -17,6 +17,7 @@
package enr package enr
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -180,9 +181,16 @@ func (err *KeyError) Error() string {
return fmt.Sprintf("ENR key %q: %v", err.Key, err.Err) return fmt.Sprintf("ENR key %q: %v", err.Key, err.Err)
} }
func (err *KeyError) Unwrap() error {
return err.Err
}
// IsNotFound reports whether the given error means that a key/value pair is // IsNotFound reports whether the given error means that a key/value pair is
// missing from a record. // missing from a record.
func IsNotFound(err error) bool { func IsNotFound(err error) bool {
kerr, ok := err.(*KeyError) var ke *KeyError
return ok && kerr.Err == errNotFound if errors.As(err, &ke) {
return ke.Err == errNotFound
}
return false
} }

@ -416,7 +416,7 @@ func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error)
if err == nil { if err == nil {
p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version)) p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version))
err = errProtocolReturned err = errProtocolReturned
} else if err != io.EOF { } else if !errors.Is(err, io.EOF) {
p.log.Trace(fmt.Sprintf("Protocol %s/%d failed", proto.Name, proto.Version), "err", err) p.log.Trace(fmt.Sprintf("Protocol %s/%d failed", proto.Name, proto.Version), "err", err)
} }
p.protoErr <- err p.protoErr <- err

@ -103,7 +103,7 @@ func discReasonForError(err error) DiscReason {
if reason, ok := err.(DiscReason); ok { if reason, ok := err.(DiscReason); ok {
return reason return reason
} }
if err == errProtocolReturned { if errors.Is(err, errProtocolReturned) {
return DiscQuitting return DiscQuitting
} }
peerError, ok := err.(*peerError) peerError, ok := err.(*peerError)

@ -21,6 +21,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"html" "html"
"io" "io"
@ -559,7 +560,7 @@ func (s *Server) CreateNode(w http.ResponseWriter, req *http.Request) {
config := &adapters.NodeConfig{} config := &adapters.NodeConfig{}
err := json.NewDecoder(req.Body).Decode(config) err := json.NewDecoder(req.Body).Decode(config)
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }