From 49bde05a55260469abcbeb64fd3c7b85c7536b5c Mon Sep 17 00:00:00 2001 From: rene <41963722+renaynay@users.noreply.github.com> Date: Tue, 25 May 2021 23:09:11 +0200 Subject: [PATCH] cmd/devp2p: refactor eth test suite (#22843) This PR refactors the eth test suite to make it more readable and easier to use. Some notable differences: - A new file helpers.go stores all of the methods used between both eth66 and eth65 and below tests, as well as methods shared among many test functions. - suite.go now contains all of the test functions for both eth65 tests and eth66 tests. - The utesting.T object doesn't get passed through to other helper methods, but is instead only used within the scope of the test function, whereas helper methods return errors, so only the test function itself can fatal out in the case of an error. - The full test suite now only takes 13.5 seconds to run. --- cmd/devp2p/internal/ethtest/chain.go | 20 +- cmd/devp2p/internal/ethtest/eth66_suite.go | 521 ---------- .../internal/ethtest/eth66_suiteHelpers.go | 333 ------- cmd/devp2p/internal/ethtest/helpers.go | 635 ++++++++++++ cmd/devp2p/internal/ethtest/suite.go | 920 +++++++++++------- cmd/devp2p/internal/ethtest/transaction.go | 298 ++++-- cmd/devp2p/internal/ethtest/types.go | 240 ++--- 7 files changed, 1506 insertions(+), 1461 deletions(-) delete mode 100644 cmd/devp2p/internal/ethtest/eth66_suite.go delete mode 100644 cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go create mode 100644 cmd/devp2p/internal/ethtest/helpers.go diff --git a/cmd/devp2p/internal/ethtest/chain.go b/cmd/devp2p/internal/ethtest/chain.go index 83c55181ad..34a20c515b 100644 --- a/cmd/devp2p/internal/ethtest/chain.go +++ b/cmd/devp2p/internal/ethtest/chain.go @@ -54,10 +54,24 @@ func (c *Chain) Len() int { return len(c.blocks) } -// TD calculates the total difficulty of the chain. -func (c *Chain) TD(height int) *big.Int { // TODO later on channge scheme so that the height is included in range +// TD calculates the total difficulty of the chain at the +// chain head. +func (c *Chain) TD() *big.Int { sum := big.NewInt(0) - for _, block := range c.blocks[:height] { + for _, block := range c.blocks[:c.Len()] { + sum.Add(sum, block.Difficulty()) + } + return sum +} + +// TotalDifficultyAt calculates the total difficulty of the chain +// at the given block height. +func (c *Chain) TotalDifficultyAt(height int) *big.Int { + sum := big.NewInt(0) + if height >= c.Len() { + return sum + } + for _, block := range c.blocks[:height+1] { sum.Add(sum, block.Difficulty()) } return sum diff --git a/cmd/devp2p/internal/ethtest/eth66_suite.go b/cmd/devp2p/internal/ethtest/eth66_suite.go deleted file mode 100644 index 903a90c7eb..0000000000 --- a/cmd/devp2p/internal/ethtest/eth66_suite.go +++ /dev/null @@ -1,521 +0,0 @@ -// Copyright 2021 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package ethtest - -import ( - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth/protocols/eth" - "github.com/ethereum/go-ethereum/internal/utesting" - "github.com/ethereum/go-ethereum/p2p" -) - -// Is_66 checks if the node supports the eth66 protocol version, -// and if not, exists the test suite -func (s *Suite) Is_66(t *utesting.T) { - conn := s.dial66(t) - conn.handshake(t) - if conn.negotiatedProtoVersion < 66 { - t.Fail() - } -} - -// TestStatus_66 attempts to connect to the given node and exchange -// a status message with it on the eth66 protocol, and then check to -// make sure the chain head is correct. -func (s *Suite) TestStatus_66(t *utesting.T) { - conn := s.dial66(t) - defer conn.Close() - // get protoHandshake - conn.handshake(t) - // get status - switch msg := conn.statusExchange66(t, s.chain).(type) { - case *Status: - status := *msg - if status.ProtocolVersion != uint32(66) { - t.Fatalf("mismatch in version: wanted 66, got %d", status.ProtocolVersion) - } - t.Logf("got status message: %s", pretty.Sdump(msg)) - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } -} - -// TestGetBlockHeaders_66 tests whether the given node can respond to -// an eth66 `GetBlockHeaders` request and that the response is accurate. -func (s *Suite) TestGetBlockHeaders_66(t *utesting.T) { - conn := s.setupConnection66(t) - defer conn.Close() - // get block headers - req := ð.GetBlockHeadersPacket66{ - RequestId: 3, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Hash: s.chain.blocks[1].Hash(), - }, - Amount: 2, - Skip: 1, - Reverse: false, - }, - } - // write message - headers, err := s.getBlockHeaders66(conn, req, req.RequestId) - if err != nil { - t.Fatalf("could not get block headers: %v", err) - } - // check for correct headers - if !headersMatch(t, s.chain, headers) { - t.Fatal("received wrong header(s)") - } -} - -// TestSimultaneousRequests_66 sends two simultaneous `GetBlockHeader` requests -// with different request IDs and checks to make sure the node responds with the correct -// headers per request. -func (s *Suite) TestSimultaneousRequests_66(t *utesting.T) { - // create two connections - conn := s.setupConnection66(t) - defer conn.Close() - // create two requests - req1 := ð.GetBlockHeadersPacket66{ - RequestId: 111, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Hash: s.chain.blocks[1].Hash(), - }, - Amount: 2, - Skip: 1, - Reverse: false, - }, - } - req2 := ð.GetBlockHeadersPacket66{ - RequestId: 222, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Hash: s.chain.blocks[1].Hash(), - }, - Amount: 4, - Skip: 1, - Reverse: false, - }, - } - // write first request - if err := conn.write66(req1, GetBlockHeaders{}.Code()); err != nil { - t.Fatalf("failed to write to connection: %v", err) - } - // write second request - if err := conn.write66(req2, GetBlockHeaders{}.Code()); err != nil { - t.Fatalf("failed to write to connection: %v", err) - } - // wait for responses - headers1, err := s.waitForBlockHeadersResponse66(conn, req1.RequestId) - if err != nil { - t.Fatalf("error while waiting for block headers: %v", err) - } - headers2, err := s.waitForBlockHeadersResponse66(conn, req2.RequestId) - if err != nil { - t.Fatalf("error while waiting for block headers: %v", err) - } - // check headers of both responses - if !headersMatch(t, s.chain, headers1) { - t.Fatalf("wrong header(s) in response to req1: got %v", headers1) - } - if !headersMatch(t, s.chain, headers2) { - t.Fatalf("wrong header(s) in response to req2: got %v", headers2) - } -} - -// TestBroadcast_66 tests whether a block announcement is correctly -// propagated to the given node's peer(s) on the eth66 protocol. -func (s *Suite) TestBroadcast_66(t *utesting.T) { - s.sendNextBlock66(t) -} - -// TestGetBlockBodies_66 tests whether the given node can respond to -// a `GetBlockBodies` request and that the response is accurate over -// the eth66 protocol. -func (s *Suite) TestGetBlockBodies_66(t *utesting.T) { - conn := s.setupConnection66(t) - defer conn.Close() - // create block bodies request - id := uint64(55) - req := ð.GetBlockBodiesPacket66{ - RequestId: id, - GetBlockBodiesPacket: eth.GetBlockBodiesPacket{ - s.chain.blocks[54].Hash(), - s.chain.blocks[75].Hash(), - }, - } - if err := conn.write66(req, GetBlockBodies{}.Code()); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - - reqID, msg := conn.readAndServe66(s.chain, timeout) - switch msg := msg.(type) { - case BlockBodies: - if reqID != req.RequestId { - t.Fatalf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID) - } - t.Logf("received %d block bodies", len(msg)) - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } -} - -// TestLargeAnnounce_66 tests the announcement mechanism with a large block. -func (s *Suite) TestLargeAnnounce_66(t *utesting.T) { - nextBlock := len(s.chain.blocks) - blocks := []*NewBlock{ - { - Block: largeBlock(), - TD: s.fullChain.TD(nextBlock + 1), - }, - { - Block: s.fullChain.blocks[nextBlock], - TD: largeNumber(2), - }, - { - Block: largeBlock(), - TD: largeNumber(2), - }, - { - Block: s.fullChain.blocks[nextBlock], - TD: s.fullChain.TD(nextBlock + 1), - }, - } - - for i, blockAnnouncement := range blocks[0:3] { - t.Logf("Testing malicious announcement: %v\n", i) - sendConn := s.setupConnection66(t) - if err := sendConn.Write(blockAnnouncement); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // Invalid announcement, check that peer disconnected - switch msg := sendConn.ReadAndServe(s.chain, time.Second*8).(type) { - case *Disconnect: - case *Error: - break - default: - t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg)) - } - sendConn.Close() - } - // Test the last block as a valid block - s.sendNextBlock66(t) -} - -func (s *Suite) TestOldAnnounce_66(t *utesting.T) { - sendConn, recvConn := s.setupConnection66(t), s.setupConnection66(t) - defer sendConn.Close() - defer recvConn.Close() - - s.oldAnnounce(t, sendConn, recvConn) -} - -// TestMaliciousHandshake_66 tries to send malicious data during the handshake. -func (s *Suite) TestMaliciousHandshake_66(t *utesting.T) { - conn := s.dial66(t) - defer conn.Close() - // write hello to client - pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] - handshakes := []*Hello{ - { - Version: 5, - Caps: []p2p.Cap{ - {Name: largeString(2), Version: 66}, - }, - ID: pub0, - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - {Name: "eth", Version: 66}, - }, - ID: append(pub0, byte(0)), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - {Name: "eth", Version: 66}, - }, - ID: append(pub0, pub0...), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - {Name: "eth", Version: 66}, - }, - ID: largeBuffer(2), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: largeString(2), Version: 66}, - }, - ID: largeBuffer(2), - }, - } - for i, handshake := range handshakes { - t.Logf("Testing malicious handshake %v\n", i) - // Init the handshake - if err := conn.Write(handshake); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // check that the peer disconnected - timeout := 20 * time.Second - // Discard one hello - for i := 0; i < 2; i++ { - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *Disconnect: - case *Error: - case *Hello: - // Hello's are sent concurrently, so ignore them - continue - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } - } - // Dial for the next round - conn = s.dial66(t) - } -} - -// TestMaliciousStatus_66 sends a status package with a large total difficulty. -func (s *Suite) TestMaliciousStatus_66(t *utesting.T) { - conn := s.dial66(t) - defer conn.Close() - // get protoHandshake - conn.handshake(t) - status := &Status{ - ProtocolVersion: uint32(66), - NetworkID: s.chain.chainConfig.ChainID.Uint64(), - TD: largeNumber(2), - Head: s.chain.blocks[s.chain.Len()-1].Hash(), - Genesis: s.chain.blocks[0].Hash(), - ForkID: s.chain.ForkID(), - } - // get status - switch msg := conn.statusExchange(t, s.chain, status).(type) { - case *Status: - t.Logf("%+v\n", msg) - default: - t.Fatalf("expected status, got: %#v ", msg) - } - // wait for disconnect - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *Disconnect: - case *Error: - return - default: - t.Fatalf("expected disconnect, got: %s", pretty.Sdump(msg)) - } -} - -func (s *Suite) TestTransaction_66(t *utesting.T) { - tests := []*types.Transaction{ - getNextTxFromChain(t, s), - unknownTx(t, s), - } - for i, tx := range tests { - t.Logf("Testing tx propagation: %v\n", i) - sendSuccessfulTx66(t, s, tx) - } -} - -func (s *Suite) TestMaliciousTx_66(t *utesting.T) { - badTxs := []*types.Transaction{ - getOldTxFromChain(t, s), - invalidNonceTx(t, s), - hugeAmount(t, s), - hugeGasPrice(t, s), - hugeData(t, s), - } - sendConn := s.setupConnection66(t) - defer sendConn.Close() - // set up receiving connection before sending txs to make sure - // no announcements are missed - recvConn := s.setupConnection66(t) - defer recvConn.Close() - - for i, tx := range badTxs { - t.Logf("Testing malicious tx propagation: %v\n", i) - if err := sendConn.Write(&Transactions{tx}); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - - } - // check to make sure bad txs aren't propagated - waitForTxPropagation(t, s, badTxs, recvConn) -} - -// TestZeroRequestID_66 checks that a request ID of zero is still handled -// by the node. -func (s *Suite) TestZeroRequestID_66(t *utesting.T) { - conn := s.setupConnection66(t) - defer conn.Close() - - req := ð.GetBlockHeadersPacket66{ - RequestId: 0, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Number: 0, - }, - Amount: 2, - }, - } - headers, err := s.getBlockHeaders66(conn, req, req.RequestId) - if err != nil { - t.Fatalf("could not get block headers: %v", err) - } - if !headersMatch(t, s.chain, headers) { - t.Fatal("received wrong header(s)") - } -} - -// TestSameRequestID_66 sends two requests with the same request ID -// concurrently to a single node. -func (s *Suite) TestSameRequestID_66(t *utesting.T) { - conn := s.setupConnection66(t) - // create two requests with the same request ID - reqID := uint64(1234) - request1 := ð.GetBlockHeadersPacket66{ - RequestId: reqID, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Number: 1, - }, - Amount: 2, - }, - } - request2 := ð.GetBlockHeadersPacket66{ - RequestId: reqID, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Number: 33, - }, - Amount: 2, - }, - } - // write the first request - err := conn.write66(request1, GetBlockHeaders{}.Code()) - if err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // perform second request - headers2, err := s.getBlockHeaders66(conn, request2, reqID) - if err != nil { - t.Fatalf("could not get block headers: %v", err) - return - } - // wait for response to first request - headers1, err := s.waitForBlockHeadersResponse66(conn, reqID) - if err != nil { - t.Fatalf("could not get BlockHeaders response: %v", err) - } - // check if headers match - if !headersMatch(t, s.chain, headers1) || !headersMatch(t, s.chain, headers2) { - t.Fatal("received wrong header(s)") - } -} - -// TestLargeTxRequest_66 tests whether a node can fulfill a large GetPooledTransactions -// request. -func (s *Suite) TestLargeTxRequest_66(t *utesting.T) { - // send the next block to ensure the node is no longer syncing and is able to accept - // txs - s.sendNextBlock66(t) - // send 2000 transactions to the node - hashMap, txs := generateTxs(t, s, 2000) - sendConn := s.setupConnection66(t) - defer sendConn.Close() - - sendMultipleSuccessfulTxs(t, s, sendConn, txs) - // set up connection to receive to ensure node is peered with the receiving connection - // before tx request is sent - recvConn := s.setupConnection66(t) - defer recvConn.Close() - // create and send pooled tx request - hashes := make([]common.Hash, 0) - for _, hash := range hashMap { - hashes = append(hashes, hash) - } - getTxReq := ð.GetPooledTransactionsPacket66{ - RequestId: 1234, - GetPooledTransactionsPacket: hashes, - } - if err := recvConn.write66(getTxReq, GetPooledTransactions{}.Code()); err != nil { - t.Fatalf("could not write to conn: %v", err) - } - // check that all received transactions match those that were sent to node - switch msg := recvConn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) { - case PooledTransactions: - for _, gotTx := range msg { - if _, exists := hashMap[gotTx.Hash()]; !exists { - t.Fatalf("unexpected tx received: %v", gotTx.Hash()) - } - } - default: - t.Fatalf("unexpected %s", pretty.Sdump(msg)) - } -} - -// TestNewPooledTxs_66 tests whether a node will do a GetPooledTransactions -// request upon receiving a NewPooledTransactionHashes announcement. -func (s *Suite) TestNewPooledTxs_66(t *utesting.T) { - // send the next block to ensure the node is no longer syncing and is able to accept - // txs - s.sendNextBlock66(t) - // generate 50 txs - hashMap, _ := generateTxs(t, s, 50) - // create new pooled tx hashes announcement - hashes := make([]common.Hash, 0) - for _, hash := range hashMap { - hashes = append(hashes, hash) - } - announce := NewPooledTransactionHashes(hashes) - // send announcement - conn := s.setupConnection66(t) - defer conn.Close() - if err := conn.Write(announce); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // wait for GetPooledTxs request - for { - _, msg := conn.readAndServe66(s.chain, timeout) - switch msg := msg.(type) { - case GetPooledTransactions: - if len(msg) != len(hashes) { - t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg)) - } - return - case *NewPooledTransactionHashes, *NewBlock, *NewBlockHashes: - // ignore propagated txs and blocks from old tests - continue - default: - t.Fatalf("unexpected %s", pretty.Sdump(msg)) - } - } -} diff --git a/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go b/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go deleted file mode 100644 index 3c5b22f0b5..0000000000 --- a/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go +++ /dev/null @@ -1,333 +0,0 @@ -// Copyright 2021 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package ethtest - -import ( - "fmt" - "reflect" - "time" - - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/protocols/eth" - "github.com/ethereum/go-ethereum/internal/utesting" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/rlp" - "github.com/stretchr/testify/assert" -) - -func (c *Conn) statusExchange66(t *utesting.T, chain *Chain) Message { - status := &Status{ - ProtocolVersion: uint32(66), - NetworkID: chain.chainConfig.ChainID.Uint64(), - TD: chain.TD(chain.Len()), - Head: chain.blocks[chain.Len()-1].Hash(), - Genesis: chain.blocks[0].Hash(), - ForkID: chain.ForkID(), - } - return c.statusExchange(t, chain, status) -} - -func (s *Suite) dial66(t *utesting.T) *Conn { - conn, err := s.dial() - if err != nil { - t.Fatalf("could not dial: %v", err) - } - conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66}) - conn.ourHighestProtoVersion = 66 - return conn -} - -func (c *Conn) write66(req eth.Packet, code int) error { - payload, err := rlp.EncodeToBytes(req) - if err != nil { - return err - } - _, err = c.Conn.Write(uint64(code), payload) - return err -} - -func (c *Conn) read66() (uint64, Message) { - code, rawData, _, err := c.Conn.Read() - if err != nil { - return 0, errorf("could not read from connection: %v", err) - } - - var msg Message - - switch int(code) { - case (Hello{}).Code(): - msg = new(Hello) - - case (Ping{}).Code(): - msg = new(Ping) - case (Pong{}).Code(): - msg = new(Pong) - case (Disconnect{}).Code(): - msg = new(Disconnect) - case (Status{}).Code(): - msg = new(Status) - case (GetBlockHeaders{}).Code(): - ethMsg := new(eth.GetBlockHeadersPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket) - case (BlockHeaders{}).Code(): - ethMsg := new(eth.BlockHeadersPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket) - case (GetBlockBodies{}).Code(): - ethMsg := new(eth.GetBlockBodiesPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket) - case (BlockBodies{}).Code(): - ethMsg := new(eth.BlockBodiesPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket) - case (NewBlock{}).Code(): - msg = new(NewBlock) - case (NewBlockHashes{}).Code(): - msg = new(NewBlockHashes) - case (Transactions{}).Code(): - msg = new(Transactions) - case (NewPooledTransactionHashes{}).Code(): - msg = new(NewPooledTransactionHashes) - case (GetPooledTransactions{}.Code()): - ethMsg := new(eth.GetPooledTransactionsPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket) - case (PooledTransactions{}.Code()): - ethMsg := new(eth.PooledTransactionsPacket66) - if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket) - default: - msg = errorf("invalid message code: %d", code) - } - - if msg != nil { - if err := rlp.DecodeBytes(rawData, msg); err != nil { - return 0, errorf("could not rlp decode message: %v", err) - } - return 0, msg - } - return 0, errorf("invalid message: %s", string(rawData)) -} - -func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message { - for { - id, msg := c.readAndServe66(chain, timeout) - if id == requestID { - return msg - } - } -} - -// ReadAndServe serves GetBlockHeaders requests while waiting -// on another message from the node. -func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) { - start := time.Now() - for time.Since(start) < timeout { - c.SetReadDeadline(time.Now().Add(10 * time.Second)) - - reqID, msg := c.read66() - - switch msg := msg.(type) { - case *Ping: - c.Write(&Pong{}) - case *GetBlockHeaders: - headers, err := chain.GetHeaders(*msg) - if err != nil { - return 0, errorf("could not get headers for inbound header request: %v", err) - } - resp := ð.BlockHeadersPacket66{ - RequestId: reqID, - BlockHeadersPacket: eth.BlockHeadersPacket(headers), - } - if err := c.write66(resp, BlockHeaders{}.Code()); err != nil { - return 0, errorf("could not write to connection: %v", err) - } - default: - return reqID, msg - } - } - return 0, errorf("no message received within %v", timeout) -} - -func (s *Suite) setupConnection66(t *utesting.T) *Conn { - // create conn - sendConn := s.dial66(t) - sendConn.handshake(t) - sendConn.statusExchange66(t, s.chain) - return sendConn -} - -func (s *Suite) testAnnounce66(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) { - // Announce the block. - if err := sendConn.Write(blockAnnouncement); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - s.waitAnnounce66(t, receiveConn, blockAnnouncement) -} - -func (s *Suite) waitAnnounce66(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) { - for { - _, msg := conn.readAndServe66(s.chain, timeout) - switch msg := msg.(type) { - case *NewBlock: - t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block)) - assert.Equal(t, - blockAnnouncement.Block.Header(), msg.Block.Header(), - "wrong block header in announcement", - ) - assert.Equal(t, - blockAnnouncement.TD, msg.TD, - "wrong TD in announcement", - ) - return - case *NewBlockHashes: - blockHashes := *msg - t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes)) - assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash, - "wrong block hash in announcement", - ) - return - case *NewPooledTransactionHashes: - // ignore old txs being propagated - continue - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } - } -} - -// waitForBlock66 waits for confirmation from the client that it has -// imported the given block. -func (c *Conn) waitForBlock66(block *types.Block) error { - defer c.SetReadDeadline(time.Time{}) - - c.SetReadDeadline(time.Now().Add(20 * time.Second)) - // note: if the node has not yet imported the block, it will respond - // to the GetBlockHeaders request with an empty BlockHeaders response, - // so the GetBlockHeaders request must be sent again until the BlockHeaders - // response contains the desired header. - for { - req := eth.GetBlockHeadersPacket66{ - RequestId: 54, - GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ - Origin: eth.HashOrNumber{ - Hash: block.Hash(), - }, - Amount: 1, - }, - } - if err := c.write66(req, GetBlockHeaders{}.Code()); err != nil { - return err - } - - reqID, msg := c.read66() - // check message - switch msg := msg.(type) { - case BlockHeaders: - // check request ID - if reqID != req.RequestId { - return fmt.Errorf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID) - } - for _, header := range msg { - if header.Number.Uint64() == block.NumberU64() { - return nil - } - } - time.Sleep(100 * time.Millisecond) - case *NewPooledTransactionHashes: - // ignore old announcements - continue - default: - return fmt.Errorf("invalid message: %s", pretty.Sdump(msg)) - } - } -} - -func sendSuccessfulTx66(t *utesting.T, s *Suite, tx *types.Transaction) { - sendConn := s.setupConnection66(t) - defer sendConn.Close() - sendSuccessfulTxWithConn(t, s, tx, sendConn) -} - -// waitForBlockHeadersResponse66 waits for a BlockHeaders message with the given expected request ID -func (s *Suite) waitForBlockHeadersResponse66(conn *Conn, expectedID uint64) (BlockHeaders, error) { - reqID, msg := conn.readAndServe66(s.chain, timeout) - switch msg := msg.(type) { - case BlockHeaders: - if reqID != expectedID { - return nil, fmt.Errorf("request ID mismatch: wanted %d, got %d", expectedID, reqID) - } - return msg, nil - default: - return nil, fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) - } -} - -func (s *Suite) getBlockHeaders66(conn *Conn, req eth.Packet, expectedID uint64) (BlockHeaders, error) { - if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil { - return nil, fmt.Errorf("could not write to connection: %v", err) - } - return s.waitForBlockHeadersResponse66(conn, expectedID) -} - -func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) bool { - mismatched := 0 - for _, header := range headers { - num := header.Number.Uint64() - t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash())) - if !reflect.DeepEqual(chain.blocks[int(num)].Header(), header) { - mismatched += 1 - t.Logf("received wrong header: %v", pretty.Sdump(header)) - } - } - return mismatched == 0 -} - -func (s *Suite) sendNextBlock66(t *utesting.T) { - sendConn, receiveConn := s.setupConnection66(t), s.setupConnection66(t) - defer sendConn.Close() - defer receiveConn.Close() - - // create new block announcement - nextBlock := len(s.chain.blocks) - blockAnnouncement := &NewBlock{ - Block: s.fullChain.blocks[nextBlock], - TD: s.fullChain.TD(nextBlock + 1), - } - // send announcement and wait for node to request the header - s.testAnnounce66(t, sendConn, receiveConn, blockAnnouncement) - // wait for client to update its chain - if err := receiveConn.waitForBlock66(s.fullChain.blocks[nextBlock]); err != nil { - t.Fatal(err) - } - // update test suite chain - s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) -} diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go new file mode 100644 index 0000000000..d99376124d --- /dev/null +++ b/cmd/devp2p/internal/ethtest/helpers.go @@ -0,0 +1,635 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "fmt" + "net" + "reflect" + "strings" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/rlpx" +) + +var ( + pretty = spew.ConfigState{ + Indent: " ", + DisableCapacities: true, + DisablePointerAddresses: true, + SortKeys: true, + } + timeout = 20 * time.Second +) + +// Is_66 checks if the node supports the eth66 protocol version, +// and if not, exists the test suite +func (s *Suite) Is_66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + if err := conn.handshake(); err != nil { + t.Fatalf("handshake failed: %v", err) + } + if conn.negotiatedProtoVersion < 66 { + t.Fail() + } +} + +// dial attempts to dial the given node and perform a handshake, +// returning the created Conn if successful. +func (s *Suite) dial() (*Conn, error) { + // dial + fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) + if err != nil { + return nil, err + } + conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())} + // do encHandshake + conn.ourKey, _ = crypto.GenerateKey() + _, err = conn.Handshake(conn.ourKey) + if err != nil { + conn.Close() + return nil, err + } + // set default p2p capabilities + conn.caps = []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + } + conn.ourHighestProtoVersion = 65 + return &conn, nil +} + +// dial66 attempts to dial the given node and perform a handshake, +// returning the created Conn with additional eth66 capabilities if +// successful +func (s *Suite) dial66() (*Conn, error) { + conn, err := s.dial() + if err != nil { + return nil, fmt.Errorf("dial failed: %v", err) + } + conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66}) + conn.ourHighestProtoVersion = 66 + return conn, nil +} + +// peer performs both the protocol handshake and the status message +// exchange with the node in order to peer with it. +func (c *Conn) peer(chain *Chain, status *Status) error { + if err := c.handshake(); err != nil { + return fmt.Errorf("handshake failed: %v", err) + } + if _, err := c.statusExchange(chain, status); err != nil { + return fmt.Errorf("status exchange failed: %v", err) + } + return nil +} + +// handshake performs a protocol handshake with the node. +func (c *Conn) handshake() error { + defer c.SetDeadline(time.Time{}) + c.SetDeadline(time.Now().Add(10 * time.Second)) + // write hello to client + pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:] + ourHandshake := &Hello{ + Version: 5, + Caps: c.caps, + ID: pub0, + } + if err := c.Write(ourHandshake); err != nil { + return fmt.Errorf("write to connection failed: %v", err) + } + // read hello from client + switch msg := c.Read().(type) { + case *Hello: + // set snappy if version is at least 5 + if msg.Version >= 5 { + c.SetSnappy(true) + } + c.negotiateEthProtocol(msg.Caps) + if c.negotiatedProtoVersion == 0 { + return fmt.Errorf("unexpected eth protocol version") + } + return nil + default: + return fmt.Errorf("bad handshake: %#v", msg) + } +} + +// negotiateEthProtocol sets the Conn's eth protocol version to highest +// advertised capability from peer. +func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { + var highestEthVersion uint + for _, capability := range caps { + if capability.Name != "eth" { + continue + } + if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion { + highestEthVersion = capability.Version + } + } + c.negotiatedProtoVersion = highestEthVersion +} + +// statusExchange performs a `Status` message exchange with the given node. +func (c *Conn) statusExchange(chain *Chain, status *Status) (Message, error) { + defer c.SetDeadline(time.Time{}) + c.SetDeadline(time.Now().Add(20 * time.Second)) + + // read status message from client + var message Message +loop: + for { + switch msg := c.Read().(type) { + case *Status: + if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want { + return nil, fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x", + want, chain.blocks[chain.Len()-1].NumberU64(), have) + } + if have, want := msg.TD.Cmp(chain.TD()), 0; have != want { + return nil, fmt.Errorf("wrong TD in status: have %v want %v", have, want) + } + if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) { + return nil, fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want) + } + if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) { + return nil, fmt.Errorf("wrong protocol version: have %v, want %v", have, want) + } + message = msg + break loop + case *Disconnect: + return nil, fmt.Errorf("disconnect received: %v", msg.Reason) + case *Ping: + c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error + // (PINGs should not be a response upon fresh connection) + default: + return nil, fmt.Errorf("bad status message: %s", pretty.Sdump(msg)) + } + } + // make sure eth protocol version is set for negotiation + if c.negotiatedProtoVersion == 0 { + return nil, fmt.Errorf("eth protocol version must be set in Conn") + } + if status == nil { + // default status message + status = &Status{ + ProtocolVersion: uint32(c.negotiatedProtoVersion), + NetworkID: chain.chainConfig.ChainID.Uint64(), + TD: chain.TD(), + Head: chain.blocks[chain.Len()-1].Hash(), + Genesis: chain.blocks[0].Hash(), + ForkID: chain.ForkID(), + } + } + if err := c.Write(status); err != nil { + return nil, fmt.Errorf("write to connection failed: %v", err) + } + return message, nil +} + +// createSendAndRecvConns creates two connections, one for sending messages to the +// node, and one for receiving messages from the node. +func (s *Suite) createSendAndRecvConns(isEth66 bool) (*Conn, *Conn, error) { + var ( + sendConn *Conn + recvConn *Conn + err error + ) + if isEth66 { + sendConn, err = s.dial66() + if err != nil { + return nil, nil, fmt.Errorf("dial failed: %v", err) + } + recvConn, err = s.dial66() + if err != nil { + sendConn.Close() + return nil, nil, fmt.Errorf("dial failed: %v", err) + } + } else { + sendConn, err = s.dial() + if err != nil { + return nil, nil, fmt.Errorf("dial failed: %v", err) + } + recvConn, err = s.dial() + if err != nil { + sendConn.Close() + return nil, nil, fmt.Errorf("dial failed: %v", err) + } + } + return sendConn, recvConn, nil +} + +// readAndServe serves GetBlockHeaders requests while waiting +// on another message from the node. +func (c *Conn) readAndServe(chain *Chain, timeout time.Duration) Message { + start := time.Now() + for time.Since(start) < timeout { + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + switch msg := c.Read().(type) { + case *Ping: + c.Write(&Pong{}) + case *GetBlockHeaders: + req := *msg + headers, err := chain.GetHeaders(req) + if err != nil { + return errorf("could not get headers for inbound header request: %v", err) + } + if err := c.Write(headers); err != nil { + return errorf("could not write to connection: %v", err) + } + default: + return msg + } + } + return errorf("no message received within %v", timeout) +} + +// readAndServe66 serves eth66 GetBlockHeaders requests while waiting +// on another message from the node. +func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) { + start := time.Now() + for time.Since(start) < timeout { + c.SetReadDeadline(time.Now().Add(10 * time.Second)) + + reqID, msg := c.Read66() + + switch msg := msg.(type) { + case *Ping: + c.Write(&Pong{}) + case *GetBlockHeaders: + headers, err := chain.GetHeaders(*msg) + if err != nil { + return 0, errorf("could not get headers for inbound header request: %v", err) + } + resp := ð.BlockHeadersPacket66{ + RequestId: reqID, + BlockHeadersPacket: eth.BlockHeadersPacket(headers), + } + if err := c.Write66(resp, BlockHeaders{}.Code()); err != nil { + return 0, errorf("could not write to connection: %v", err) + } + default: + return reqID, msg + } + } + return 0, errorf("no message received within %v", timeout) +} + +// headersRequest executes the given `GetBlockHeaders` request. +func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bool, reqID uint64) (BlockHeaders, error) { + defer c.SetReadDeadline(time.Time{}) + c.SetReadDeadline(time.Now().Add(20 * time.Second)) + // if on eth66 connection, perform eth66 GetBlockHeaders request + if isEth66 { + return getBlockHeaders66(chain, c, request, reqID) + } + if err := c.Write(request); err != nil { + return nil, err + } + switch msg := c.readAndServe(chain, timeout).(type) { + case *BlockHeaders: + return *msg, nil + default: + return nil, fmt.Errorf("invalid message: %s", pretty.Sdump(msg)) + } +} + +// getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol. +func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) { + // write request + packet := eth.GetBlockHeadersPacket(*request) + req := ð.GetBlockHeadersPacket66{ + RequestId: id, + GetBlockHeadersPacket: &packet, + } + if err := conn.Write66(req, GetBlockHeaders{}.Code()); err != nil { + return nil, fmt.Errorf("could not write to connection: %v", err) + } + // wait for response + msg := conn.waitForResponse(chain, timeout, req.RequestId) + headers, ok := msg.(BlockHeaders) + if !ok { + return nil, fmt.Errorf("unexpected message received: %s", pretty.Sdump(msg)) + } + return headers, nil +} + +// headersMatch returns whether the received headers match the given request +func headersMatch(expected BlockHeaders, headers BlockHeaders) bool { + return reflect.DeepEqual(expected, headers) +} + +// waitForResponse reads from the connection until a response with the expected +// request ID is received. +func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message { + for { + id, msg := c.readAndServe66(chain, timeout) + if id == requestID { + return msg + } + } +} + +// sendNextBlock broadcasts the next block in the chain and waits +// for the node to propagate the block and import it into its chain. +func (s *Suite) sendNextBlock(isEth66 bool) error { + // set up sending and receiving connections + sendConn, recvConn, err := s.createSendAndRecvConns(isEth66) + if err != nil { + return err + } + defer sendConn.Close() + defer recvConn.Close() + if err = sendConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + if err = recvConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + // create new block announcement + nextBlock := s.fullChain.blocks[s.chain.Len()] + blockAnnouncement := &NewBlock{ + Block: nextBlock, + TD: s.fullChain.TotalDifficultyAt(s.chain.Len()), + } + // send announcement and wait for node to request the header + if err = s.testAnnounce(sendConn, recvConn, blockAnnouncement); err != nil { + return fmt.Errorf("failed to announce block: %v", err) + } + // wait for client to update its chain + if err = s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil { + return fmt.Errorf("failed to receive confirmation of block import: %v", err) + } + // update test suite chain + s.chain.blocks = append(s.chain.blocks, nextBlock) + return nil +} + +// testAnnounce writes a block announcement to the node and waits for the node +// to propagate it. +func (s *Suite) testAnnounce(sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) error { + if err := sendConn.Write(blockAnnouncement); err != nil { + return fmt.Errorf("could not write to connection: %v", err) + } + return s.waitAnnounce(receiveConn, blockAnnouncement) +} + +// waitAnnounce waits for a NewBlock or NewBlockHashes announcement from the node. +func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error { + for { + switch msg := conn.readAndServe(s.chain, timeout).(type) { + case *NewBlock: + if !reflect.DeepEqual(blockAnnouncement.Block.Header(), msg.Block.Header()) { + return fmt.Errorf("wrong header in block announcement: \nexpected %v "+ + "\ngot %v", blockAnnouncement.Block.Header(), msg.Block.Header()) + } + if !reflect.DeepEqual(blockAnnouncement.TD, msg.TD) { + return fmt.Errorf("wrong TD in announcement: expected %v, got %v", blockAnnouncement.TD, msg.TD) + } + return nil + case *NewBlockHashes: + hashes := *msg + if blockAnnouncement.Block.Hash() != hashes[0].Hash { + return fmt.Errorf("wrong block hash in announcement: expected %v, got %v", blockAnnouncement.Block.Hash(), hashes[0].Hash) + } + return nil + case *NewPooledTransactionHashes: + // ignore tx announcements from previous tests + continue + default: + return fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) + } + } +} + +func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block, isEth66 bool) error { + defer conn.SetReadDeadline(time.Time{}) + conn.SetReadDeadline(time.Now().Add(20 * time.Second)) + // create request + req := &GetBlockHeaders{ + Origin: eth.HashOrNumber{ + Hash: block.Hash(), + }, + Amount: 1, + } + // loop until BlockHeaders response contains desired block, confirming the + // node imported the block + for { + var ( + headers BlockHeaders + err error + ) + if isEth66 { + requestID := uint64(54) + headers, err = conn.headersRequest(req, s.chain, eth66, requestID) + } else { + headers, err = conn.headersRequest(req, s.chain, eth65, 0) + } + if err != nil { + return fmt.Errorf("GetBlockHeader request failed: %v", err) + } + // if headers response is empty, node hasn't imported block yet, try again + if len(headers) == 0 { + time.Sleep(100 * time.Millisecond) + continue + } + if !reflect.DeepEqual(block.Header(), headers[0]) { + return fmt.Errorf("wrong header returned: wanted %v, got %v", block.Header(), headers[0]) + } + return nil + } +} + +func (s *Suite) oldAnnounce(isEth66 bool) error { + sendConn, receiveConn, err := s.createSendAndRecvConns(isEth66) + if err != nil { + return err + } + defer sendConn.Close() + defer receiveConn.Close() + if err := sendConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + if err := receiveConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + // create old block announcement + oldBlockAnnounce := &NewBlock{ + Block: s.chain.blocks[len(s.chain.blocks)/2], + TD: s.chain.blocks[len(s.chain.blocks)/2].Difficulty(), + } + if err := sendConn.Write(oldBlockAnnounce); err != nil { + return fmt.Errorf("could not write to connection: %v", err) + } + // wait to see if the announcement is propagated + switch msg := receiveConn.readAndServe(s.chain, time.Second*8).(type) { + case *NewBlock: + block := *msg + if block.Block.Hash() == oldBlockAnnounce.Block.Hash() { + return fmt.Errorf("unexpected: block propagated: %s", pretty.Sdump(msg)) + } + case *NewBlockHashes: + hashes := *msg + for _, hash := range hashes { + if hash.Hash == oldBlockAnnounce.Block.Hash() { + return fmt.Errorf("unexpected: block announced: %s", pretty.Sdump(msg)) + } + } + case *Error: + errMsg := *msg + // check to make sure error is timeout (propagation didn't come through == test successful) + if !strings.Contains(errMsg.String(), "timeout") { + return fmt.Errorf("unexpected error: %v", pretty.Sdump(msg)) + } + default: + return fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) + } + return nil +} + +func (s *Suite) maliciousHandshakes(t *utesting.T, isEth66 bool) error { + var ( + conn *Conn + err error + ) + if isEth66 { + conn, err = s.dial66() + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + } else { + conn, err = s.dial() + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + } + defer conn.Close() + // write hello to client + pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] + handshakes := []*Hello{ + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: pub0, + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, byte(0)), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, pub0...), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: largeBuffer(2), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: largeBuffer(2), + }, + } + for i, handshake := range handshakes { + t.Logf("Testing malicious handshake %v\n", i) + if err := conn.Write(handshake); err != nil { + return fmt.Errorf("could not write to connection: %v", err) + } + // check that the peer disconnected + for i := 0; i < 2; i++ { + switch msg := conn.readAndServe(s.chain, 20*time.Second).(type) { + case *Disconnect: + case *Error: + case *Hello: + // Discard one hello as Hello's are sent concurrently + continue + default: + return fmt.Errorf("unexpected: %s", pretty.Sdump(msg)) + } + } + // dial for the next round + if isEth66 { + conn, err = s.dial66() + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + } else { + conn, err = s.dial() + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + } + } + return nil +} + +func (s *Suite) maliciousStatus(conn *Conn) error { + if err := conn.handshake(); err != nil { + return fmt.Errorf("handshake failed: %v", err) + } + status := &Status{ + ProtocolVersion: uint32(conn.negotiatedProtoVersion), + NetworkID: s.chain.chainConfig.ChainID.Uint64(), + TD: largeNumber(2), + Head: s.chain.blocks[s.chain.Len()-1].Hash(), + Genesis: s.chain.blocks[0].Hash(), + ForkID: s.chain.ForkID(), + } + // get status + msg, err := conn.statusExchange(s.chain, status) + if err != nil { + return fmt.Errorf("status exchange failed: %v", err) + } + switch msg := msg.(type) { + case *Status: + default: + return fmt.Errorf("expected status, got: %#v ", msg) + } + // wait for disconnect + switch msg := conn.readAndServe(s.chain, timeout).(type) { + case *Disconnect: + return nil + case *Error: + return nil + default: + return fmt.Errorf("expected disconnect, got: %s", pretty.Sdump(msg)) + } +} diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index abc6bcddce..ad832dddd2 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -17,33 +17,16 @@ package ethtest import ( - "fmt" - "net" - "strings" "time" - "github.com/davecgh/go-spew/spew" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/internal/utesting" - "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/rlpx" - "github.com/stretchr/testify/assert" ) -var pretty = spew.ConfigState{ - Indent: " ", - DisableCapacities: true, - DisablePointerAddresses: true, - SortKeys: true, -} - -var timeout = 20 * time.Second - -// Suite represents a structure used to test the eth -// protocol of a node(s). +// Suite represents a structure used to test a node's conformance +// to the eth protocol. type Suite struct { Dest *enode.Node @@ -70,35 +53,35 @@ func (s *Suite) AllEthTests() []utesting.Test { return []utesting.Test{ // status {Name: "TestStatus", Fn: s.TestStatus}, - {Name: "TestStatus_66", Fn: s.TestStatus_66}, + {Name: "TestStatus66", Fn: s.TestStatus66}, // get block headers {Name: "TestGetBlockHeaders", Fn: s.TestGetBlockHeaders}, - {Name: "TestGetBlockHeaders_66", Fn: s.TestGetBlockHeaders_66}, - {Name: "TestSimultaneousRequests_66", Fn: s.TestSimultaneousRequests_66}, - {Name: "TestSameRequestID_66", Fn: s.TestSameRequestID_66}, - {Name: "TestZeroRequestID_66", Fn: s.TestZeroRequestID_66}, + {Name: "TestGetBlockHeaders66", Fn: s.TestGetBlockHeaders66}, + {Name: "TestSimultaneousRequests66", Fn: s.TestSimultaneousRequests66}, + {Name: "TestSameRequestID66", Fn: s.TestSameRequestID66}, + {Name: "TestZeroRequestID66", Fn: s.TestZeroRequestID66}, // get block bodies {Name: "TestGetBlockBodies", Fn: s.TestGetBlockBodies}, - {Name: "TestGetBlockBodies_66", Fn: s.TestGetBlockBodies_66}, + {Name: "TestGetBlockBodies66", Fn: s.TestGetBlockBodies66}, // broadcast {Name: "TestBroadcast", Fn: s.TestBroadcast}, - {Name: "TestBroadcast_66", Fn: s.TestBroadcast_66}, + {Name: "TestBroadcast66", Fn: s.TestBroadcast66}, {Name: "TestLargeAnnounce", Fn: s.TestLargeAnnounce}, - {Name: "TestLargeAnnounce_66", Fn: s.TestLargeAnnounce_66}, + {Name: "TestLargeAnnounce66", Fn: s.TestLargeAnnounce66}, {Name: "TestOldAnnounce", Fn: s.TestOldAnnounce}, - {Name: "TestOldAnnounce_66", Fn: s.TestOldAnnounce_66}, + {Name: "TestOldAnnounce66", Fn: s.TestOldAnnounce66}, // malicious handshakes + status {Name: "TestMaliciousHandshake", Fn: s.TestMaliciousHandshake}, {Name: "TestMaliciousStatus", Fn: s.TestMaliciousStatus}, - {Name: "TestMaliciousHandshake_66", Fn: s.TestMaliciousHandshake_66}, - {Name: "TestMaliciousStatus_66", Fn: s.TestMaliciousStatus_66}, + {Name: "TestMaliciousHandshake66", Fn: s.TestMaliciousHandshake66}, + {Name: "TestMaliciousStatus66", Fn: s.TestMaliciousStatus66}, // test transactions {Name: "TestTransaction", Fn: s.TestTransaction}, - {Name: "TestTransaction_66", Fn: s.TestTransaction_66}, + {Name: "TestTransaction66", Fn: s.TestTransaction66}, {Name: "TestMaliciousTx", Fn: s.TestMaliciousTx}, - {Name: "TestMaliciousTx_66", Fn: s.TestMaliciousTx_66}, - {Name: "TestLargeTxRequest_66", Fn: s.TestLargeTxRequest_66}, - {Name: "TestNewPooledTxs_66", Fn: s.TestNewPooledTxs_66}, + {Name: "TestMaliciousTx66", Fn: s.TestMaliciousTx66}, + {Name: "TestLargeTxRequest66", Fn: s.TestLargeTxRequest66}, + {Name: "TestNewPooledTxs66", Fn: s.TestNewPooledTxs66}, } } @@ -109,6 +92,7 @@ func (s *Suite) EthTests() []utesting.Test { {Name: "TestGetBlockBodies", Fn: s.TestGetBlockBodies}, {Name: "TestBroadcast", Fn: s.TestBroadcast}, {Name: "TestLargeAnnounce", Fn: s.TestLargeAnnounce}, + {Name: "TestOldAnnounce", Fn: s.TestOldAnnounce}, {Name: "TestMaliciousHandshake", Fn: s.TestMaliciousHandshake}, {Name: "TestMaliciousStatus", Fn: s.TestMaliciousStatus}, {Name: "TestTransaction", Fn: s.TestTransaction}, @@ -119,90 +103,67 @@ func (s *Suite) EthTests() []utesting.Test { func (s *Suite) Eth66Tests() []utesting.Test { return []utesting.Test{ // only proceed with eth66 test suite if node supports eth 66 protocol - {Name: "TestStatus_66", Fn: s.TestStatus_66}, - {Name: "TestGetBlockHeaders_66", Fn: s.TestGetBlockHeaders_66}, - {Name: "TestSimultaneousRequests_66", Fn: s.TestSimultaneousRequests_66}, - {Name: "TestSameRequestID_66", Fn: s.TestSameRequestID_66}, - {Name: "TestZeroRequestID_66", Fn: s.TestZeroRequestID_66}, - {Name: "TestGetBlockBodies_66", Fn: s.TestGetBlockBodies_66}, - {Name: "TestBroadcast_66", Fn: s.TestBroadcast_66}, - {Name: "TestLargeAnnounce_66", Fn: s.TestLargeAnnounce_66}, - {Name: "TestMaliciousHandshake_66", Fn: s.TestMaliciousHandshake_66}, - {Name: "TestMaliciousStatus_66", Fn: s.TestMaliciousStatus_66}, - {Name: "TestTransaction_66", Fn: s.TestTransaction_66}, - {Name: "TestMaliciousTx_66", Fn: s.TestMaliciousTx_66}, - {Name: "TestLargeTxRequest_66", Fn: s.TestLargeTxRequest_66}, - {Name: "TestNewPooledTxs_66", Fn: s.TestNewPooledTxs_66}, + {Name: "TestStatus66", Fn: s.TestStatus66}, + {Name: "TestGetBlockHeaders66", Fn: s.TestGetBlockHeaders66}, + {Name: "TestSimultaneousRequests66", Fn: s.TestSimultaneousRequests66}, + {Name: "TestSameRequestID66", Fn: s.TestSameRequestID66}, + {Name: "TestZeroRequestID66", Fn: s.TestZeroRequestID66}, + {Name: "TestGetBlockBodies66", Fn: s.TestGetBlockBodies66}, + {Name: "TestBroadcast66", Fn: s.TestBroadcast66}, + {Name: "TestLargeAnnounce66", Fn: s.TestLargeAnnounce66}, + {Name: "TestOldAnnounce66", Fn: s.TestOldAnnounce66}, + {Name: "TestMaliciousHandshake66", Fn: s.TestMaliciousHandshake66}, + {Name: "TestMaliciousStatus66", Fn: s.TestMaliciousStatus66}, + {Name: "TestTransaction66", Fn: s.TestTransaction66}, + {Name: "TestMaliciousTx66", Fn: s.TestMaliciousTx66}, + {Name: "TestLargeTxRequest66", Fn: s.TestLargeTxRequest66}, + {Name: "TestNewPooledTxs66", Fn: s.TestNewPooledTxs66}, } } +var ( + eth66 = true // indicates whether suite should negotiate eth66 connection + eth65 = false // indicates whether suite should negotiate eth65 connection or below. +) + // TestStatus attempts to connect to the given node and exchange -// a status message with it, and then check to make sure -// the chain head is correct. +// a status message with it. func (s *Suite) TestStatus(t *utesting.T) { conn, err := s.dial() if err != nil { - t.Fatalf("could not dial: %v", err) + t.Fatalf("dial failed: %v", err) } defer conn.Close() - // get protoHandshake - conn.handshake(t) - // get status - switch msg := conn.statusExchange(t, s.chain, nil).(type) { - case *Status: - t.Logf("got status message: %s", pretty.Sdump(msg)) - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) } } -// TestMaliciousStatus sends a status package with a large total difficulty. -func (s *Suite) TestMaliciousStatus(t *utesting.T) { - conn, err := s.dial() +// TestStatus66 attempts to connect to the given node and exchange +// a status message with it on the eth66 protocol. +func (s *Suite) TestStatus66(t *utesting.T) { + conn, err := s.dial66() if err != nil { - t.Fatalf("could not dial: %v", err) + t.Fatalf("dial failed: %v", err) } defer conn.Close() - // get protoHandshake - conn.handshake(t) - status := &Status{ - ProtocolVersion: uint32(conn.negotiatedProtoVersion), - NetworkID: s.chain.chainConfig.ChainID.Uint64(), - TD: largeNumber(2), - Head: s.chain.blocks[s.chain.Len()-1].Hash(), - Genesis: s.chain.blocks[0].Hash(), - ForkID: s.chain.ForkID(), - } - // get status - switch msg := conn.statusExchange(t, s.chain, status).(type) { - case *Status: - t.Logf("%+v\n", msg) - default: - t.Fatalf("expected status, got: %#v ", msg) - } - // wait for disconnect - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *Disconnect: - case *Error: - return - default: - t.Fatalf("expected disconnect, got: %s", pretty.Sdump(msg)) + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) } } // TestGetBlockHeaders tests whether the given node can respond to -// a `GetBlockHeaders` request and that the response is accurate. +// a `GetBlockHeaders` request accurately. func (s *Suite) TestGetBlockHeaders(t *utesting.T) { conn, err := s.dial() if err != nil { - t.Fatalf("could not dial: %v", err) + t.Fatalf("dial failed: %v", err) } defer conn.Close() - - conn.handshake(t) - conn.statusExchange(t, s.chain, nil) - - // get block headers + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("handshake(s) failed: %v", err) + } + // write request req := &GetBlockHeaders{ Origin: eth.HashOrNumber{ Hash: s.chain.blocks[1].Hash(), @@ -211,21 +172,219 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) { Skip: 1, Reverse: false, } - - if err := conn.Write(req); err != nil { - t.Fatalf("could not write to connection: %v", err) + headers, err := conn.headersRequest(req, s.chain, eth65, 0) + if err != nil { + t.Fatalf("GetBlockHeaders request failed: %v", err) } + // check for correct headers + expected, err := s.chain.GetHeaders(*req) + if err != nil { + t.Fatalf("failed to get headers for given request: %v", err) + } + if !headersMatch(expected, headers) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) + } +} - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *BlockHeaders: - headers := *msg - for _, header := range headers { - num := header.Number.Uint64() - t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash())) - assert.Equal(t, s.chain.blocks[int(num)].Header(), header) - } - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) +// TestGetBlockHeaders66 tests whether the given node can respond to +// an eth66 `GetBlockHeaders` request and that the response is accurate. +func (s *Suite) TestGetBlockHeaders66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err = conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + // write request + req := &GetBlockHeaders{ + Origin: eth.HashOrNumber{ + Hash: s.chain.blocks[1].Hash(), + }, + Amount: 2, + Skip: 1, + Reverse: false, + } + headers, err := conn.headersRequest(req, s.chain, eth66, 33) + if err != nil { + t.Fatalf("could not get block headers: %v", err) + } + // check for correct headers + expected, err := s.chain.GetHeaders(*req) + if err != nil { + t.Fatalf("failed to get headers for given request: %v", err) + } + if !headersMatch(expected, headers) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) + } +} + +// TestSimultaneousRequests66 sends two simultaneous `GetBlockHeader` requests from +// the same connection with different request IDs and checks to make sure the node +// responds with the correct headers per request. +func (s *Suite) TestSimultaneousRequests66(t *utesting.T) { + // create a connection + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + // create two requests + req1 := ð.GetBlockHeadersPacket66{ + RequestId: uint64(111), + GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ + Origin: eth.HashOrNumber{ + Hash: s.chain.blocks[1].Hash(), + }, + Amount: 2, + Skip: 1, + Reverse: false, + }, + } + req2 := ð.GetBlockHeadersPacket66{ + RequestId: uint64(222), + GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ + Origin: eth.HashOrNumber{ + Hash: s.chain.blocks[1].Hash(), + }, + Amount: 4, + Skip: 1, + Reverse: false, + }, + } + // write the first request + if err := conn.Write66(req1, GetBlockHeaders{}.Code()); err != nil { + t.Fatalf("failed to write to connection: %v", err) + } + // write the second request + if err := conn.Write66(req2, GetBlockHeaders{}.Code()); err != nil { + t.Fatalf("failed to write to connection: %v", err) + } + // wait for responses + msg := conn.waitForResponse(s.chain, timeout, req1.RequestId) + headers1, ok := msg.(BlockHeaders) + if !ok { + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } + msg = conn.waitForResponse(s.chain, timeout, req2.RequestId) + headers2, ok := msg.(BlockHeaders) + if !ok { + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } + // check received headers for accuracy + expected1, err := s.chain.GetHeaders(GetBlockHeaders(*req1.GetBlockHeadersPacket)) + if err != nil { + t.Fatalf("failed to get expected headers for request 1: %v", err) + } + expected2, err := s.chain.GetHeaders(GetBlockHeaders(*req2.GetBlockHeadersPacket)) + if err != nil { + t.Fatalf("failed to get expected headers for request 2: %v", err) + } + if !headersMatch(expected1, headers1) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1) + } + if !headersMatch(expected2, headers2) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2) + } +} + +// TestSameRequestID66 sends two requests with the same request ID to a +// single node. +func (s *Suite) TestSameRequestID66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + // create requests + reqID := uint64(1234) + request1 := ð.GetBlockHeadersPacket66{ + RequestId: reqID, + GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ + Origin: eth.HashOrNumber{ + Number: 1, + }, + Amount: 2, + }, + } + request2 := ð.GetBlockHeadersPacket66{ + RequestId: reqID, + GetBlockHeadersPacket: ð.GetBlockHeadersPacket{ + Origin: eth.HashOrNumber{ + Number: 33, + }, + Amount: 2, + }, + } + // write the requests + if err = conn.Write66(request1, GetBlockHeaders{}.Code()); err != nil { + t.Fatalf("failed to write to connection: %v", err) + } + if err = conn.Write66(request2, GetBlockHeaders{}.Code()); err != nil { + t.Fatalf("failed to write to connection: %v", err) + } + // wait for responses + msg := conn.waitForResponse(s.chain, timeout, reqID) + headers1, ok := msg.(BlockHeaders) + if !ok { + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } + msg = conn.waitForResponse(s.chain, timeout, reqID) + headers2, ok := msg.(BlockHeaders) + if !ok { + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } + // check if headers match + expected1, err := s.chain.GetHeaders(GetBlockHeaders(*request1.GetBlockHeadersPacket)) + if err != nil { + t.Fatalf("failed to get expected block headers: %v", err) + } + expected2, err := s.chain.GetHeaders(GetBlockHeaders(*request2.GetBlockHeadersPacket)) + if err != nil { + t.Fatalf("failed to get expected block headers: %v", err) + } + if !headersMatch(expected1, headers1) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected1, headers1) + } + if !headersMatch(expected2, headers2) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected2, headers2) + } +} + +// TestZeroRequestID_66 checks that a message with a request ID of zero is still handled +// by the node. +func (s *Suite) TestZeroRequestID66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + req := &GetBlockHeaders{ + Origin: eth.HashOrNumber{ + Number: 0, + }, + Amount: 2, + } + headers, err := conn.headersRequest(req, s.chain, eth66, 0) + if err != nil { + t.Fatalf("failed to get block headers: %v", err) + } + expected, err := s.chain.GetHeaders(*req) + if err != nil { + t.Fatalf("failed to get expected block headers: %v", err) + } + if !headersMatch(expected, headers) { + t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) } } @@ -234,12 +393,12 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) { func (s *Suite) TestGetBlockBodies(t *utesting.T) { conn, err := s.dial() if err != nil { - t.Fatalf("could not dial: %v", err) + t.Fatalf("dial failed: %v", err) } defer conn.Close() - - conn.handshake(t) - conn.statusExchange(t, s.chain, nil) + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } // create block bodies request req := &GetBlockBodies{ s.chain.blocks[54].Hash(), @@ -248,116 +407,68 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) { if err := conn.Write(req); err != nil { t.Fatalf("could not write to connection: %v", err) } - - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + // wait for response + switch msg := conn.readAndServe(s.chain, timeout).(type) { case *BlockBodies: t.Logf("received %d block bodies", len(*msg)) + if len(*msg) != len(*req) { + t.Fatalf("wrong bodies in response: expected %d bodies, "+ + "got %d", len(*req), len(*msg)) + } default: t.Fatalf("unexpected: %s", pretty.Sdump(msg)) } } +// TestGetBlockBodies66 tests whether the given node can respond to +// a `GetBlockBodies` request and that the response is accurate over +// the eth66 protocol. +func (s *Suite) TestGetBlockBodies66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + // create block bodies request + req := ð.GetBlockBodiesPacket66{ + RequestId: uint64(55), + GetBlockBodiesPacket: eth.GetBlockBodiesPacket{ + s.chain.blocks[54].Hash(), + s.chain.blocks[75].Hash(), + }, + } + if err := conn.Write66(req, GetBlockBodies{}.Code()); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // wait for block bodies response + msg := conn.waitForResponse(s.chain, timeout, req.RequestId) + blockBodies, ok := msg.(BlockBodies) + if !ok { + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } + t.Logf("received %d block bodies", len(blockBodies)) + if len(blockBodies) != len(req.GetBlockBodiesPacket) { + t.Fatalf("wrong bodies in response: expected %d bodies, "+ + "got %d", len(req.GetBlockBodiesPacket), len(blockBodies)) + } +} + // TestBroadcast tests whether a block announcement is correctly // propagated to the given node's peer(s). func (s *Suite) TestBroadcast(t *utesting.T) { - s.sendNextBlock(t) + if err := s.sendNextBlock(eth65); err != nil { + t.Fatalf("block broadcast failed: %v", err) + } } -func (s *Suite) sendNextBlock(t *utesting.T) { - sendConn, receiveConn := s.setupConnection(t), s.setupConnection(t) - defer sendConn.Close() - defer receiveConn.Close() - - // create new block announcement - nextBlock := len(s.chain.blocks) - blockAnnouncement := &NewBlock{ - Block: s.fullChain.blocks[nextBlock], - TD: s.fullChain.TD(nextBlock + 1), - } - // send announcement and wait for node to request the header - s.testAnnounce(t, sendConn, receiveConn, blockAnnouncement) - // wait for client to update its chain - if err := receiveConn.waitForBlock(s.fullChain.blocks[nextBlock]); err != nil { - t.Fatal(err) - } - // update test suite chain - s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) -} - -// TestMaliciousHandshake tries to send malicious data during the handshake. -func (s *Suite) TestMaliciousHandshake(t *utesting.T) { - conn, err := s.dial() - if err != nil { - t.Fatalf("could not dial: %v", err) - } - defer conn.Close() - // write hello to client - pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] - handshakes := []*Hello{ - { - Version: 5, - Caps: []p2p.Cap{ - {Name: largeString(2), Version: 64}, - }, - ID: pub0, - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - }, - ID: append(pub0, byte(0)), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - }, - ID: append(pub0, pub0...), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - }, - ID: largeBuffer(2), - }, - { - Version: 5, - Caps: []p2p.Cap{ - {Name: largeString(2), Version: 64}, - }, - ID: largeBuffer(2), - }, - } - for i, handshake := range handshakes { - t.Logf("Testing malicious handshake %v\n", i) - // Init the handshake - if err := conn.Write(handshake); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // check that the peer disconnected - timeout := 20 * time.Second - // Discard one hello - for i := 0; i < 2; i++ { - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *Disconnect: - case *Error: - case *Hello: - // Hello's are send concurrently, so ignore them - continue - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } - } - // Dial for the next round - conn, err = s.dial() - if err != nil { - t.Fatalf("could not dial: %v", err) - } +// TestBroadcast66 tests whether a block announcement is correctly +// propagated to the given node's peer(s) on the eth66 protocol. +func (s *Suite) TestBroadcast66(t *utesting.T) { + if err := s.sendNextBlock(eth66); err != nil { + t.Fatalf("block broadcast failed: %v", err) } } @@ -367,7 +478,7 @@ func (s *Suite) TestLargeAnnounce(t *utesting.T) { blocks := []*NewBlock{ { Block: largeBlock(), - TD: s.fullChain.TD(nextBlock + 1), + TD: s.fullChain.TotalDifficultyAt(nextBlock), }, { Block: s.fullChain.blocks[nextBlock], @@ -377,174 +488,267 @@ func (s *Suite) TestLargeAnnounce(t *utesting.T) { Block: largeBlock(), TD: largeNumber(2), }, - { - Block: s.fullChain.blocks[nextBlock], - TD: s.fullChain.TD(nextBlock + 1), - }, } - for i, blockAnnouncement := range blocks[0:3] { + for i, blockAnnouncement := range blocks { t.Logf("Testing malicious announcement: %v\n", i) - sendConn := s.setupConnection(t) - if err := sendConn.Write(blockAnnouncement); err != nil { + conn, err := s.dial() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + if err = conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + if err = conn.Write(blockAnnouncement); err != nil { t.Fatalf("could not write to connection: %v", err) } // Invalid announcement, check that peer disconnected - switch msg := sendConn.ReadAndServe(s.chain, time.Second*8).(type) { + switch msg := conn.readAndServe(s.chain, time.Second*8).(type) { case *Disconnect: case *Error: break default: t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg)) } - sendConn.Close() + conn.Close() } // Test the last block as a valid block - s.sendNextBlock(t) + if err := s.sendNextBlock(eth65); err != nil { + t.Fatalf("failed to broadcast next block: %v", err) + } } -func (s *Suite) TestOldAnnounce(t *utesting.T) { - sendConn, recvConn := s.setupConnection(t), s.setupConnection(t) - defer sendConn.Close() - defer recvConn.Close() - - s.oldAnnounce(t, sendConn, recvConn) -} - -func (s *Suite) oldAnnounce(t *utesting.T, sendConn, receiveConn *Conn) { - oldBlockAnnounce := &NewBlock{ - Block: s.chain.blocks[len(s.chain.blocks)/2], - TD: s.chain.blocks[len(s.chain.blocks)/2].Difficulty(), +// TestLargeAnnounce66 tests the announcement mechanism with a large +// block over the eth66 protocol. +func (s *Suite) TestLargeAnnounce66(t *utesting.T) { + nextBlock := len(s.chain.blocks) + blocks := []*NewBlock{ + { + Block: largeBlock(), + TD: s.fullChain.TotalDifficultyAt(nextBlock), + }, + { + Block: s.fullChain.blocks[nextBlock], + TD: largeNumber(2), + }, + { + Block: largeBlock(), + TD: largeNumber(2), + }, } - if err := sendConn.Write(oldBlockAnnounce); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - - switch msg := receiveConn.ReadAndServe(s.chain, time.Second*8).(type) { - case *NewBlock: - block := *msg - if block.Block.Hash() == oldBlockAnnounce.Block.Hash() { - t.Fatalf("unexpected: block propagated: %s", pretty.Sdump(msg)) + for i, blockAnnouncement := range blocks[0:3] { + t.Logf("Testing malicious announcement: %v\n", i) + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) } - case *NewBlockHashes: - hashes := *msg - for _, hash := range hashes { - if hash.Hash == oldBlockAnnounce.Block.Hash() { - t.Fatalf("unexpected: block announced: %s", pretty.Sdump(msg)) - } + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) } - case *Error: - errMsg := *msg - // check to make sure error is timeout (propagation didn't come through == test successful) - if !strings.Contains(errMsg.String(), "timeout") { - t.Fatalf("unexpected error: %v", pretty.Sdump(msg)) - } - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } -} - -func (s *Suite) testAnnounce(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) { - // Announce the block. - if err := sendConn.Write(blockAnnouncement); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - s.waitAnnounce(t, receiveConn, blockAnnouncement) -} - -func (s *Suite) waitAnnounce(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) { - switch msg := conn.ReadAndServe(s.chain, timeout).(type) { - case *NewBlock: - t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block)) - assert.Equal(t, - blockAnnouncement.Block.Header(), msg.Block.Header(), - "wrong block header in announcement", - ) - assert.Equal(t, - blockAnnouncement.TD, msg.TD, - "wrong TD in announcement", - ) - case *NewBlockHashes: - message := *msg - t.Logf("received NewBlockHashes message: %s", pretty.Sdump(message)) - assert.Equal(t, blockAnnouncement.Block.Hash(), message[0].Hash, - "wrong block hash in announcement", - ) - default: - t.Fatalf("unexpected: %s", pretty.Sdump(msg)) - } -} - -func (s *Suite) setupConnection(t *utesting.T) *Conn { - // create conn - sendConn, err := s.dial() - if err != nil { - t.Fatalf("could not dial: %v", err) - } - sendConn.handshake(t) - sendConn.statusExchange(t, s.chain, nil) - return sendConn -} - -// dial attempts to dial the given node and perform a handshake, -// returning the created Conn if successful. -func (s *Suite) dial() (*Conn, error) { - var conn Conn - // dial - fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) - if err != nil { - return nil, err - } - conn.Conn = rlpx.NewConn(fd, s.Dest.Pubkey()) - // do encHandshake - conn.ourKey, _ = crypto.GenerateKey() - _, err = conn.Handshake(conn.ourKey) - if err != nil { - return nil, err - } - // set default p2p capabilities - conn.caps = []p2p.Cap{ - {Name: "eth", Version: 64}, - {Name: "eth", Version: 65}, - } - conn.ourHighestProtoVersion = 65 - return &conn, nil -} - -func (s *Suite) TestTransaction(t *utesting.T) { - tests := []*types.Transaction{ - getNextTxFromChain(t, s), - unknownTx(t, s), - } - for i, tx := range tests { - t.Logf("Testing tx propagation: %v\n", i) - sendSuccessfulTx(t, s, tx) - } -} - -func (s *Suite) TestMaliciousTx(t *utesting.T) { - badTxs := []*types.Transaction{ - getOldTxFromChain(t, s), - invalidNonceTx(t, s), - hugeAmount(t, s), - hugeGasPrice(t, s), - hugeData(t, s), - } - sendConn := s.setupConnection(t) - defer sendConn.Close() - // set up receiving connection before sending txs to make sure - // no announcements are missed - recvConn := s.setupConnection(t) - defer recvConn.Close() - - for i, tx := range badTxs { - t.Logf("Testing malicious tx propagation: %v\n", i) - if err := sendConn.Write(&Transactions{tx}); err != nil { + if err := conn.Write(blockAnnouncement); err != nil { t.Fatalf("could not write to connection: %v", err) } - + // Invalid announcement, check that peer disconnected + switch msg := conn.readAndServe(s.chain, time.Second*8).(type) { + case *Disconnect: + case *Error: + break + default: + t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg)) + } + conn.Close() + } + // Test the last block as a valid block + if err := s.sendNextBlock(eth66); err != nil { + t.Fatalf("failed to broadcast next block: %v", err) + } +} + +// TestOldAnnounce tests the announcement mechanism with an old block. +func (s *Suite) TestOldAnnounce(t *utesting.T) { + if err := s.oldAnnounce(eth65); err != nil { + t.Fatal(err) + } +} + +// TestOldAnnounce66 tests the announcement mechanism with an old block, +// over the eth66 protocol. +func (s *Suite) TestOldAnnounce66(t *utesting.T) { + if err := s.oldAnnounce(eth66); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousHandshake tries to send malicious data during the handshake. +func (s *Suite) TestMaliciousHandshake(t *utesting.T) { + if err := s.maliciousHandshakes(t, eth65); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousHandshake66 tries to send malicious data during the handshake. +func (s *Suite) TestMaliciousHandshake66(t *utesting.T) { + if err := s.maliciousHandshakes(t, eth66); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousStatus sends a status package with a large total difficulty. +func (s *Suite) TestMaliciousStatus(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + + if err := s.maliciousStatus(conn); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousStatus66 sends a status package with a large total +// difficulty over the eth66 protocol. +func (s *Suite) TestMaliciousStatus66(t *utesting.T) { + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + + if err := s.maliciousStatus(conn); err != nil { + t.Fatal(err) + } +} + +// TestTransaction sends a valid transaction to the node and +// checks if the transaction gets propagated. +func (s *Suite) TestTransaction(t *utesting.T) { + if err := s.sendSuccessfulTxs(t, eth65); err != nil { + t.Fatal(err) + } +} + +// TestTransaction66 sends a valid transaction to the node and +// checks if the transaction gets propagated. +func (s *Suite) TestTransaction66(t *utesting.T) { + if err := s.sendSuccessfulTxs(t, eth66); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousTx sends several invalid transactions and tests whether +// the node will propagate them. +func (s *Suite) TestMaliciousTx(t *utesting.T) { + if err := s.sendMaliciousTxs(t, eth65); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousTx66 sends several invalid transactions and tests whether +// the node will propagate them. +func (s *Suite) TestMaliciousTx66(t *utesting.T) { + if err := s.sendMaliciousTxs(t, eth66); err != nil { + t.Fatal(err) + } +} + +// TestLargeTxRequest66 tests whether a node can fulfill a large GetPooledTransactions +// request. +func (s *Suite) TestLargeTxRequest66(t *utesting.T) { + // send the next block to ensure the node is no longer syncing and + // is able to accept txs + if err := s.sendNextBlock(eth66); err != nil { + t.Fatalf("failed to send next block: %v", err) + } + // send 2000 transactions to the node + hashMap, txs, err := generateTxs(s, 2000) + if err != nil { + t.Fatalf("failed to generate transactions: %v", err) + } + if err = sendMultipleSuccessfulTxs(t, s, txs); err != nil { + t.Fatalf("failed to send multiple txs: %v", err) + } + // set up connection to receive to ensure node is peered with the receiving connection + // before tx request is sent + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err = conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + // create and send pooled tx request + hashes := make([]common.Hash, 0) + for _, hash := range hashMap { + hashes = append(hashes, hash) + } + getTxReq := ð.GetPooledTransactionsPacket66{ + RequestId: 1234, + GetPooledTransactionsPacket: hashes, + } + if err = conn.Write66(getTxReq, GetPooledTransactions{}.Code()); err != nil { + t.Fatalf("could not write to conn: %v", err) + } + // check that all received transactions match those that were sent to node + switch msg := conn.waitForResponse(s.chain, timeout, getTxReq.RequestId).(type) { + case PooledTransactions: + for _, gotTx := range msg { + if _, exists := hashMap[gotTx.Hash()]; !exists { + t.Fatalf("unexpected tx received: %v", gotTx.Hash()) + } + } + default: + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } +} + +// TestNewPooledTxs_66 tests whether a node will do a GetPooledTransactions +// request upon receiving a NewPooledTransactionHashes announcement. +func (s *Suite) TestNewPooledTxs66(t *utesting.T) { + // send the next block to ensure the node is no longer syncing and + // is able to accept txs + if err := s.sendNextBlock(eth66); err != nil { + t.Fatalf("failed to send next block: %v", err) + } + // generate 50 txs + hashMap, _, err := generateTxs(s, 50) + if err != nil { + t.Fatalf("failed to generate transactions: %v", err) + } + // create new pooled tx hashes announcement + hashes := make([]common.Hash, 0) + for _, hash := range hashMap { + hashes = append(hashes, hash) + } + announce := NewPooledTransactionHashes(hashes) + // send announcement + conn, err := s.dial66() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err = conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + if err = conn.Write(announce); err != nil { + t.Fatalf("failed to write to connection: %v", err) + } + // wait for GetPooledTxs request + for { + _, msg := conn.readAndServe66(s.chain, timeout) + switch msg := msg.(type) { + case GetPooledTransactions: + if len(msg) != len(hashes) { + t.Fatalf("unexpected number of txs requested: wanted %d, got %d", len(hashes), len(msg)) + } + return + case *NewPooledTransactionHashes: + // ignore propagated txs from old tests + continue + default: + t.Fatalf("unexpected %s", pretty.Sdump(msg)) + } } - // check to make sure bad txs aren't propagated - waitForTxPropagation(t, s, badTxs, recvConn) } diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go index a6166bd2e3..d2dbe0a7d6 100644 --- a/cmd/devp2p/internal/ethtest/transaction.go +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -17,6 +17,7 @@ package ethtest import ( + "fmt" "math/big" "strings" "time" @@ -31,58 +32,171 @@ import ( //var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") -func sendSuccessfulTx(t *utesting.T, s *Suite, tx *types.Transaction) { - sendConn := s.setupConnection(t) - defer sendConn.Close() - sendSuccessfulTxWithConn(t, s, tx, sendConn) +func (s *Suite) sendSuccessfulTxs(t *utesting.T, isEth66 bool) error { + tests := []*types.Transaction{ + getNextTxFromChain(s), + unknownTx(s), + } + for i, tx := range tests { + if tx == nil { + return fmt.Errorf("could not find tx to send") + } + t.Logf("Testing tx propagation %d: sending tx %v %v %v\n", i, tx.Hash().String(), tx.GasPrice(), tx.Gas()) + // get previous tx if exists for reference in case of old tx propagation + var prevTx *types.Transaction + if i != 0 { + prevTx = tests[i-1] + } + // write tx to connection + if err := sendSuccessfulTx(s, tx, prevTx, isEth66); err != nil { + return fmt.Errorf("send successful tx test failed: %v", err) + } + } + return nil } -func sendSuccessfulTxWithConn(t *utesting.T, s *Suite, tx *types.Transaction, sendConn *Conn) { - t.Logf("sending tx: %v %v %v\n", tx.Hash().String(), tx.GasPrice(), tx.Gas()) +func sendSuccessfulTx(s *Suite, tx *types.Transaction, prevTx *types.Transaction, isEth66 bool) error { + sendConn, recvConn, err := s.createSendAndRecvConns(isEth66) + if err != nil { + return err + } + defer sendConn.Close() + defer recvConn.Close() + if err = sendConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } // Send the transaction - if err := sendConn.Write(&Transactions{tx}); err != nil { - t.Fatal(err) + if err = sendConn.Write(&Transactions{tx}); err != nil { + return fmt.Errorf("failed to write to connection: %v", err) + } + // peer receiving connection to node + if err = recvConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) } // update last nonce seen nonce = tx.Nonce() - - recvConn := s.setupConnection(t) // Wait for the transaction announcement - switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { - case *Transactions: - recTxs := *msg - for _, gotTx := range recTxs { - if gotTx.Hash() == tx.Hash() { - // Ok - return + for { + switch msg := recvConn.readAndServe(s.chain, timeout).(type) { + case *Transactions: + recTxs := *msg + // if you receive an old tx propagation, read from connection again + if len(recTxs) == 1 && prevTx != nil { + if recTxs[0] == prevTx { + continue + } } - } - t.Fatalf("missing transaction: got %v missing %v", recTxs, tx.Hash()) - case *NewPooledTransactionHashes: - txHashes := *msg - for _, gotHash := range txHashes { - if gotHash == tx.Hash() { - return + for _, gotTx := range recTxs { + if gotTx.Hash() == tx.Hash() { + // Ok + return nil + } } + return fmt.Errorf("missing transaction: got %v missing %v", recTxs, tx.Hash()) + case *NewPooledTransactionHashes: + txHashes := *msg + // if you receive an old tx propagation, read from connection again + if len(txHashes) == 1 && prevTx != nil { + if txHashes[0] == prevTx.Hash() { + continue + } + } + for _, gotHash := range txHashes { + if gotHash == tx.Hash() { + // Ok + return nil + } + } + return fmt.Errorf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash()) + default: + return fmt.Errorf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg)) } - t.Fatalf("missing transaction announcement: got %v missing %v", txHashes, tx.Hash()) - default: - t.Fatalf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg)) } } +func (s *Suite) sendMaliciousTxs(t *utesting.T, isEth66 bool) error { + badTxs := []*types.Transaction{ + getOldTxFromChain(s), + invalidNonceTx(s), + hugeAmount(s), + hugeGasPrice(s), + hugeData(s), + } + // setup receiving connection before sending malicious txs + var ( + recvConn *Conn + err error + ) + if isEth66 { + recvConn, err = s.dial66() + } else { + recvConn, err = s.dial() + } + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + defer recvConn.Close() + if err = recvConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + for i, tx := range badTxs { + t.Logf("Testing malicious tx propagation: %v\n", i) + if err = sendMaliciousTx(s, tx, isEth66); err != nil { + return fmt.Errorf("malicious tx test failed:\ntx: %v\nerror: %v", tx, err) + } + } + // check to make sure bad txs aren't propagated + return checkMaliciousTxPropagation(s, badTxs, recvConn) +} + +func sendMaliciousTx(s *Suite, tx *types.Transaction, isEth66 bool) error { + // setup connection + var ( + conn *Conn + err error + ) + if isEth66 { + conn, err = s.dial66() + } else { + conn, err = s.dial() + } + if err != nil { + return fmt.Errorf("dial failed: %v", err) + } + defer conn.Close() + if err = conn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + // write malicious tx + if err = conn.Write(&Transactions{tx}); err != nil { + return fmt.Errorf("failed to write to connection: %v", err) + } + return nil +} + var nonce = uint64(99) -func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*types.Transaction) { +// sendMultipleSuccessfulTxs sends the given transactions to the node and +// expects the node to accept and propagate them. +func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, txs []*types.Transaction) error { txMsg := Transactions(txs) t.Logf("sending %d txs\n", len(txs)) - recvConn := s.setupConnection(t) + sendConn, recvConn, err := s.createSendAndRecvConns(true) + if err != nil { + return err + } + defer sendConn.Close() defer recvConn.Close() - + if err = sendConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } + if err = recvConn.peer(s.chain, nil); err != nil { + return fmt.Errorf("peering failed: %v", err) + } // Send the transactions - if err := sendConn.Write(&txMsg); err != nil { - t.Fatal(err) + if err = sendConn.Write(&txMsg); err != nil { + return fmt.Errorf("failed to write message to connection: %v", err) } // update nonce nonce = txs[len(txs)-1].Nonce() @@ -90,7 +204,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t recvHashes := make([]common.Hash, 0) // all txs should be announced within 3 announcements for i := 0; i < 3; i++ { - switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { + switch msg := recvConn.readAndServe(s.chain, timeout).(type) { case *Transactions: for _, tx := range *msg { recvHashes = append(recvHashes, tx.Hash()) @@ -99,7 +213,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t recvHashes = append(recvHashes, *msg...) default: if !strings.Contains(pretty.Sdump(msg), "i/o timeout") { - t.Fatalf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg)) + return fmt.Errorf("unexpected message while waiting to receive txs: %s", pretty.Sdump(msg)) } } // break once all 2000 txs have been received @@ -112,7 +226,7 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t continue } else { t.Logf("successfully received all %d txs", len(txs)) - return + return nil } } } @@ -121,13 +235,15 @@ func sendMultipleSuccessfulTxs(t *utesting.T, s *Suite, sendConn *Conn, txs []*t for _, missing := range missingTxs { t.Logf("missing tx: %v", missing.Hash()) } - t.Fatalf("missing %d txs", len(missingTxs)) + return fmt.Errorf("missing %d txs", len(missingTxs)) } + return nil } -func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, recvConn *Conn) { - // Wait for another transaction announcement - switch msg := recvConn.ReadAndServe(s.chain, time.Second*8).(type) { +// checkMaliciousTxPropagation checks whether the given malicious transactions were +// propagated by the node. +func checkMaliciousTxPropagation(s *Suite, txs []*types.Transaction, conn *Conn) error { + switch msg := conn.readAndServe(s.chain, time.Second*8).(type) { case *Transactions: // check to see if any of the failing txs were in the announcement recvTxs := make([]common.Hash, len(*msg)) @@ -136,25 +252,20 @@ func waitForTxPropagation(t *utesting.T, s *Suite, txs []*types.Transaction, rec } badTxs, _ := compareReceivedTxs(recvTxs, txs) if len(badTxs) > 0 { - for _, tx := range badTxs { - t.Logf("received bad tx: %v", tx) - } - t.Fatalf("received %d bad txs", len(badTxs)) + return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs) } case *NewPooledTransactionHashes: badTxs, _ := compareReceivedTxs(*msg, txs) if len(badTxs) > 0 { - for _, tx := range badTxs { - t.Logf("received bad tx: %v", tx) - } - t.Fatalf("received %d bad txs", len(badTxs)) + return fmt.Errorf("received %d bad txs: \n%v", len(badTxs), badTxs) } case *Error: // Transaction should not be announced -> wait for timeout - return + return nil default: - t.Fatalf("unexpected message in sendFailingTx: %s", pretty.Sdump(msg)) + return fmt.Errorf("unexpected message in sendFailingTx: %s", pretty.Sdump(msg)) } + return nil } // compareReceivedTxs compares the received set of txs against the given set of txs, @@ -180,118 +291,129 @@ func compareReceivedTxs(recvTxs []common.Hash, txs []*types.Transaction) (presen return present, missing } -func unknownTx(t *utesting.T, s *Suite) *types.Transaction { - tx := getNextTxFromChain(t, s) +func unknownTx(s *Suite) *types.Transaction { + tx := getNextTxFromChain(s) + if tx == nil { + return nil + } var to common.Address if tx.To() != nil { to = *tx.To() } txNew := types.NewTransaction(tx.Nonce()+1, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) - return signWithFaucet(t, s.chain.chainConfig, txNew) + return signWithFaucet(s.chain.chainConfig, txNew) } -func getNextTxFromChain(t *utesting.T, s *Suite) *types.Transaction { +func getNextTxFromChain(s *Suite) *types.Transaction { // Get a new transaction - var tx *types.Transaction for _, blocks := range s.fullChain.blocks[s.chain.Len():] { txs := blocks.Transactions() if txs.Len() != 0 { - tx = txs[0] - break + return txs[0] } } - if tx == nil { - t.Fatal("could not find transaction") - } - return tx + return nil } -func generateTxs(t *utesting.T, s *Suite, numTxs int) (map[common.Hash]common.Hash, []*types.Transaction) { +func generateTxs(s *Suite, numTxs int) (map[common.Hash]common.Hash, []*types.Transaction, error) { txHashMap := make(map[common.Hash]common.Hash, numTxs) txs := make([]*types.Transaction, numTxs) - nextTx := getNextTxFromChain(t, s) + nextTx := getNextTxFromChain(s) + if nextTx == nil { + return nil, nil, fmt.Errorf("failed to get the next transaction") + } gas := nextTx.Gas() nonce = nonce + 1 // generate txs for i := 0; i < numTxs; i++ { - tx := generateTx(t, s.chain.chainConfig, nonce, gas) + tx := generateTx(s.chain.chainConfig, nonce, gas) + if tx == nil { + return nil, nil, fmt.Errorf("failed to get the next transaction") + } txHashMap[tx.Hash()] = tx.Hash() txs[i] = tx nonce = nonce + 1 } - return txHashMap, txs + return txHashMap, txs, nil } -func generateTx(t *utesting.T, chainConfig *params.ChainConfig, nonce uint64, gas uint64) *types.Transaction { +func generateTx(chainConfig *params.ChainConfig, nonce uint64, gas uint64) *types.Transaction { var to common.Address tx := types.NewTransaction(nonce, to, big.NewInt(1), gas, big.NewInt(1), []byte{}) - return signWithFaucet(t, chainConfig, tx) + return signWithFaucet(chainConfig, tx) } -func getOldTxFromChain(t *utesting.T, s *Suite) *types.Transaction { - var tx *types.Transaction +func getOldTxFromChain(s *Suite) *types.Transaction { for _, blocks := range s.fullChain.blocks[:s.chain.Len()-1] { txs := blocks.Transactions() if txs.Len() != 0 { - tx = txs[0] - break + return txs[0] } } - if tx == nil { - t.Fatal("could not find transaction") - } - return tx + return nil } -func invalidNonceTx(t *utesting.T, s *Suite) *types.Transaction { - tx := getNextTxFromChain(t, s) +func invalidNonceTx(s *Suite) *types.Transaction { + tx := getNextTxFromChain(s) + if tx == nil { + return nil + } var to common.Address if tx.To() != nil { to = *tx.To() } txNew := types.NewTransaction(tx.Nonce()-2, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) - return signWithFaucet(t, s.chain.chainConfig, txNew) + return signWithFaucet(s.chain.chainConfig, txNew) } -func hugeAmount(t *utesting.T, s *Suite) *types.Transaction { - tx := getNextTxFromChain(t, s) +func hugeAmount(s *Suite) *types.Transaction { + tx := getNextTxFromChain(s) + if tx == nil { + return nil + } amount := largeNumber(2) var to common.Address if tx.To() != nil { to = *tx.To() } txNew := types.NewTransaction(tx.Nonce(), to, amount, tx.Gas(), tx.GasPrice(), tx.Data()) - return signWithFaucet(t, s.chain.chainConfig, txNew) + return signWithFaucet(s.chain.chainConfig, txNew) } -func hugeGasPrice(t *utesting.T, s *Suite) *types.Transaction { - tx := getNextTxFromChain(t, s) +func hugeGasPrice(s *Suite) *types.Transaction { + tx := getNextTxFromChain(s) + if tx == nil { + return nil + } gasPrice := largeNumber(2) var to common.Address if tx.To() != nil { to = *tx.To() } txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), gasPrice, tx.Data()) - return signWithFaucet(t, s.chain.chainConfig, txNew) + return signWithFaucet(s.chain.chainConfig, txNew) } -func hugeData(t *utesting.T, s *Suite) *types.Transaction { - tx := getNextTxFromChain(t, s) +func hugeData(s *Suite) *types.Transaction { + tx := getNextTxFromChain(s) + if tx == nil { + return nil + } var to common.Address if tx.To() != nil { to = *tx.To() } txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), tx.GasPrice(), largeBuffer(2)) - return signWithFaucet(t, s.chain.chainConfig, txNew) + return signWithFaucet(s.chain.chainConfig, txNew) } -func signWithFaucet(t *utesting.T, chainConfig *params.ChainConfig, tx *types.Transaction) *types.Transaction { +func signWithFaucet(chainConfig *params.ChainConfig, tx *types.Transaction) *types.Transaction { signer := types.LatestSigner(chainConfig) signedTx, err := types.SignTx(tx, signer, faucetKey) if err != nil { - t.Fatalf("could not sign tx: %v\n", err) + return nil } return signedTx } diff --git a/cmd/devp2p/internal/ethtest/types.go b/cmd/devp2p/internal/ethtest/types.go index 50a69b9418..e49ea284e9 100644 --- a/cmd/devp2p/internal/ethtest/types.go +++ b/cmd/devp2p/internal/ethtest/types.go @@ -19,13 +19,8 @@ package ethtest import ( "crypto/ecdsa" "fmt" - "reflect" - "time" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/protocols/eth" - "github.com/ethereum/go-ethereum/internal/utesting" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/rlpx" "github.com/ethereum/go-ethereum/rlp" @@ -137,6 +132,7 @@ type Conn struct { caps []p2p.Cap } +// Read reads an eth packet from the connection. func (c *Conn) Read() Message { code, rawData, _, err := c.Conn.Read() if err != nil { @@ -185,32 +181,83 @@ func (c *Conn) Read() Message { return msg } -// ReadAndServe serves GetBlockHeaders requests while waiting -// on another message from the node. -func (c *Conn) ReadAndServe(chain *Chain, timeout time.Duration) Message { - start := time.Now() - for time.Since(start) < timeout { - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - switch msg := c.Read().(type) { - case *Ping: - c.Write(&Pong{}) - case *GetBlockHeaders: - req := *msg - headers, err := chain.GetHeaders(req) - if err != nil { - return errorf("could not get headers for inbound header request: %v", err) - } - - if err := c.Write(headers); err != nil { - return errorf("could not write to connection: %v", err) - } - default: - return msg - } +// Read66 reads an eth66 packet from the connection. +func (c *Conn) Read66() (uint64, Message) { + code, rawData, _, err := c.Conn.Read() + if err != nil { + return 0, errorf("could not read from connection: %v", err) } - return errorf("no message received within %v", timeout) + + var msg Message + switch int(code) { + case (Hello{}).Code(): + msg = new(Hello) + case (Ping{}).Code(): + msg = new(Ping) + case (Pong{}).Code(): + msg = new(Pong) + case (Disconnect{}).Code(): + msg = new(Disconnect) + case (Status{}).Code(): + msg = new(Status) + case (GetBlockHeaders{}).Code(): + ethMsg := new(eth.GetBlockHeadersPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket) + case (BlockHeaders{}).Code(): + ethMsg := new(eth.BlockHeadersPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket) + case (GetBlockBodies{}).Code(): + ethMsg := new(eth.GetBlockBodiesPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket) + case (BlockBodies{}).Code(): + ethMsg := new(eth.BlockBodiesPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket) + case (NewBlock{}).Code(): + msg = new(NewBlock) + case (NewBlockHashes{}).Code(): + msg = new(NewBlockHashes) + case (Transactions{}).Code(): + msg = new(Transactions) + case (NewPooledTransactionHashes{}).Code(): + msg = new(NewPooledTransactionHashes) + case (GetPooledTransactions{}.Code()): + ethMsg := new(eth.GetPooledTransactionsPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket) + case (PooledTransactions{}.Code()): + ethMsg := new(eth.PooledTransactionsPacket66) + if err := rlp.DecodeBytes(rawData, ethMsg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket) + default: + msg = errorf("invalid message code: %d", code) + } + + if msg != nil { + if err := rlp.DecodeBytes(rawData, msg); err != nil { + return 0, errorf("could not rlp decode message: %v", err) + } + return 0, msg + } + return 0, errorf("invalid message: %s", string(rawData)) } +// Write writes a eth packet to the connection. func (c *Conn) Write(msg Message) error { // check if message is eth protocol message var ( @@ -225,135 +272,12 @@ func (c *Conn) Write(msg Message) error { return err } -// handshake checks to make sure a `HELLO` is received. -func (c *Conn) handshake(t *utesting.T) Message { - defer c.SetDeadline(time.Time{}) - c.SetDeadline(time.Now().Add(10 * time.Second)) - - // write hello to client - pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:] - ourHandshake := &Hello{ - Version: 5, - Caps: c.caps, - ID: pub0, - } - if err := c.Write(ourHandshake); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - // read hello from client - switch msg := c.Read().(type) { - case *Hello: - // set snappy if version is at least 5 - if msg.Version >= 5 { - c.SetSnappy(true) - } - c.negotiateEthProtocol(msg.Caps) - if c.negotiatedProtoVersion == 0 { - t.Fatalf("unexpected eth protocol version") - } - return msg - default: - t.Fatalf("bad handshake: %#v", msg) - return nil - } -} - -// negotiateEthProtocol sets the Conn's eth protocol version -// to highest advertised capability from peer -func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { - var highestEthVersion uint - for _, capability := range caps { - if capability.Name != "eth" { - continue - } - if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion { - highestEthVersion = capability.Version - } - } - c.negotiatedProtoVersion = highestEthVersion -} - -// statusExchange performs a `Status` message exchange with the given -// node. -func (c *Conn) statusExchange(t *utesting.T, chain *Chain, status *Status) Message { - defer c.SetDeadline(time.Time{}) - c.SetDeadline(time.Now().Add(20 * time.Second)) - - // read status message from client - var message Message -loop: - for { - switch msg := c.Read().(type) { - case *Status: - if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want { - t.Fatalf("wrong head block in status, want: %#x (block %d) have %#x", - want, chain.blocks[chain.Len()-1].NumberU64(), have) - } - if have, want := msg.TD.Cmp(chain.TD(chain.Len())), 0; have != want { - t.Fatalf("wrong TD in status: have %v want %v", have, want) - } - if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) { - t.Fatalf("wrong fork ID in status: have %v, want %v", have, want) - } - message = msg - break loop - case *Disconnect: - t.Fatalf("disconnect received: %v", msg.Reason) - case *Ping: - c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error - // (PINGs should not be a response upon fresh connection) - default: - t.Fatalf("bad status message: %s", pretty.Sdump(msg)) - } - } - // make sure eth protocol version is set for negotiation - if c.negotiatedProtoVersion == 0 { - t.Fatalf("eth protocol version must be set in Conn") - } - if status == nil { - // write status message to client - status = &Status{ - ProtocolVersion: uint32(c.negotiatedProtoVersion), - NetworkID: chain.chainConfig.ChainID.Uint64(), - TD: chain.TD(chain.Len()), - Head: chain.blocks[chain.Len()-1].Hash(), - Genesis: chain.blocks[0].Hash(), - ForkID: chain.ForkID(), - } - } - - if err := c.Write(status); err != nil { - t.Fatalf("could not write to connection: %v", err) - } - - return message -} - -// waitForBlock waits for confirmation from the client that it has -// imported the given block. -func (c *Conn) waitForBlock(block *types.Block) error { - defer c.SetReadDeadline(time.Time{}) - - c.SetReadDeadline(time.Now().Add(20 * time.Second)) - // note: if the node has not yet imported the block, it will respond - // to the GetBlockHeaders request with an empty BlockHeaders response, - // so the GetBlockHeaders request must be sent again until the BlockHeaders - // response contains the desired header. - for { - req := &GetBlockHeaders{Origin: eth.HashOrNumber{Hash: block.Hash()}, Amount: 1} - if err := c.Write(req); err != nil { - return err - } - switch msg := c.Read().(type) { - case *BlockHeaders: - for _, header := range *msg { - if header.Number.Uint64() == block.NumberU64() { - return nil - } - } - time.Sleep(100 * time.Millisecond) - default: - return fmt.Errorf("invalid message: %s", pretty.Sdump(msg)) - } +// Write66 writes an eth66 packet to the connection. +func (c *Conn) Write66(req eth.Packet, code int) error { + payload, err := rlp.EncodeToBytes(req) + if err != nil { + return err } + _, err = c.Conn.Write(uint64(code), payload) + return err }