From c1fca72552386868d28ce7541691e53e55673549 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 24 Nov 2014 19:02:48 +0100 Subject: [PATCH] p2p: use package rlp --- p2p/message.go | 101 ++++++++++++++------------------------------ p2p/message_test.go | 3 ++ p2p/peer_test.go | 2 +- 3 files changed, 35 insertions(+), 71 deletions(-) diff --git a/p2p/message.go b/p2p/message.go index 89ad189d76..ade39d25a3 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -3,12 +3,12 @@ package p2p import ( "bytes" "encoding/binary" - "fmt" "io" "io/ioutil" "math/big" "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/rlp" ) // Msg defines the structure of a p2p message. @@ -43,16 +43,10 @@ func encodePayload(params ...interface{}) []byte { // Data returns the decoded RLP payload items in a message. func (msg Msg) Data() (*ethutil.Value, error) { - // TODO: avoid copying when we have a better RLP decoder - buf := new(bytes.Buffer) - var s []interface{} - if _, err := buf.ReadFrom(msg.Payload); err != nil { - return nil, err - } - for buf.Len() > 0 { - s = append(s, ethutil.DecodeWithReader(buf)) - } - return ethutil.NewValue(s), nil + s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) + var v []interface{} + err := s.Decode(&v) + return ethutil.NewValue(v), err } // Discard reads any remaining payload data into a black hole. @@ -137,13 +131,9 @@ func makeListHeader(length uint32) []byte { return append([]byte{lenb}, enc...) } -type byteReader interface { - io.Reader - io.ByteReader -} - // readMsg reads a message header from r. -func readMsg(r byteReader) (msg Msg, err error) { +// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. +func readMsg(r rlp.ByteReader) (msg Msg, err error) { // read magic and payload size start := make([]byte, 8) if _, err = io.ReadFull(r, start); err != nil { @@ -155,64 +145,35 @@ func readMsg(r byteReader) (msg Msg, err error) { size := binary.BigEndian.Uint32(start[4:]) // decode start of RLP message to get the message code - _, hdrlen, err := readListHeader(r) + posr := &postrack{r, 0} + s := rlp.NewStream(posr) + if _, err := s.List(); err != nil { + return msg, err + } + code, err := s.Uint() if err != nil { return msg, err } - code, codelen, err := readMsgCode(r) - if err != nil { - return msg, err - } - - rlpsize := size - hdrlen - codelen - return Msg{ - Code: code, - Size: rlpsize, - Payload: io.LimitReader(r, int64(rlpsize)), - }, nil + payloadsize := size - posr.p + return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil } -// readListHeader reads an RLP list header from r. -func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { - b, err := r.ReadByte() - if err != nil { - return 0, 0, err - } - if b < 0xC0 { - return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b) - } else if b < 0xF7 { - len = uint64(b - 0xc0) - hdrlen = 1 - } else { - lenlen := b - 0xF7 - lenbuf := make([]byte, 8) - if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil { - return 0, 0, err - } - len = binary.BigEndian.Uint64(lenbuf) - hdrlen = 1 + uint32(lenlen) - } - return len, hdrlen, nil +// postrack wraps an rlp.ByteReader with a position counter. +type postrack struct { + r rlp.ByteReader + p uint32 } -// readUint reads an RLP-encoded unsigned integer from r. -func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) { - b, err := r.ReadByte() - if err != nil { - return 0, 0, err - } - if b < 0x80 { - return uint64(b), 1, nil - } else if b < 0x89 { // max length for uint64 is 8 bytes - codelen = uint32(b - 0x80) - if codelen == 0 { - return 0, 1, nil - } - buf := make([]byte, 8) - if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { - return 0, 0, err - } - return binary.BigEndian.Uint64(buf), codelen, nil - } - return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) +func (r *postrack) Read(buf []byte) (int, error) { + n, err := r.r.Read(buf) + r.p += uint32(n) + return n, err +} + +func (r *postrack) ReadByte() (byte, error) { + b, err := r.r.ReadByte() + if err == nil { + r.p++ + } + return b, err } diff --git a/p2p/message_test.go b/p2p/message_test.go index 1edabc4e7b..02d70a28ba 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -46,6 +46,9 @@ func TestEncodeDecodeMsg(t *testing.T) { if err != nil { t.Fatalf("first payload item decode error: %v", err) } + if v := data.Len(); v != 2 { + t.Errorf("incorrect data.Len(): got %v, expected %d", v, 1) + } if v := data.Get(0).Uint(); v != 1 { t.Errorf("incorrect data[0]: got %v, expected %d", v, 1) } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 1afa0ab170..56cd4d8902 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -57,7 +57,7 @@ func TestPeerProtoReadMsg(t *testing.T) { if err != nil { t.Errorf("data decoding error: %v", err) } - expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}} + expdata := []interface{}{[]byte{0x01}, []byte{0x30, 0x30, 0x30}} if !reflect.DeepEqual(data.Slice(), expdata) { t.Errorf("incorrect msg data %#v", data.Slice()) }