diff --git a/eth/peer.go b/eth/peer.go index 68ce903a6a..695e910f64 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -21,6 +21,7 @@ import ( "fmt" "math/big" "sync" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -38,8 +39,9 @@ var ( ) const ( - maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) - maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) + maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) + maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) + handshakeTimeout = 5 * time.Second ) type peer struct { @@ -267,8 +269,8 @@ func (p *peer) RequestReceipts(hashes []common.Hash) error { // Handshake executes the eth protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. func (p *peer) Handshake(td *big.Int, head common.Hash, genesis common.Hash) error { - // Send out own handshake in a new thread - errc := make(chan error, 1) + errc := make(chan error, 2) + var status statusData // safe to read after two values have been received from errc go func() { errc <- p2p.Send(p.rw, StatusMsg, &statusData{ ProtocolVersion: uint32(p.version), @@ -278,7 +280,26 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, genesis common.Hash) err GenesisBlock: genesis, }) }() - // In the mean time retrieve the remote status message + go func() { + errc <- p.readStatus(&status, genesis) + }() + timeout := time.NewTimer(handshakeTimeout) + defer timeout.Stop() + for i := 0; i < 2; i++ { + select { + case err := <-errc: + if err != nil { + return err + } + case <-timeout.C: + return p2p.DiscReadTimeout + } + } + p.td, p.head = status.TD, status.CurrentBlock + return nil +} + +func (p *peer) readStatus(status *statusData, genesis common.Hash) (err error) { msg, err := p.rw.ReadMsg() if err != nil { return err @@ -290,7 +311,6 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, genesis common.Hash) err return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) } // Decode the handshake and make sure everything matches - var status statusData if err := msg.Decode(&status); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) } @@ -303,9 +323,7 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, genesis common.Hash) err if int(status.ProtocolVersion) != p.version { return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version) } - // Configure the remote peer, and sanity check out handshake too - p.td, p.head = status.TD, status.CurrentBlock - return <-errc + return nil } // String implements fmt.Stringer.