p2p: improve and test eofSignal
This commit is contained in:
parent
9423401d73
commit
e28c60caf9
17
p2p/peer.go
17
p2p/peer.go
@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error)
|
||||
proto.in <- msg
|
||||
} else {
|
||||
wait = true
|
||||
pr := &eofSignal{msg.Payload, protoDone}
|
||||
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
|
||||
msg.Payload = pr
|
||||
proto.in <- msg
|
||||
}
|
||||
@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// eofSignal wraps a reader with eof signaling.
|
||||
// the eof channel is closed when the wrapped reader
|
||||
// reaches EOF.
|
||||
// eofSignal wraps a reader with eof signaling. the eof channel is
|
||||
// closed when the wrapped reader returns an error or when count bytes
|
||||
// have been read.
|
||||
//
|
||||
type eofSignal struct {
|
||||
wrapped io.Reader
|
||||
count int64
|
||||
eof chan<- struct{}
|
||||
}
|
||||
|
||||
// note: when using eofSignal to detect whether a message payload
|
||||
// has been read, Read might not be called for zero sized messages.
|
||||
|
||||
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||
n, err := r.wrapped.Read(buf)
|
||||
if err != nil {
|
||||
r.count -= int64(n)
|
||||
if (err != nil || r.count <= 0) && r.eof != nil {
|
||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||
r.eof = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) {
|
||||
// Should not hang.
|
||||
p.Disconnect(DiscAlreadyConnected)
|
||||
}
|
||||
|
||||
func TestEOFSignal(t *testing.T) {
|
||||
rb := make([]byte, 10)
|
||||
|
||||
// empty reader
|
||||
eof := make(chan struct{}, 1)
|
||||
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// count before error
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
||||
if n, err := sig.Read(rb); n != 8 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// error before count
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
default:
|
||||
t.Error("EOF chan not signaled")
|
||||
}
|
||||
|
||||
// no signal if neither occurs
|
||||
eof = make(chan struct{}, 1)
|
||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
||||
if n, err := sig.Read(rb); n != 10 || err != nil {
|
||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||
}
|
||||
select {
|
||||
case <-eof:
|
||||
t.Error("unexpected EOF signal")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user