diff --git a/swarm/network/simulation/node.go b/swarm/network/simulation/node.go index 46c2bb8660..f66b0afd0d 100644 --- a/swarm/network/simulation/node.go +++ b/swarm/network/simulation/node.go @@ -234,9 +234,9 @@ func (s *Simulation) UploadSnapshot(ctx context.Context, snapshotFile string, op if err != nil { return err } - defer f.Close() jsonbyte, err := ioutil.ReadAll(f) + f.Close() if err != nil { return err } diff --git a/swarm/pss/prox_test.go b/swarm/pss/prox_test.go index bc32e612d3..908a0d3302 100644 --- a/swarm/pss/prox_test.go +++ b/swarm/pss/prox_test.go @@ -4,10 +4,7 @@ import ( "context" "crypto/ecdsa" "encoding/binary" - "errors" "fmt" - "strconv" - "strings" "sync" "testing" "time" @@ -39,24 +36,20 @@ type handlerNotification struct { } type testData struct { - mu sync.Mutex - sim *simulation.Simulation - handlerDone bool // set to true on termination of the simulation run - requiredMessages int - allowedMessages int - messageCount int - kademlias map[enode.ID]*network.Kademlia - nodeAddrs map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids - recipients map[int][]enode.ID // for logging output only - allowed map[int][]enode.ID // allowed recipients - expectedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive - allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive - senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood) - handlerC chan handlerNotification // passes message from pss message handler to simulation driver - doneC chan struct{} // terminates the handler channel listener - errC chan error // error to pass to main sim thread - msgC chan handlerNotification // message receipt notification to main sim thread - msgs [][]byte // recipient addresses of messages + sim *simulation.Simulation + kademlias map[enode.ID]*network.Kademlia + nodeAddresses map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids + senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood) + recipientAddresses [][]byte + + requiredMsgCount int + requiredMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive + allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive + + notifications []handlerNotification // notification queue + totalMsgCount int + handlerDone bool // set to true on termination of the simulation run + mu sync.Mutex } var ( @@ -64,67 +57,60 @@ var ( topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82}) ) -func (d *testData) getMsgCount() int { - d.mu.Lock() - defer d.mu.Unlock() - return d.messageCount +func (td *testData) pushNotification(val handlerNotification) { + td.mu.Lock() + td.notifications = append(td.notifications, val) + td.mu.Unlock() } -func (d *testData) incrementMsgCount() int { - d.mu.Lock() - defer d.mu.Unlock() - d.messageCount++ - return d.messageCount -} - -func (d *testData) isDone() bool { - d.mu.Lock() - defer d.mu.Unlock() - return d.handlerDone -} - -func (d *testData) setDone() { - d.mu.Lock() - defer d.mu.Unlock() - d.handlerDone = true -} - -func getCmdParams(t *testing.T) (int, int, time.Duration) { - args := strings.Split(t.Name(), "/") - msgCount, err := strconv.ParseInt(args[2], 10, 16) - if err != nil { - t.Fatal(err) +func (td *testData) popNotification() (first handlerNotification, exist bool) { + td.mu.Lock() + if len(td.notifications) > 0 { + exist = true + first = td.notifications[0] + td.notifications = td.notifications[1:] } - nodeCount, err := strconv.ParseInt(args[1], 10, 16) - if err != nil { - t.Fatal(err) - } - timeoutStr := fmt.Sprintf("%ss", args[3]) - timeoutDur, err := time.ParseDuration(timeoutStr) - if err != nil { - t.Fatal(err) - } - return int(msgCount), int(nodeCount), timeoutDur + td.mu.Unlock() + return first, exist +} + +func (td *testData) getMsgCount() int { + td.mu.Lock() + defer td.mu.Unlock() + return td.totalMsgCount +} + +func (td *testData) incrementMsgCount() int { + td.mu.Lock() + defer td.mu.Unlock() + td.totalMsgCount++ + return td.totalMsgCount +} + +func (td *testData) isDone() bool { + td.mu.Lock() + defer td.mu.Unlock() + return td.handlerDone +} + +func (td *testData) setDone() { + td.mu.Lock() + defer td.mu.Unlock() + td.handlerDone = true } func newTestData() *testData { return &testData{ - kademlias: make(map[enode.ID]*network.Kademlia), - nodeAddrs: make(map[enode.ID][]byte), - recipients: make(map[int][]enode.ID), - allowed: make(map[int][]enode.ID), - expectedMsgs: make(map[enode.ID][]uint64), - allowedMsgs: make(map[enode.ID][]uint64), - senders: make(map[int]enode.ID), - handlerC: make(chan handlerNotification), - doneC: make(chan struct{}), - errC: make(chan error), - msgC: make(chan handlerNotification), + kademlias: make(map[enode.ID]*network.Kademlia), + nodeAddresses: make(map[enode.ID][]byte), + requiredMsgs: make(map[enode.ID][]uint64), + allowedMsgs: make(map[enode.ID][]uint64), + senders: make(map[int]enode.ID), } } -func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) { - kadif, ok := d.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia) +func (td *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) { + kadif, ok := td.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia) if !ok { return nil, fmt.Errorf("no kademlia entry for %v", nodeId) } @@ -135,29 +121,29 @@ func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) { return kad, nil } -func (d *testData) init(msgCount int) error { +func (td *testData) init(msgCount int) error { log.Debug("TestProxNetwork start") - for _, nodeId := range d.sim.NodeIDs() { - kad, err := d.getKademlia(&nodeId) + for _, nodeId := range td.sim.NodeIDs() { + kad, err := td.getKademlia(&nodeId) if err != nil { return err } - d.nodeAddrs[nodeId] = kad.BaseAddr() + td.nodeAddresses[nodeId] = kad.BaseAddr() } for i := 0; i < int(msgCount); i++ { msgAddr := pot.RandomAddress() // we choose message addresses randomly - d.msgs = append(d.msgs, msgAddr.Bytes()) + td.recipientAddresses = append(td.recipientAddresses, msgAddr.Bytes()) smallestPo := 256 var targets []enode.ID var closestPO int // loop through all nodes and find the required and allowed recipients of each message // (for more information, please see the comment to the main test function) - for _, nod := range d.sim.Net.GetNodes() { - po, _ := pof(d.msgs[i], d.nodeAddrs[nod.ID()], 0) - depth := d.kademlias[nod.ID()].NeighbourhoodDepth() + for _, nod := range td.sim.Net.GetNodes() { + po, _ := pof(td.recipientAddresses[i], td.nodeAddresses[nod.ID()], 0) + depth := td.kademlias[nod.ID()].NeighbourhoodDepth() // only nodes with closest IDs (wrt the msg address) will be required recipients if po > closestPO { @@ -169,28 +155,25 @@ func (d *testData) init(msgCount int) error { } if po >= depth { - d.allowedMessages++ - d.allowed[i] = append(d.allowed[i], nod.ID()) - d.allowedMsgs[nod.ID()] = append(d.allowedMsgs[nod.ID()], uint64(i)) + td.allowedMsgs[nod.ID()] = append(td.allowedMsgs[nod.ID()], uint64(i)) } // a node with the smallest PO (wrt msg) will be the sender, // in order to increase the distance the msg must travel if po < smallestPo { smallestPo = po - d.senders[i] = nod.ID() + td.senders[i] = nod.ID() } } - d.requiredMessages += len(targets) + td.requiredMsgCount += len(targets) for _, id := range targets { - d.recipients[i] = append(d.recipients[i], id) - d.expectedMsgs[id] = append(d.expectedMsgs[id], uint64(i)) + td.requiredMsgs[id] = append(td.requiredMsgs[id], uint64(i)) } - log.Debug("nn for msg", "targets", len(d.recipients[i]), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", d.senders[i], "senderpo", smallestPo) + log.Debug("nn for msg", "targets", len(targets), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", td.senders[i], "senderpo", smallestPo) } - log.Debug("msgs to receive", "count", d.requiredMessages) + log.Debug("recipientAddresses to receive", "count", td.requiredMsgCount) return nil } @@ -213,144 +196,161 @@ func (d *testData) init(msgCount int) error { // nodes Y and Z will be considered required recipients of the msg, // whereas nodes X, Y and Z will be allowed recipients. func TestProxNetwork(t *testing.T) { - t.Run("16/16/15", testProxNetwork) + t.Run("16_nodes,_16_messages,_16_seconds", func(t *testing.T) { + testProxNetwork(t, 16, 16, 16*time.Second) + }) } -// params in run name: nodes/msgs func TestProxNetworkLong(t *testing.T) { if !*longrunning { t.Skip("run with --longrunning flag to run extensive network tests") } - t.Run("8/100/30", testProxNetwork) - t.Run("16/100/30", testProxNetwork) - t.Run("32/100/60", testProxNetwork) - t.Run("64/100/60", testProxNetwork) - t.Run("128/100/120", testProxNetwork) + t.Run("8_nodes,_100_messages,_30_seconds", func(t *testing.T) { + testProxNetwork(t, 8, 100, 30*time.Second) + }) + t.Run("16_nodes,_100_messages,_30_seconds", func(t *testing.T) { + testProxNetwork(t, 16, 100, 30*time.Second) + }) + t.Run("32_nodes,_100_messages,_60_seconds", func(t *testing.T) { + testProxNetwork(t, 32, 100, 1*time.Minute) + }) + t.Run("64_nodes,_100_messages,_60_seconds", func(t *testing.T) { + testProxNetwork(t, 64, 100, 1*time.Minute) + }) + t.Run("128_nodes,_100_messages,_120_seconds", func(t *testing.T) { + testProxNetwork(t, 128, 100, 2*time.Minute) + }) } -func testProxNetwork(t *testing.T) { - tstdata := newTestData() - msgCount, nodeCount, timeout := getCmdParams(t) +func testProxNetwork(t *testing.T, nodeCount int, msgCount int, timeout time.Duration) { + td := newTestData() handlerContextFuncs := make(map[Topic]handlerContextFunc) handlerContextFuncs[topic] = nodeMsgHandler - services := newProxServices(tstdata, true, handlerContextFuncs, tstdata.kademlias) - tstdata.sim = simulation.New(services) - defer tstdata.sim.Close() + services := newProxServices(td, true, handlerContextFuncs, td.kademlias) + td.sim = simulation.New(services) + defer td.sim.Close() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount) - err := tstdata.sim.UploadSnapshot(ctx, filename) + err := td.sim.UploadSnapshot(ctx, filename) if err != nil { t.Fatal(err) } - err = tstdata.init(msgCount) // initialize the test data + err = td.init(msgCount) // initialize the test data if err != nil { t.Fatal(err) } wrapper := func(c context.Context, _ *simulation.Simulation) error { - return testRoutine(tstdata, c) + return testRoutine(td, c) } - result := tstdata.sim.Run(ctx, wrapper) // call the main test function + result := td.sim.Run(ctx, wrapper) // call the main test function if result.Error != nil { - // context deadline exceeded - // however, it might just mean that not all possible messages are received - // now we must check if all required messages are received - cnt := tstdata.getMsgCount() - log.Debug("TestProxNetwork finished", "rcv", cnt) - if cnt < tstdata.requiredMessages { + timedOut := result.Error == context.DeadlineExceeded + if !timedOut || td.getMsgCount() < td.requiredMsgCount { t.Fatal(result.Error) } } - t.Logf("completed %d", result.Duration) } -func (tstdata *testData) sendAllMsgs() { - for i, msg := range tstdata.msgs { - log.Debug("sending msg", "idx", i, "from", tstdata.senders[i]) - nodeClient, err := tstdata.sim.Net.GetNode(tstdata.senders[i]).Client() +func (td *testData) sendAllMsgs() error { + nodes := make(map[int]*rpc.Client) + for i := range td.recipientAddresses { + nodeClient, err := td.sim.Net.GetNode(td.senders[i]).Client() if err != nil { - tstdata.errC <- err + return err } + nodes[i] = nodeClient + } + + for i, msg := range td.recipientAddresses { + log.Debug("sending msg", "idx", i, "from", td.senders[i]) + nodeClient := nodes[i] var uvarByte [8]byte binary.PutUvarint(uvarByte[:], uint64(i)) nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:])) } - log.Debug("all messages sent") + return nil +} + +func isMoreTimeLeft(ctx context.Context) bool { + select { + case <-ctx.Done(): + return false + default: + return true + } } // testRoutine is the main test function, called by Simulation.Run() -func testRoutine(tstdata *testData, ctx context.Context) error { - go handlerChannelListener(tstdata, ctx) - go tstdata.sendAllMsgs() - received := 0 +func testRoutine(td *testData, ctx context.Context) error { - // collect incoming messages and terminate with corresponding status when message handler listener ends - for { - select { - case err := <-tstdata.errC: - return err - case hn := <-tstdata.msgC: - received++ - log.Debug("msg received", "msgs_received", received, "total_expected", tstdata.requiredMessages, "id", hn.id, "serial", hn.serial) - if received == tstdata.allowedMessages { - close(tstdata.doneC) - return nil - } + hasMoreRound := func(err error, hadMessage bool) bool { + return err == nil && (hadMessage || isMoreTimeLeft(ctx)) + } + + if err := td.sendAllMsgs(); err != nil { + return err + } + + var err error + received := 0 + hadMessage := false + + for oneMoreRound := true; oneMoreRound; oneMoreRound = hasMoreRound(err, hadMessage) { + message, hadMessage := td.popNotification() + + if !isMoreTimeLeft(ctx) { + // Stop handlers from sending more messages. + // Note: only best effort, race is possible. + td.setDone() } + + if hadMessage { + if td.isAllowedMessage(message) { + received++ + log.Debug("msg received", "msgs_received", received, "total_expected", td.requiredMsgCount, "id", message.id, "serial", message.serial) + } else { + err = fmt.Errorf("message %d received by wrong recipient %v", message.serial, message.id) + } + } else { + time.Sleep(32 * time.Millisecond) + } + } + + if err != nil { + return err + } + + if td.getMsgCount() < td.requiredMsgCount { + return ctx.Err() } return nil } -func handlerChannelListener(tstdata *testData, ctx context.Context) { - for { - select { - case <-tstdata.doneC: // graceful exit - tstdata.setDone() - tstdata.errC <- nil - return - - case <-ctx.Done(): // timeout or cancel - tstdata.setDone() - tstdata.errC <- ctx.Err() - return - - // incoming message from pss message handler - case handlerNotification := <-tstdata.handlerC: - // check if recipient has already received all its messages and notify to fail the test if so - aMsgs := tstdata.allowedMsgs[handlerNotification.id] - if len(aMsgs) == 0 { - tstdata.setDone() - tstdata.errC <- fmt.Errorf("too many messages received by recipient %x", handlerNotification.id) - return - } - - // check if message serial is in expected messages for this recipient and notify to fail the test if not - idx := -1 - for i, msg := range aMsgs { - if handlerNotification.serial == msg { - idx = i - break - } - } - if idx == -1 { - tstdata.setDone() - tstdata.errC <- fmt.Errorf("message %d received by wrong recipient %v", handlerNotification.serial, handlerNotification.id) - return - } - - // message is ok, so remove that message serial from the recipient expectation array and notify the main sim thread - aMsgs[idx] = aMsgs[len(aMsgs)-1] - aMsgs = aMsgs[:len(aMsgs)-1] - tstdata.msgC <- handlerNotification +func (td *testData) isAllowedMessage(n handlerNotification) bool { + // check if message serial is in expected messages for this recipient + for _, s := range td.allowedMsgs[n.id] { + if n.serial == s { + return true } } + return false } -func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler { +func (td *testData) removeAllowedMessage(id enode.ID, index int) { + last := len(td.allowedMsgs[id]) - 1 + td.allowedMsgs[id][index] = td.allowedMsgs[id][last] + td.allowedMsgs[id] = td.allowedMsgs[id][:last] +} + +func nodeMsgHandler(td *testData, config *adapters.NodeConfig) *handler { return &handler{ f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { - cnt := tstdata.incrementMsgCount() - log.Debug("nodeMsgHandler rcv", "cnt", cnt) + if td.isDone() { + return nil // terminate if simulation is over + } + + td.incrementMsgCount() // using simple serial in message body, makes it easy to keep track of who's getting what serial, c := binary.Uvarint(msg) @@ -358,15 +358,7 @@ func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler { log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c)) } - if tstdata.isDone() { - return errors.New("handlers aborted") // terminate if simulation is over - } - - // pass message context to the listener in the simulation - tstdata.handlerC <- handlerNotification{ - id: config.ID, - serial: serial, - } + td.pushNotification(handlerNotification{id: config.ID, serial: serial}) return nil }, caps: &handlerCaps{ @@ -378,7 +370,7 @@ func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler { // an adaptation of the same services setup as in pss_test.go // replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package -func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc { +func newProxServices(td *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc { stateStore := state.NewInmemoryStore() kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia { if k, ok := kademlias[id]; ok { @@ -415,6 +407,9 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T UnderlayAddr: addr.Under(), HiveParams: hp, } + bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey) + pskad := kademlia(ctx.Config.ID, bzzKey) + b.Store(simulation.BucketKeyKademlia, pskad) return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil }, "pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) { @@ -434,6 +429,7 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T } bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey) pskad := kademlia(ctx.Config.ID, bzzKey) + b.Store(simulation.BucketKeyKademlia, pskad) ps, err := NewPss(pskad, pssp) if err != nil { return nil, nil, err @@ -442,7 +438,7 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T // register the handlers we've been passed var deregisters []func() for tpc, hndlrFunc := range handlerContextFuncs { - deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(tstdata, ctx.Config))) + deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(td, ctx.Config))) } // if handshake mode is set, add the controller @@ -459,8 +455,6 @@ func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[T Public: false, }) - b.Store(simulation.BucketKeyKademlia, pskad) - // return Pss and cleanups return ps, func() { // run the handler deregister functions in reverse order diff --git a/swarm/pss/pss_test.go b/swarm/pss/pss_test.go index ea7a591b1e..9884ffbe94 100644 --- a/swarm/pss/pss_test.go +++ b/swarm/pss/pss_test.go @@ -1364,7 +1364,7 @@ func TestNetwork(t *testing.T) { } // params in run name: -// nodes/msgs/addrbytes/adaptertype +// nodes/recipientAddresses/addrbytes/adaptertype // if adaptertype is exec uses execadapter, simadapter otherwise func TestNetwork2000(t *testing.T) { if !*longrunning {