diff --git a/event/event.go b/event/event.go new file mode 100644 index 0000000000..74f8043da8 --- /dev/null +++ b/event/event.go @@ -0,0 +1,162 @@ +// Package event implements an event multiplexer. +package event + +import ( + "errors" + "reflect" + "sync" +) + +type Subscription interface { + Chan() <-chan interface{} + Unsubscribe() +} + +// A TypeMux dispatches events to registered receivers. Receivers can be +// registered to handle events of certain type. Any operation +// called after mux is stopped will return ErrMuxClosed. +type TypeMux struct { + mutex sync.RWMutex + subm map[reflect.Type][]*muxsub + stopped bool +} + +var ErrMuxClosed = errors.New("event: mux closed") + +// NewTypeMux creates a running mux. +func NewTypeMux() *TypeMux { + return &TypeMux{subm: make(map[reflect.Type][]*muxsub)} +} + +// Subscribe creates a subscription for events of the given types. The +// subscription's channel is closed when it is unsubscribed +// or the mux is closed. +func (mux *TypeMux) Subscribe(types ...interface{}) Subscription { + sub := newsub(mux) + mux.mutex.Lock() + if mux.stopped { + mux.mutex.Unlock() + close(sub.postC) + } else { + for _, t := range types { + rtyp := reflect.TypeOf(t) + oldsubs := mux.subm[rtyp] + subs := make([]*muxsub, len(oldsubs)+1) + copy(subs, oldsubs) + subs[len(oldsubs)] = sub + mux.subm[rtyp] = subs + } + mux.mutex.Unlock() + } + return sub +} + +// Post sends an event to all receivers registered for the given type. +// It returns ErrMuxClosed if the mux has been stopped. +func (mux *TypeMux) Post(ev interface{}) error { + rtyp := reflect.TypeOf(ev) + mux.mutex.RLock() + if mux.stopped { + mux.mutex.RUnlock() + return ErrMuxClosed + } + subs := mux.subm[rtyp] + mux.mutex.RUnlock() + for _, sub := range subs { + sub.deliver(ev) + } + return nil +} + +// Stop closes a mux. The mux can no longer be used. +// Future Post calls will fail with ErrMuxClosed. +// Stop blocks until all current deliveries have finished. +func (mux *TypeMux) Stop() { + mux.mutex.Lock() + for _, subs := range mux.subm { + for _, sub := range subs { + sub.closewait() + } + } + mux.subm = nil + mux.stopped = true + mux.mutex.Unlock() +} + +func (mux *TypeMux) del(s *muxsub) { + mux.mutex.Lock() + for typ, subs := range mux.subm { + if pos := find(subs, s); pos >= 0 { + if len(subs) == 1 { + delete(mux.subm, typ) + } else { + mux.subm[typ] = posdelete(subs, pos) + } + } + } + s.mux.mutex.Unlock() +} + +func find(slice []*muxsub, item *muxsub) int { + for i, v := range slice { + if v == item { + return i + } + } + return -1 +} + +func posdelete(slice []*muxsub, pos int) []*muxsub { + news := make([]*muxsub, len(slice)-1) + copy(news[:pos], slice[:pos]) + copy(news[pos:], slice[pos+1:]) + return news +} + +type muxsub struct { + mux *TypeMux + mutex sync.RWMutex + closing chan struct{} + + // these two are the same channel. they are stored separately so + // postC can be set to nil without affecting the return value of + // Chan. + readC <-chan interface{} + postC chan<- interface{} +} + +func newsub(mux *TypeMux) *muxsub { + c := make(chan interface{}) + return &muxsub{ + mux: mux, + readC: c, + postC: c, + closing: make(chan struct{}), + } +} + +func (s *muxsub) Chan() <-chan interface{} { + return s.readC +} + +func (s *muxsub) Unsubscribe() { + s.mux.del(s) + s.closewait() +} + +func (s *muxsub) closewait() { + close(s.closing) + s.mutex.Lock() + close(s.postC) + s.postC = nil + s.mutex.Unlock() +} + +func (s *muxsub) deliver(ev interface{}) { + s.mutex.RLock() + select { + case s.postC <- ev: + case <-s.closing: + } + s.mutex.RUnlock() +} diff --git a/event/event_test.go b/event/event_test.go new file mode 100644 index 0000000000..385bd70b79 --- /dev/null +++ b/event/event_test.go @@ -0,0 +1,161 @@ +package event + +import ( + "math/rand" + "sync" + "testing" + "time" +) + +type testEvent int + +func TestSub(t *testing.T) { + mux := NewTypeMux() + defer mux.Stop() + + sub := mux.Subscribe(testEvent(0)) + go func() { + if err := mux.Post(testEvent(5)); err != nil { + t.Errorf("Post returned unexpected error: %v", err) + } + }() + ev := <-sub.Chan() + + if ev.(testEvent) != testEvent(5) { + t.Errorf("Got %v (%T), expected event %v (%T)", + ev, ev, testEvent(5), testEvent(5)) + } +} + +func TestMuxErrorAfterStop(t *testing.T) { + mux := NewTypeMux() + mux.Stop() + + sub := mux.Subscribe(testEvent(0)) + if _, isopen := <-sub.Chan(); isopen { + t.Errorf("subscription channel was not closed") + } + if err := mux.Post(testEvent(0)); err != ErrMuxClosed { + t.Errorf("Post error mismatch, got: %s, expected: %s", err, ErrMuxClosed) + } +} + +func TestUnsubscribeUnblockPost(t *testing.T) { + mux := NewTypeMux() + defer mux.Stop() + + sub := mux.Subscribe(testEvent(0)) + unblocked := make(chan bool) + go func() { + mux.Post(testEvent(5)) + unblocked <- true + }() + + select { + case <-unblocked: + t.Errorf("Post returned before Unsubscribe") + default: + sub.Unsubscribe() + <-unblocked + } +} + +func TestMuxConcurrent(t *testing.T) { + rand.Seed(time.Now().Unix()) + mux := NewTypeMux() + defer mux.Stop() + + recv := make(chan int) + poster := func() { + for { + err := mux.Post(testEvent(0)) + if err != nil { + return + } + } + } + sub := func(i int) { + time.Sleep(time.Duration(rand.Intn(99)) * time.Millisecond) + sub := mux.Subscribe(testEvent(0)) + <-sub.Chan() + sub.Unsubscribe() + recv <- i + } + + go poster() + go poster() + go poster() + nsubs := 1000 + for i := 0; i < nsubs; i++ { + go sub(i) + } + + // wait until everyone has been served + counts := make(map[int]int, nsubs) + for i := 0; i < nsubs; i++ { + counts[<-recv]++ + } + for i, count := range counts { + if count != 1 { + t.Errorf("receiver %d called %d times, expected only 1 call", i, count) + } + } +} + +func emptySubscriber(mux *TypeMux, types ...interface{}) { + s := mux.Subscribe(testEvent(0)) + go func() { + for _ = range s.Chan() { + } + }() +} + +func BenchmarkPost3(b *testing.B) { + var mux = NewTypeMux() + defer mux.Stop() + emptySubscriber(mux, testEvent(0)) + emptySubscriber(mux, testEvent(0)) + emptySubscriber(mux, testEvent(0)) + + for i := 0; i < b.N; i++ { + mux.Post(testEvent(0)) + } +} + +func BenchmarkPostConcurrent(b *testing.B) { + var mux = NewTypeMux() + defer mux.Stop() + emptySubscriber(mux, testEvent(0)) + emptySubscriber(mux, testEvent(0)) + emptySubscriber(mux, testEvent(0)) + + var wg sync.WaitGroup + poster := func() { + for i := 0; i < b.N; i++ { + mux.Post(testEvent(0)) + } + wg.Done() + } + wg.Add(5) + for i := 0; i < 5; i++ { + go poster() + } + wg.Wait() +} + +// for comparison +func BenchmarkChanSend(b *testing.B) { + c := make(chan interface{}) + closed := make(chan struct{}) + go func() { + for _ = range c { + } + }() + + for i := 0; i < b.N; i++ { + select { + case c <- i: + case <-closed: + } + } +}