diff --git a/p2p/protocols/accounting.go b/p2p/protocols/accounting.go new file mode 100644 index 0000000000..06a1a58454 --- /dev/null +++ b/p2p/protocols/accounting.go @@ -0,0 +1,172 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package protocols + +import "github.com/ethereum/go-ethereum/metrics" + +//define some metrics +var ( + //NOTE: these metrics just define the interfaces and are currently *NOT persisted* over sessions + //All metrics are cumulative + + //total amount of units credited + mBalanceCredit = metrics.NewRegisteredCounterForced("account.balance.credit", nil) + //total amount of units debited + mBalanceDebit = metrics.NewRegisteredCounterForced("account.balance.debit", nil) + //total amount of bytes credited + mBytesCredit = metrics.NewRegisteredCounterForced("account.bytes.credit", nil) + //total amount of bytes debited + mBytesDebit = metrics.NewRegisteredCounterForced("account.bytes.debit", nil) + //total amount of credited messages + mMsgCredit = metrics.NewRegisteredCounterForced("account.msg.credit", nil) + //total amount of debited messages + mMsgDebit = metrics.NewRegisteredCounterForced("account.msg.debit", nil) + //how many times local node had to drop remote peers + mPeerDrops = metrics.NewRegisteredCounterForced("account.peerdrops", nil) + //how many times local node overdrafted and dropped + mSelfDrops = metrics.NewRegisteredCounterForced("account.selfdrops", nil) +) + +//Prices defines how prices are being passed on to the accounting instance +type Prices interface { + //Return the Price for a message + Price(interface{}) *Price +} + +type Payer bool + +const ( + Sender = Payer(true) + Receiver = Payer(false) +) + +//Price represents the costs of a message +type Price struct { + Value uint64 // + PerByte bool //True if the price is per byte or for unit + Payer Payer +} + +//For gives back the price for a message +//A protocol provides the message price in absolute value +//This method then returns the correct signed amount, +//depending on who pays, which is identified by the `payer` argument: +//`Send` will pass a `Sender` payer, `Receive` will pass the `Receiver` argument. +//Thus: If Sending and sender pays, amount positive, otherwise negative +//If Receiving, and receiver pays, amount positive, otherwise negative +func (p *Price) For(payer Payer, size uint32) int64 { + price := p.Value + if p.PerByte { + price *= uint64(size) + } + if p.Payer == payer { + return 0 - int64(price) + } + return int64(price) +} + +//Balance is the actual accounting instance +//Balance defines the operations needed for accounting +//Implementations internally maintain the balance for every peer +type Balance interface { + //Adds amount to the local balance with remote node `peer`; + //positive amount = credit local node + //negative amount = debit local node + Add(amount int64, peer *Peer) error +} + +//Accounting implements the Hook interface +//It interfaces to the balances through the Balance interface, +//while interfacing with protocols and its prices through the Prices interface +type Accounting struct { + Balance //interface to accounting logic + Prices //interface to prices logic +} + +func NewAccounting(balance Balance, po Prices) *Accounting { + ah := &Accounting{ + Prices: po, + Balance: balance, + } + return ah +} + +//Implement Hook.Send +// Send takes a peer, a size and a msg and +// - calculates the cost for the local node sending a msg of size to peer using the Prices interface +// - credits/debits local node using balance interface +func (ah *Accounting) Send(peer *Peer, size uint32, msg interface{}) error { + //get the price for a message (through the protocol spec) + price := ah.Price(msg) + //this message doesn't need accounting + if price == nil { + return nil + } + //evaluate the price for sending messages + costToLocalNode := price.For(Sender, size) + //do the accounting + err := ah.Add(costToLocalNode, peer) + //record metrics: just increase counters for user-facing metrics + ah.doMetrics(costToLocalNode, size, err) + return err +} + +//Implement Hook.Receive +// Receive takes a peer, a size and a msg and +// - calculates the cost for the local node receiving a msg of size from peer using the Prices interface +// - credits/debits local node using balance interface +func (ah *Accounting) Receive(peer *Peer, size uint32, msg interface{}) error { + //get the price for a message (through the protocol spec) + price := ah.Price(msg) + //this message doesn't need accounting + if price == nil { + return nil + } + //evaluate the price for receiving messages + costToLocalNode := price.For(Receiver, size) + //do the accounting + err := ah.Add(costToLocalNode, peer) + //record metrics: just increase counters for user-facing metrics + ah.doMetrics(costToLocalNode, size, err) + return err +} + +//record some metrics +//this is not an error handling. `err` is returned by both `Send` and `Receive` +//`err` will only be non-nil if a limit has been violated (overdraft), in which case the peer has been dropped. +//if the limit has been violated and `err` is thus not nil: +// * if the price is positive, local node has been credited; thus `err` implicitly signals the REMOTE has been dropped +// * if the price is negative, local node has been debited, thus `err` implicitly signals LOCAL node "overdraft" +func (ah *Accounting) doMetrics(price int64, size uint32, err error) { + if price > 0 { + mBalanceCredit.Inc(price) + mBytesCredit.Inc(int64(size)) + mMsgCredit.Inc(1) + if err != nil { + //increase the number of times a remote node has been dropped due to "overdraft" + mPeerDrops.Inc(1) + } + } else { + mBalanceDebit.Inc(price) + mBytesDebit.Inc(int64(size)) + mMsgDebit.Inc(1) + if err != nil { + //increase the number of times the local node has done an "overdraft" in respect to other nodes + mSelfDrops.Inc(1) + } + } +} diff --git a/p2p/protocols/accounting_simulation_test.go b/p2p/protocols/accounting_simulation_test.go new file mode 100644 index 0000000000..65b737abed --- /dev/null +++ b/p2p/protocols/accounting_simulation_test.go @@ -0,0 +1,310 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package protocols + +import ( + "context" + "flag" + "fmt" + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "github.com/mattn/go-colorable" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rpc" + + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +const ( + content = "123456789" +) + +var ( + nodes = flag.Int("nodes", 30, "number of nodes to create (default 30)") + msgs = flag.Int("msgs", 100, "number of messages sent by node (default 100)") + loglevel = flag.Int("loglevel", 0, "verbosity of logs") + rawlog = flag.Bool("rawlog", false, "remove terminal formatting from logs") +) + +func init() { + flag.Parse() + log.PrintOrigins(true) + log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(!*rawlog)))) +} + +//TestAccountingSimulation runs a p2p/simulations simulation +//It creates a *nodes number of nodes, connects each one with each other, +//then sends out a random selection of messages up to *msgs amount of messages +//from the test protocol spec. +//The spec has some accounted messages defined through the Prices interface. +//The test does accounting for all the message exchanged, and then checks +//that every node has the same balance with a peer, but with opposite signs. +//Balance(AwithB) = 0 - Balance(BwithA) or Abs|Balance(AwithB)| == Abs|Balance(BwithA)| +func TestAccountingSimulation(t *testing.T) { + //setup the balances objects for every node + bal := newBalances(*nodes) + //define the node.Service for this test + services := adapters.Services{ + "accounting": func(ctx *adapters.ServiceContext) (node.Service, error) { + return bal.newNode(), nil + }, + } + //setup the simulation + adapter := adapters.NewSimAdapter(services) + net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{DefaultService: "accounting"}) + defer net.Shutdown() + + // we send msgs messages per node, wait for all messages to arrive + bal.wg.Add(*nodes * *msgs) + trigger := make(chan enode.ID) + go func() { + // wait for all of them to arrive + bal.wg.Wait() + // then trigger a check + // the selected node for the trigger is irrelevant, + // we just want to trigger the end of the simulation + trigger <- net.Nodes[0].ID() + }() + + // create nodes and start them + for i := 0; i < *nodes; i++ { + conf := adapters.RandomNodeConfig() + bal.id2n[conf.ID] = i + if _, err := net.NewNodeWithConfig(conf); err != nil { + t.Fatal(err) + } + if err := net.Start(conf.ID); err != nil { + t.Fatal(err) + } + } + // fully connect nodes + for i, n := range net.Nodes { + for _, m := range net.Nodes[i+1:] { + if err := net.Connect(n.ID(), m.ID()); err != nil { + t.Fatal(err) + } + } + } + + // empty action + action := func(ctx context.Context) error { + return nil + } + // check always checks out + check := func(ctx context.Context, id enode.ID) (bool, error) { + return true, nil + } + + // run simulation + timeout := 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + result := simulations.NewSimulation(net).Run(ctx, &simulations.Step{ + Action: action, + Trigger: trigger, + Expect: &simulations.Expectation{ + Nodes: []enode.ID{net.Nodes[0].ID()}, + Check: check, + }, + }) + + if result.Error != nil { + t.Fatal(result.Error) + } + + // check if balance matrix is symmetric + if err := bal.symmetric(); err != nil { + t.Fatal(err) + } +} + +// matrix is a matrix of nodes and its balances +// matrix is in fact a linear array of size n*n, +// so the balance for any node A with B is at index +// A*n + B, while the balance of node B with A is at +// B*n + A +// (n entries in the array will not be filled - +// the balance of a node with itself) +type matrix struct { + n int //number of nodes + m []int64 //array of balances +} + +// create a new matrix +func newMatrix(n int) *matrix { + return &matrix{ + n: n, + m: make([]int64, n*n), + } +} + +// called from the testBalance's Add accounting function: register balance change +func (m *matrix) add(i, j int, v int64) error { + // index for the balance of local node i with remote nodde j is + // i * number of nodes + remote node + mi := i*m.n + j + // register that balance + m.m[mi] += v + return nil +} + +// check that the balances are symmetric: +// balance of node i with node j is the same as j with i but with inverted signs +func (m *matrix) symmetric() error { + //iterate all nodes + for i := 0; i < m.n; i++ { + //iterate starting +1 + for j := i + 1; j < m.n; j++ { + log.Debug("bal", "1", i, "2", j, "i,j", m.m[i*m.n+j], "j,i", m.m[j*m.n+i]) + if m.m[i*m.n+j] != -m.m[j*m.n+i] { + return fmt.Errorf("value mismatch. m[%v, %v] = %v; m[%v, %v] = %v", i, j, m.m[i*m.n+j], j, i, m.m[j*m.n+i]) + } + } + } + return nil +} + +// all the balances +type balances struct { + i int + *matrix + id2n map[enode.ID]int + wg *sync.WaitGroup +} + +func newBalances(n int) *balances { + return &balances{ + matrix: newMatrix(n), + id2n: make(map[enode.ID]int), + wg: &sync.WaitGroup{}, + } +} + +// create a new testNode for every node created as part of the service +func (b *balances) newNode() *testNode { + defer func() { b.i++ }() + return &testNode{ + bal: b, + i: b.i, + peers: make([]*testPeer, b.n), //a node will be connected to n-1 peers + } +} + +type testNode struct { + bal *balances + i int + lock sync.Mutex + peers []*testPeer + peerCount int +} + +// do the accounting for the peer's test protocol +// testNode implements protocols.Balance +func (t *testNode) Add(a int64, p *Peer) error { + //get the index for the remote peer + remote := t.bal.id2n[p.ID()] + log.Debug("add", "local", t.i, "remote", remote, "amount", a) + return t.bal.add(t.i, remote, a) +} + +//run the p2p protocol +//for every node, represented by testNode, create a remote testPeer +func (t *testNode) run(p *p2p.Peer, rw p2p.MsgReadWriter) error { + spec := createTestSpec() + //create accounting hook + spec.Hook = NewAccounting(t, &dummyPrices{}) + + //create a peer for this node + tp := &testPeer{NewPeer(p, rw, spec), t.i, t.bal.id2n[p.ID()], t.bal.wg} + t.lock.Lock() + t.peers[t.bal.id2n[p.ID()]] = tp + t.peerCount++ + if t.peerCount == t.bal.n-1 { + //when all peer connections are established, start sending messages from this peer + go t.send() + } + t.lock.Unlock() + return tp.Run(tp.handle) +} + +// p2p message receive handler function +func (tp *testPeer) handle(ctx context.Context, msg interface{}) error { + tp.wg.Done() + log.Debug("receive", "from", tp.remote, "to", tp.local, "type", reflect.TypeOf(msg), "msg", msg) + return nil +} + +type testPeer struct { + *Peer + local, remote int + wg *sync.WaitGroup +} + +func (t *testNode) send() { + log.Debug("start sending") + for i := 0; i < *msgs; i++ { + //determine randomly to which peer to send + whom := rand.Intn(t.bal.n - 1) + if whom >= t.i { + whom++ + } + t.lock.Lock() + p := t.peers[whom] + t.lock.Unlock() + + //determine a random message from the spec's messages to be sent + which := rand.Intn(len(p.spec.Messages)) + msg := p.spec.Messages[which] + switch msg.(type) { + case *perBytesMsgReceiverPays: + msg = &perBytesMsgReceiverPays{Content: content[:rand.Intn(len(content))]} + case *perBytesMsgSenderPays: + msg = &perBytesMsgSenderPays{Content: content[:rand.Intn(len(content))]} + } + log.Debug("send", "from", t.i, "to", whom, "type", reflect.TypeOf(msg), "msg", msg) + p.Send(context.TODO(), msg) + } +} + +// define the protocol +func (t *testNode) Protocols() []p2p.Protocol { + return []p2p.Protocol{{ + Length: 100, + Run: t.run, + }} +} + +func (t *testNode) APIs() []rpc.API { + return nil +} + +func (t *testNode) Start(server *p2p.Server) error { + return nil +} + +func (t *testNode) Stop() error { + return nil +} diff --git a/p2p/protocols/accounting_test.go b/p2p/protocols/accounting_test.go new file mode 100644 index 0000000000..3810ae2c9b --- /dev/null +++ b/p2p/protocols/accounting_test.go @@ -0,0 +1,223 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package protocols + +import ( + "testing" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" +) + +//dummy Balance implementation +type dummyBalance struct { + amount int64 + peer *Peer +} + +//dummy Prices implementation +type dummyPrices struct{} + +//a dummy message which needs size based accounting +//sender pays +type perBytesMsgSenderPays struct { + Content string +} + +//a dummy message which needs size based accounting +//receiver pays +type perBytesMsgReceiverPays struct { + Content string +} + +//a dummy message which is paid for per unit +//sender pays +type perUnitMsgSenderPays struct{} + +//receiver pays +type perUnitMsgReceiverPays struct{} + +//a dummy message which has zero as its price +type zeroPriceMsg struct{} + +//a dummy message which has no accounting +type nilPriceMsg struct{} + +//return the price for the defined messages +func (d *dummyPrices) Price(msg interface{}) *Price { + switch msg.(type) { + //size based message cost, receiver pays + case *perBytesMsgReceiverPays: + return &Price{ + PerByte: true, + Value: uint64(100), + Payer: Receiver, + } + //size based message cost, sender pays + case *perBytesMsgSenderPays: + return &Price{ + PerByte: true, + Value: uint64(100), + Payer: Sender, + } + //unitary cost, receiver pays + case *perUnitMsgReceiverPays: + return &Price{ + PerByte: false, + Value: uint64(99), + Payer: Receiver, + } + //unitary cost, sender pays + case *perUnitMsgSenderPays: + return &Price{ + PerByte: false, + Value: uint64(99), + Payer: Sender, + } + case *zeroPriceMsg: + return &Price{ + PerByte: false, + Value: uint64(0), + Payer: Sender, + } + case *nilPriceMsg: + return nil + } + return nil +} + +//dummy accounting implementation, only stores values for later check +func (d *dummyBalance) Add(amount int64, peer *Peer) error { + d.amount = amount + d.peer = peer + return nil +} + +type testCase struct { + msg interface{} + size uint32 + sendResult int64 + recvResult int64 +} + +//lowest level unit test +func TestBalance(t *testing.T) { + //create instances + balance := &dummyBalance{} + prices := &dummyPrices{} + //create the spec + spec := createTestSpec() + //create the accounting hook for the spec + acc := NewAccounting(balance, prices) + //create a peer + id := adapters.RandomNodeConfig().ID + p := p2p.NewPeer(id, "testPeer", nil) + peer := NewPeer(p, &dummyRW{}, spec) + //price depends on size, receiver pays + msg := &perBytesMsgReceiverPays{Content: "testBalance"} + size, _ := rlp.EncodeToBytes(msg) + + testCases := []testCase{ + { + msg, + uint32(len(size)), + int64(len(size) * 100), + int64(len(size) * -100), + }, + { + &perBytesMsgSenderPays{Content: "testBalance"}, + uint32(len(size)), + int64(len(size) * -100), + int64(len(size) * 100), + }, + { + &perUnitMsgSenderPays{}, + 0, + int64(-99), + int64(99), + }, + { + &perUnitMsgReceiverPays{}, + 0, + int64(99), + int64(-99), + }, + { + &zeroPriceMsg{}, + 0, + int64(0), + int64(0), + }, + { + &nilPriceMsg{}, + 0, + int64(0), + int64(0), + }, + } + checkAccountingTestCases(t, testCases, acc, peer, balance, true) + checkAccountingTestCases(t, testCases, acc, peer, balance, false) +} + +func checkAccountingTestCases(t *testing.T, cases []testCase, acc *Accounting, peer *Peer, balance *dummyBalance, send bool) { + for _, c := range cases { + var err error + var expectedResult int64 + //reset balance before every check + balance.amount = 0 + if send { + err = acc.Send(peer, c.size, c.msg) + expectedResult = c.sendResult + } else { + err = acc.Receive(peer, c.size, c.msg) + expectedResult = c.recvResult + } + + checkResults(t, err, balance, peer, expectedResult) + } +} + +func checkResults(t *testing.T, err error, balance *dummyBalance, peer *Peer, result int64) { + if err != nil { + t.Fatal(err) + } + if balance.peer != peer { + t.Fatalf("expected Add to be called with peer %v, got %v", peer, balance.peer) + } + if balance.amount != result { + t.Fatalf("Expected balance to be %d but is %d", result, balance.amount) + } +} + +//create a test spec +func createTestSpec() *Spec { + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + &perBytesMsgReceiverPays{}, + &perBytesMsgSenderPays{}, + &perUnitMsgReceiverPays{}, + &perUnitMsgSenderPays{}, + &zeroPriceMsg{}, + &nilPriceMsg{}, + }, + } + return spec +} diff --git a/p2p/protocols/protocol.go b/p2p/protocols/protocol.go index 615f74b569..7dddd852fa 100644 --- a/p2p/protocols/protocol.go +++ b/p2p/protocols/protocol.go @@ -122,6 +122,16 @@ type WrappedMsg struct { Payload []byte } +//For accounting, the design is to allow the Spec to describe which and how its messages are priced +//To access this functionality, we provide a Hook interface which will call accounting methods +//NOTE: there could be more such (horizontal) hooks in the future +type Hook interface { + //A hook for sending messages + Send(peer *Peer, size uint32, msg interface{}) error + //A hook for receiving messages + Receive(peer *Peer, size uint32, msg interface{}) error +} + // Spec is a protocol specification including its name and version as well as // the types of messages which are exchanged type Spec struct { @@ -141,6 +151,9 @@ type Spec struct { // each message must have a single unique data type Messages []interface{} + //hook for accounting (could be extended to multiple hooks in the future) + Hook Hook + initOnce sync.Once codes map[reflect.Type]uint64 types map[uint64]reflect.Type @@ -274,6 +287,15 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error { Payload: r, } + //if the accounting hook is set, call it + if p.spec.Hook != nil { + err := p.spec.Hook.Send(p, wmsg.Size, msg) + if err != nil { + p.Drop(err) + return err + } + } + code, found := p.spec.GetCode(msg) if !found { return errorf(ErrInvalidMsgType, "%v", code) @@ -336,6 +358,14 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{}) return errorf(ErrDecode, "<= %v: %v", msg, err) } + //if the accounting hook is set, call it + if p.spec.Hook != nil { + err := p.spec.Hook.Receive(p, wmsg.Size, val) + if err != nil { + return err + } + } + // call the registered handler callbacks // a registered callback take the decoded message as argument as an interface // which the handler is supposed to cast to the appropriate type diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index 4755db3e62..2874af48d2 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -17,12 +17,15 @@ package protocols import ( + "bytes" "context" "errors" "fmt" "testing" "time" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" @@ -185,6 +188,169 @@ func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { } } +type dummyHook struct { + peer *Peer + size uint32 + msg interface{} + send bool + err error + waitC chan struct{} +} + +type dummyMsg struct { + Content string +} + +func (d *dummyHook) Send(peer *Peer, size uint32, msg interface{}) error { + d.peer = peer + d.size = size + d.msg = msg + d.send = true + return d.err +} + +func (d *dummyHook) Receive(peer *Peer, size uint32, msg interface{}) error { + d.peer = peer + d.size = size + d.msg = msg + d.send = false + d.waitC <- struct{}{} + return d.err +} + +func TestProtocolHook(t *testing.T) { + testHook := &dummyHook{ + waitC: make(chan struct{}, 1), + } + spec := &Spec{ + Name: "test", + Version: 42, + MaxMsgSize: 10 * 1024, + Messages: []interface{}{ + dummyMsg{}, + }, + Hook: testHook, + } + + runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(p, rw, spec) + ctx := context.TODO() + err := peer.Send(ctx, &dummyMsg{ + Content: "handshake"}) + + if err != nil { + t.Fatal(err) + } + + handle := func(ctx context.Context, msg interface{}) error { + return nil + } + + return peer.Run(handle) + } + + conf := adapters.RandomNodeConfig() + tester := p2ptest.NewProtocolTester(t, conf.ID, 2, runFunc) + err := tester.TestExchanges(p2ptest.Exchange{ + Expects: []p2ptest.Expect{ + { + Code: 0, + Msg: &dummyMsg{Content: "handshake"}, + Peer: tester.Nodes[0].ID(), + }, + }, + }) + if err != nil { + t.Fatal(err) + } + if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" { + t.Fatal("Expected msg to be set, but it is not") + } + if !testHook.send { + t.Fatal("Expected a send message, but it is not") + } + if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[0].ID() { + t.Fatal("Expected peer ID to be set correctly, but it is not") + } + if testHook.size != 11 { //11 is the length of the encoded message + t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size) + } + + err = tester.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: &dummyMsg{Content: "response"}, + Peer: tester.Nodes[1].ID(), + }, + }, + }) + + <-testHook.waitC + + if err != nil { + t.Fatal(err) + } + if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" { + t.Fatal("Expected msg to be set, but it is not") + } + if testHook.send { + t.Fatal("Expected a send message, but it is not") + } + if testHook.peer == nil || testHook.peer.ID() != tester.Nodes[1].ID() { + t.Fatal("Expected peer ID to be set correctly, but it is not") + } + if testHook.size != 10 { //11 is the length of the encoded message + t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size) + } + + testHook.err = fmt.Errorf("dummy error") + err = tester.TestExchanges(p2ptest.Exchange{ + Triggers: []p2ptest.Trigger{ + { + Code: 0, + Msg: &dummyMsg{Content: "response"}, + Peer: tester.Nodes[1].ID(), + }, + }, + }) + + <-testHook.waitC + + time.Sleep(100 * time.Millisecond) + err = tester.TestDisconnected(&p2ptest.Disconnect{tester.Nodes[1].ID(), testHook.err}) + if err != nil { + t.Fatalf("Expected a specific disconnect error, but got different one: %v", err) + } + +} + +//We need to test that if the hook is not defined, then message infrastructure +//(send,receive) still works +func TestNoHook(t *testing.T) { + //create a test spec + spec := createTestSpec() + //a random node + id := adapters.RandomNodeConfig().ID + //a peer + p := p2p.NewPeer(id, "testPeer", nil) + rw := &dummyRW{} + peer := NewPeer(p, rw, spec) + ctx := context.TODO() + msg := &perBytesMsgSenderPays{Content: "testBalance"} + //send a message + err := peer.Send(ctx, msg) + if err != nil { + t.Fatal(err) + } + //simulate receiving a message + rw.msg = msg + peer.handleIncoming(func(ctx context.Context, msg interface{}) error { + return nil + }) + //all should just work and not result in any error +} + func TestProtoHandshakeVersionMismatch(t *testing.T) { runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) } @@ -386,3 +552,39 @@ func XTestMultiplePeersDropOther(t *testing.T) { fmt.Errorf("subprotocol error"), ) } + +//dummy implementation of a MsgReadWriter +//this allows for quick and easy unit tests without +//having to build up the complete protocol +type dummyRW struct { + msg interface{} + size uint32 + code uint64 +} + +func (d *dummyRW) WriteMsg(msg p2p.Msg) error { + return nil +} + +func (d *dummyRW) ReadMsg() (p2p.Msg, error) { + enc := bytes.NewReader(d.getDummyMsg()) + return p2p.Msg{ + Code: d.code, + Size: d.size, + Payload: enc, + ReceivedAt: time.Now(), + }, nil +} + +func (d *dummyRW) getDummyMsg() []byte { + r, _ := rlp.EncodeToBytes(d.msg) + var b bytes.Buffer + wmsg := WrappedMsg{ + Context: b.Bytes(), + Size: uint32(len(r)), + Payload: r, + } + rr, _ := rlp.EncodeToBytes(wmsg) + d.size = uint32(len(rr)) + return rr +}