diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index a779aeebbe..361285f06e 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -19,13 +19,17 @@ package testing import ( "errors" "fmt" + "sync" "time" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" ) +var errTimedOut = errors.New("timed out") + // ProtocolSession is a quasi simulation of a pivot node running // a service and a number of dummy peers that can send (trigger) or // receive (expect) messages @@ -46,6 +50,7 @@ type Exchange struct { Label string Triggers []Trigger Expects []Expect + Timeout time.Duration } // Trigger is part of the exchange, incoming message for the pivot node @@ -102,78 +107,147 @@ func (self *ProtocolSession) trigger(trig Trigger) error { } // expect checks an expectation of a message sent out by the pivot node -func (self *ProtocolSession) expect(exp Expect) error { - if exp.Msg == nil { - return errors.New("no message to expect") - } - simNode, ok := self.adapter.GetNode(exp.Peer) - if !ok { - return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs)) - } - mockNode, ok := simNode.Services()[0].(*mockNode) - if !ok { - return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer) +func (self *ProtocolSession) expect(exps []Expect) error { + // construct a map of expectations for each node + peerExpects := make(map[discover.NodeID][]Expect) + for _, exp := range exps { + if exp.Msg == nil { + return errors.New("no message to expect") + } + peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) } + // construct a map of mockNodes for each node + mockNodes := make(map[discover.NodeID]*mockNode) + for nodeID := range peerExpects { + simNode, ok := self.adapter.GetNode(nodeID) + if !ok { + return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) + } + mockNode, ok := simNode.Services()[0].(*mockNode) + if !ok { + return fmt.Errorf("trigger: peer %v is not a mock", nodeID) + } + mockNodes[nodeID] = mockNode + } + + // done chanell cancels all created goroutines when function returns + done := make(chan struct{}) + defer close(done) + // errc catches the first error from errc := make(chan error) + + wg := &sync.WaitGroup{} + wg.Add(len(mockNodes)) + for nodeID, mockNode := range mockNodes { + nodeID := nodeID + mockNode := mockNode + go func() { + defer wg.Done() + + // Sum all Expect timeouts to give the maximum + // time for all expectations to finish. + // mockNode.Expect checks all received messages against + // a list of expected messages and timeout for each + // of them can not be checked separately. + var t time.Duration + for _, exp := range peerExpects[nodeID] { + if exp.Timeout == time.Duration(0) { + t += 2000 * time.Millisecond + } else { + t += exp.Timeout + } + } + alarm := time.NewTimer(t) + defer alarm.Stop() + + // expectErrc is used to check if error returned + // from mockNode.Expect is not nil and to send it to + // errc only in that case. + // done channel will be closed when function + expectErrc := make(chan error) + go func() { + select { + case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): + case <-done: + case <-alarm.C: + } + }() + + select { + case err := <-expectErrc: + if err != nil { + select { + case errc <- err: + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + } + case <-done: + case <-alarm.C: + errc <- errTimedOut + } + + }() + } + go func() { - errc <- mockNode.Expect(&exp) + wg.Wait() + // close errc when all goroutines finish to return nill err from errc + close(errc) }() - t := exp.Timeout - if t == time.Duration(0) { - t = 2000 * time.Millisecond - } - select { - case err := <-errc: - return err - case <-time.After(t): - return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer) - } + return <-errc } // TestExchanges tests a series of exchanges against the session func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { - // launch all triggers of this exchanges - - for _, e := range exchanges { - errc := make(chan error, len(e.Triggers)+len(e.Expects)) - for _, trig := range e.Triggers { - errc <- self.trigger(trig) - } - - // each expectation is spawned in separate go-routine - // expectations of an exchange are conjunctive but unordered, i.e., - // only all of them arriving constitutes a pass - // each expectation is meant to be for a different peer, otherwise they are expected to panic - // testing of an exchange blocks until all expectations are decided - // an expectation is decided if - // expected message arrives OR - // an unexpected message arrives (panic) - // times out on their individual timeout - for _, ex := range e.Expects { - // expect msg spawned to separate go routine - go func(exp Expect) { - errc <- self.expect(exp) - }(ex) - } - - // time out globally or finish when all expectations satisfied - timeout := time.After(5 * time.Second) - for i := 0; i < len(e.Triggers)+len(e.Expects); i++ { - select { - case err := <-errc: - if err != nil { - return fmt.Errorf("exchange failed with: %v", err) - } - case <-timeout: - return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label) - } + for i, e := range exchanges { + if err := self.testExchange(e); err != nil { + return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) } + log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) } return nil } +// testExchange tests a single Exchange. +// Default timeout value is 2 seconds. +func (self *ProtocolSession) testExchange(e Exchange) error { + errc := make(chan error) + done := make(chan struct{}) + defer close(done) + + go func() { + for _, trig := range e.Triggers { + err := self.trigger(trig) + if err != nil { + errc <- err + return + } + } + + select { + case errc <- self.expect(e.Expects): + case <-done: + } + }() + + // time out globally or finish when all expectations satisfied + t := e.Timeout + if t == 0 { + t = 2000 * time.Millisecond + } + alarm := time.NewTimer(t) + select { + case err := <-errc: + return err + case <-alarm.C: + return errTimedOut + } +} + // TestDisconnected tests the disconnections given as arguments // the disconnect structs describe what disconnect error is expected on which peer func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go index ea5b106ff8..a797412d60 100644 --- a/p2p/testing/protocoltester.go +++ b/p2p/testing/protocoltester.go @@ -24,7 +24,11 @@ that can be used to send and receive messages package testing import ( + "bytes" "fmt" + "io" + "io/ioutil" + "strings" "sync" "testing" @@ -34,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/simulations" "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" ) @@ -152,7 +157,7 @@ type mockNode struct { testNode trigger chan *Trigger - expect chan *Expect + expect chan []Expect err chan error stop chan struct{} stopOnce sync.Once @@ -161,7 +166,7 @@ type mockNode struct { func newMockNode() *mockNode { mock := &mockNode{ trigger: make(chan *Trigger), - expect: make(chan *Expect), + expect: make(chan []Expect), err: make(chan error), stop: make(chan struct{}), } @@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { select { case trig := <-m.trigger: m.err <- p2p.Send(rw, trig.Code, trig.Msg) - case exp := <-m.expect: - m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg) + case exps := <-m.expect: + m.err <- expectMsgs(rw, exps) case <-m.stop: return nil } @@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error { return <-m.err } -func (m *mockNode) Expect(exp *Expect) error { +func (m *mockNode) Expect(exp ...Expect) error { m.expect <- exp return <-m.err } @@ -198,3 +203,67 @@ func (m *mockNode) Stop() error { m.stopOnce.Do(func() { close(m.stop) }) return nil } + +func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { + matched := make([]bool, len(exps)) + for { + msg, err := rw.ReadMsg() + if err != nil { + if err == io.EOF { + break + } + return err + } + actualContent, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return err + } + var found bool + for i, exp := range exps { + if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { + if matched[i] { + return fmt.Errorf("message #%d received two times", i) + } + matched[i] = true + found = true + break + } + } + if !found { + expected := make([]string, 0) + for i, exp := range exps { + if matched[i] { + continue + } + expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) + } + return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) + } + done := true + for _, m := range matched { + if !m { + done = false + break + } + } + if done { + return nil + } + } + for i, m := range matched { + if !m { + return fmt.Errorf("expected message #%d not received", i) + } + } + return nil +} + +// mustEncodeMsg uses rlp to encode a message. +// In case of error it panics. +func mustEncodeMsg(msg interface{}) []byte { + contentEnc, err := rlp.EncodeToBytes(msg) + if err != nil { + panic("content encode error: " + err.Error()) + } + return contentEnc +}