diff --git a/rpc/subscription.go b/rpc/subscription.go index 6ce7befa1..6bbb6f75d 100644 --- a/rpc/subscription.go +++ b/rpc/subscription.go @@ -52,9 +52,10 @@ type notifierKey struct{} // Server callbacks use the notifier to send notifications. type Notifier struct { codec ServerCodec - subMu sync.RWMutex // guards active and inactive maps + subMu sync.Mutex active map[ID]*Subscription inactive map[ID]*Subscription + buffer map[ID][]interface{} // unsent notifications of inactive subscriptions } // newNotifier creates a new notifier that can be used to send subscription @@ -64,6 +65,7 @@ func newNotifier(codec ServerCodec) *Notifier { codec: codec, active: make(map[ID]*Subscription), inactive: make(map[ID]*Subscription), + buffer: make(map[ID][]interface{}), } } @@ -88,20 +90,26 @@ func (n *Notifier) CreateSubscription() *Subscription { // Notify sends a notification to the client with the given data as payload. // If an error occurs the RPC connection is closed and the error is returned. func (n *Notifier) Notify(id ID, data interface{}) error { - n.subMu.RLock() - defer n.subMu.RUnlock() + n.subMu.Lock() + defer n.subMu.Unlock() - sub, active := n.active[id] - if active { - notification := n.codec.CreateNotification(string(id), sub.namespace, data) - if err := n.codec.Write(notification); err != nil { - n.codec.Close() - return err - } + if sub, active := n.active[id]; active { + n.send(sub, data) + } else { + n.buffer[id] = append(n.buffer[id], data) } return nil } +func (n *Notifier) send(sub *Subscription, data interface{}) error { + notification := n.codec.CreateNotification(string(sub.ID), sub.namespace, data) + err := n.codec.Write(notification) + if err != nil { + n.codec.Close() + } + return err +} + // Closed returns a channel that is closed when the RPC connection is closed. func (n *Notifier) Closed() <-chan interface{} { return n.codec.Closed() @@ -127,9 +135,15 @@ func (n *Notifier) unsubscribe(id ID) error { func (n *Notifier) activate(id ID, namespace string) { n.subMu.Lock() defer n.subMu.Unlock() + if sub, found := n.inactive[id]; found { sub.namespace = namespace n.active[id] = sub delete(n.inactive, id) + // Send buffered notifications. + for _, data := range n.buffer[id] { + n.send(sub, data) + } + delete(n.buffer, id) } } diff --git a/rpc/subscription_test.go b/rpc/subscription_test.go index 0ba177e63..24febc919 100644 --- a/rpc/subscription_test.go +++ b/rpc/subscription_test.go @@ -27,9 +27,8 @@ import ( ) type NotificationTestService struct { - mu sync.Mutex - unsubscribed bool - + mu sync.Mutex + unsubscribed chan string gotHangSubscriptionReq chan struct{} unblockHangSubscription chan struct{} } @@ -38,16 +37,10 @@ func (s *NotificationTestService) Echo(i int) int { return i } -func (s *NotificationTestService) wasUnsubCallbackCalled() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.unsubscribed -} - func (s *NotificationTestService) Unsubscribe(subid string) { - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() + if s.unsubscribed != nil { + s.unsubscribed <- subid + } } func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { @@ -65,7 +58,6 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i // test expects n events, if we begin sending event immediately some events // will probably be dropped since the subscription ID might not be send to // the client. - time.Sleep(5 * time.Second) for i := 0; i < n; i++ { if err := notifier.Notify(subscription.ID, val+i); err != nil { return @@ -74,13 +66,10 @@ func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val i select { case <-notifier.Closed(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() case <-subscription.Err(): - s.mu.Lock() - s.unsubscribed = true - s.mu.Unlock() + } + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) } }() @@ -107,7 +96,7 @@ func (s *NotificationTestService) HangSubscription(ctx context.Context, val int) func TestNotifications(t *testing.T) { server := NewServer() - service := &NotificationTestService{} + service := &NotificationTestService{unsubscribed: make(chan string)} if err := server.RegisterName("eth", service); err != nil { t.Fatalf("unable to register test service %v", err) @@ -157,10 +146,10 @@ func TestNotifications(t *testing.T) { } clientConn.Close() // causes notification unsubscribe callback to be called - time.Sleep(1 * time.Second) - - if !service.wasUnsubCallbackCalled() { - t.Error("unsubscribe callback not called after closing connection") + select { + case <-service.unsubscribed: + case <-time.After(1 * time.Second): + t.Fatal("Unsubscribe not called after one second") } } @@ -227,18 +216,19 @@ func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSucces // for multiple different namespaces. func TestSubscriptionMultipleNamespaces(t *testing.T) { var ( - namespaces = []string{"eth", "shh", "bzz"} + namespaces = []string{"eth", "shh", "bzz"} + service = NotificationTestService{} + subCount = len(namespaces) * 2 + notificationCount = 3 + server = NewServer() - service = NotificationTestService{} clientConn, serverConn = net.Pipe() - - out = json.NewEncoder(clientConn) - in = json.NewDecoder(clientConn) - successes = make(chan jsonSuccessResponse) - failures = make(chan jsonErrResponse) - notifications = make(chan jsonNotification) - - errors = make(chan error, 10) + out = json.NewEncoder(clientConn) + in = json.NewDecoder(clientConn) + successes = make(chan jsonSuccessResponse) + failures = make(chan jsonErrResponse) + notifications = make(chan jsonNotification) + errors = make(chan error, 10) ) // setup and start server @@ -255,13 +245,12 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { go waitForMessages(t, in, successes, failures, notifications, errors) // create subscriptions one by one - n := 3 for i, namespace := range namespaces { request := map[string]interface{}{ "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, + "params": []interface{}{"someSubscription", notificationCount, i}, } if err := out.Encode(&request); err != nil { @@ -276,7 +265,7 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { "id": i, "method": fmt.Sprintf("%s_subscribe", namespace), "version": "2.0", - "params": []interface{}{"someSubscription", n, i}, + "params": []interface{}{"someSubscription", notificationCount, i}, }) } @@ -285,46 +274,40 @@ func TestSubscriptionMultipleNamespaces(t *testing.T) { } timeout := time.After(30 * time.Second) - subids := make(map[string]string, 2*len(namespaces)) - count := make(map[string]int, 2*len(namespaces)) - - for { - done := true - for id := range count { - if count, found := count[id]; !found || count < (2*n) { + subids := make(map[string]string, subCount) + count := make(map[string]int, subCount) + allReceived := func() bool { + done := len(count) == subCount + for _, c := range count { + if c < notificationCount { done = false } } + return done + } - if done && len(count) == len(namespaces) { - break - } - + for !allReceived() { select { - case err := <-errors: - t.Fatal(err) case suc := <-successes: // subscription created subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string) + case notification := <-notifications: + count[notification.Params.Subscription]++ + case err := <-errors: + t.Fatal(err) case failure := <-failures: t.Errorf("received error: %v", failure.Error) - case notification := <-notifications: - if cnt, found := count[notification.Params.Subscription]; found { - count[notification.Params.Subscription] = cnt + 1 - } else { - count[notification.Params.Subscription] = 1 - } case <-timeout: for _, namespace := range namespaces { subid, found := subids[namespace] if !found { - t.Errorf("Subscription for '%s' not created", namespace) + t.Errorf("subscription for %q not created", namespace) continue } - if count, found := count[subid]; !found || count < n { - t.Errorf("Didn't receive all notifications (%d<%d) in time for namespace '%s'", count, n, namespace) + if count, found := count[subid]; !found || count < notificationCount { + t.Errorf("didn't receive all notifications (%d<%d) in time for namespace %q", count, notificationCount, namespace) } } - return + t.Fatal("timed out") } } }