diff --git a/rlp/decode.go b/rlp/decode.go index 3b5617475b..ca9252575b 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -9,6 +9,7 @@ import ( "io" "math/big" "reflect" + "strings" ) var ( @@ -70,14 +71,22 @@ type Decoder interface { // Non-empty interface types are not supported, nor are booleans, // signed integers, floating point numbers, maps, channels and // functions. +// +// Note that Decode does not set an input limit for all readers +// and may be vulnerable to panics cause by huge value sizes. If +// you need an input limit, use +// +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - return NewStream(r).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(r, 0).Decode(val) } // DecodeBytes parses RLP data from b into val. // Please see the documentation of Decode for the decoding rules. func DecodeBytes(b []byte, val interface{}) error { - return NewStream(bytes.NewReader(b)).Decode(val) + // TODO: this could use a Stream from a pool. + return NewStream(bytes.NewReader(b), uint64(len(b))).Decode(val) } type decodeError struct { @@ -470,10 +479,12 @@ var ( ErrExpectedList = errors.New("rlp: expected List") ErrCanonInt = errors.New("rlp: expected Int") ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") ) // ByteReader must be implemented by any input reader for a Stream. It @@ -496,7 +507,13 @@ type ByteReader interface { // // Stream is not safe for concurrent use. type Stream struct { - r ByteReader + r ByteReader + + // number of bytes remaining to be read from r. + remaining uint64 + limited bool + + // auxiliary buffer for integer decoding uintbuf []byte kind Kind // kind of value ahead @@ -507,12 +524,26 @@ type Stream struct { type listpos struct{ pos, size uint64 } -// NewStream creates a new stream reading from r. -// If r does not implement ByteReader, the Stream will -// introduce its own buffering. -func NewStream(r io.Reader) *Stream { +// NewStream creates a new decoding stream reading from r. +// +// If r implements the ByteReader interface, Stream will +// not introduce any buffering. +// +// For non-toplevel values, Stream returns ErrElemTooLarge +// for values that do not fit into the enclosing list. +// +// Stream supports an optional input limit. If a limit is set, the +// size of any toplevel value will be checked against the remaining +// input length. Stream operations that encounter a value exceeding +// the remaining input length will return ErrValueTooLarge. The limit +// can be set by passing a non-zero value for inputLimit. +// +// If r is a bytes.Reader or strings.Reader, the input limit is set to +// the length of r's underlying data unless an explicit limit is +// provided. +func NewStream(r io.Reader, inputLimit uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, inputLimit) return s } @@ -520,7 +551,7 @@ func NewStream(r io.Reader) *Stream { // at an encoded list of the given length. func NewListStream(r io.Reader, len uint64) *Stream { s := new(Stream) - s.Reset(r) + s.Reset(r, len) s.kind = List s.size = len return s @@ -574,8 +605,6 @@ func (s *Stream) Raw() ([]byte, error) { return buf, nil } -var errUintOverflow = errors.New("rlp: uint overflow") - // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. @@ -667,14 +696,36 @@ func (s *Stream) Decode(val interface{}) error { } // Reset discards any information about the current decoding context -// and starts reading from r. If r does not also implement ByteReader, -// Stream will do its own buffering. -func (s *Stream) Reset(r io.Reader) { +// and starts reading from r. This method is meant to facilitate reuse +// of a preallocated Stream across many decoding operations. +// +// If r does not also implement ByteReader, Stream will do its own +// buffering. +func (s *Stream) Reset(r io.Reader, inputLimit uint64) { + if inputLimit > 0 { + s.remaining = inputLimit + s.limited = true + } else { + // Attempt to automatically discover + // the limit when reading from a byte slice. + switch br := r.(type) { + case *bytes.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + case *strings.Reader: + s.remaining = uint64(br.Len()) + s.limited = true + default: + s.limited = false + } + } + // Wrap r with a buffer if it doesn't have one. bufr, ok := r.(ByteReader) if !ok { bufr = bufio.NewReader(r) } s.r = bufr + // Reset the decoding context. s.stack = s.stack[:0] s.size = 0 s.kind = -1 @@ -700,6 +751,8 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { tos = &s.stack[len(s.stack)-1] } if s.kind < 0 { + // don't read further if we're at the end of the + // innermost list. if tos != nil && tos.pos == tos.size { return 0, 0, EOL } @@ -709,8 +762,19 @@ func (s *Stream) Kind() (kind Kind, size uint64, err error) { } s.kind, s.size = kind, size } - if tos != nil && tos.pos+s.size > tos.size { - return 0, 0, ErrElemTooLarge + // Make sure size is reasonable. This is done always + // so Kind returns the same error when called multiple times. + if tos == nil { + // At toplevel, check that the value is smaller + // than the remaining input length. + if s.limited && s.size > s.remaining { + return 0, 0, ErrValueTooLarge + } + } else { + // Inside a list, check that the value doesn't overflow the list. + if tos.pos+s.size > tos.size { + return 0, 0, ErrElemTooLarge + } } return s.kind, s.size, nil } @@ -778,6 +842,9 @@ func (s *Stream) readUint(size byte) (uint64, error) { } func (s *Stream) readFull(buf []byte) (err error) { + if s.limited && s.remaining < uint64(len(buf)) { + return ErrValueTooLarge + } s.willRead(uint64(len(buf))) var nn, n int for n < len(buf) && err == nil { @@ -791,6 +858,9 @@ func (s *Stream) readFull(buf []byte) (err error) { } func (s *Stream) readByte() (byte, error) { + if s.limited && s.remaining == 0 { + return 0, io.EOF + } s.willRead(1) b, err := s.r.ReadByte() if len(s.stack) > 0 && err == io.EOF { @@ -801,6 +871,9 @@ func (s *Stream) readByte() (byte, error) { func (s *Stream) willRead(n uint64) { s.kind = -1 // rearm Kind + if s.limited { + s.remaining -= n + } if len(s.stack) > 0 { s.stack[len(s.stack)-1].pos += n } diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 73a31c67f5..6b37ab0ad7 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -36,7 +36,8 @@ func TestStreamKind(t *testing.T) { } for i, test := range tests { - s := NewStream(bytes.NewReader(unhex(test.input))) + // using plainReader to inhibit input limit errors. + s := NewStream(newPlainReader(unhex(test.input)), 0) kind, len, err := s.Kind() if err != nil { t.Errorf("test %d: Kind returned error: %v", i, err) @@ -70,29 +71,63 @@ func TestNewListStream(t *testing.T) { } func TestStreamErrors(t *testing.T) { + withoutInputLimit := func(b []byte) *Stream { + return NewStream(newPlainReader(b), 0) + } + withCustomInputLimit := func(limit uint64) func([]byte) *Stream { + return func(b []byte) *Stream { + return NewStream(bytes.NewReader(b), limit) + } + } + type calls []string tests := []struct { string calls + newStream func([]byte) *Stream // uses bytes.Reader if nil error }{ - {"", calls{"Kind"}, io.EOF}, - {"", calls{"List"}, io.EOF}, - {"", calls{"Uint"}, io.EOF}, - {"C0", calls{"Bytes"}, ErrExpectedString}, - {"C0", calls{"Uint"}, ErrExpectedString}, - {"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"81", calls{"Uint"}, io.ErrUnexpectedEOF}, - {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, - {"89000000000000000001", calls{"Uint"}, errUintOverflow}, - {"00", calls{"List"}, ErrExpectedList}, - {"80", calls{"List"}, ErrExpectedList}, - {"C0", calls{"List", "Uint"}, EOL}, - {"C801", calls{"List", "Uint", "Uint"}, io.ErrUnexpectedEOF}, - {"C8C9", calls{"List", "Kind"}, ErrElemTooLarge}, - {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL}, - {"00", calls{"ListEnd"}, errNotInList}, - {"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL}, + {"C0", calls{"Bytes"}, nil, ErrExpectedString}, + {"C0", calls{"Uint"}, nil, ErrExpectedString}, + {"89000000000000000001", calls{"Uint"}, nil, errUintOverflow}, + {"00", calls{"List"}, nil, ErrExpectedList}, + {"80", calls{"List"}, nil, ErrExpectedList}, + {"C0", calls{"List", "Uint"}, nil, EOL}, + {"C8C9010101010101010101", calls{"List", "Kind"}, nil, ErrElemTooLarge}, + {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, nil, EOL}, + {"00", calls{"ListEnd"}, nil, errNotInList}, + {"C401020304", calls{"List", "Uint", "ListEnd"}, nil, errNotAtEOL}, + + // Expected EOF + {"", calls{"Kind"}, nil, io.EOF}, + {"", calls{"Uint"}, nil, io.EOF}, + {"", calls{"List"}, nil, io.EOF}, + {"8105", calls{"Uint", "Uint"}, nil, io.EOF}, + {"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF}, + + // Input limit errors. + {"81", calls{"Bytes"}, nil, ErrValueTooLarge}, + {"81", calls{"Uint"}, nil, ErrValueTooLarge}, + {"81", calls{"Raw"}, nil, ErrValueTooLarge}, + {"BFFFFFFFFFFFFFFFFFFF", calls{"Bytes"}, nil, ErrValueTooLarge}, + {"C801", calls{"List"}, nil, ErrValueTooLarge}, + + // Test for input limit overflow. Since we are counting the limit + // down toward zero in Stream.remaining, reading too far can overflow + // remaining to a large value, effectively disabling the limit. + {"C40102030401", calls{"Raw", "Uint"}, withCustomInputLimit(5), io.EOF}, + {"C4010203048102", calls{"Raw", "Uint"}, withCustomInputLimit(6), ErrValueTooLarge}, + + // Check that the same calls are fine without a limit. + {"C40102030401", calls{"Raw", "Uint"}, withoutInputLimit, nil}, + {"C4010203048102", calls{"Raw", "Uint"}, withoutInputLimit, nil}, + + // Unexpected EOF. This only happens when there is + // no input limit, so the reader needs to be 'dumbed down'. + {"81", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"81", calls{"Uint"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF}, + {"C801", calls{"List", "Uint", "Uint"}, withoutInputLimit, io.ErrUnexpectedEOF}, // This test verifies that the input position is advanced // correctly when calling Bytes for empty strings. Kind can be called @@ -109,12 +144,15 @@ func TestStreamErrors(t *testing.T) { "Bytes", // past final element "Bytes", // this one should fail - }, EOL}, + }, nil, EOL}, } testfor: for i, test := range tests { - s := NewStream(bytes.NewReader(unhex(test.string))) + if test.newStream == nil { + test.newStream = func(b []byte) *Stream { return NewStream(bytes.NewReader(b), 0) } + } + s := test.newStream(unhex(test.string)) rs := reflect.ValueOf(s) for j, call := range test.calls { fval := rs.MethodByName(call) @@ -124,8 +162,12 @@ testfor: err = lastret.(error).Error() } if j == len(test.calls)-1 { - if err != test.error.Error() { - t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %v", + want := "" + if test.error != nil { + want = test.error.Error() + } + if err != want { + t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %s", i, call, err, test.error) } } else if err != "" { @@ -137,7 +179,7 @@ testfor: } func TestStreamList(t *testing.T) { - s := NewStream(bytes.NewReader(unhex("C80102030405060708"))) + s := NewStream(bytes.NewReader(unhex("C80102030405060708")), 0) len, err := s.List() if err != nil { @@ -166,7 +208,7 @@ func TestStreamList(t *testing.T) { } func TestStreamRaw(t *testing.T) { - s := NewStream(bytes.NewReader(unhex("C58401010101"))) + s := NewStream(bytes.NewReader(unhex("C58401010101")), 0) s.List() want := unhex("8401010101") @@ -284,11 +326,6 @@ var decodeTests = []decodeTest{ ptr: new([5]byte), error: "rlp: input string too long for [5]uint8", }, - { - input: "850101", - ptr: new([5]byte), - error: io.ErrUnexpectedEOF.Error(), - }, // byte array reuse (should be zeroed) {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, @@ -401,11 +438,17 @@ func TestDecodeWithByteReader(t *testing.T) { }) } -// dumbReader reads from a byte slice but does not -// implement ReadByte. -type dumbReader []byte +// plainReader reads from a byte slice but does not +// implement ReadByte. It is also not recognized by the +// size validation. This is useful to test how the decoder +// behaves on a non-buffered input stream. +type plainReader []byte -func (r *dumbReader) Read(buf []byte) (n int, err error) { +func newPlainReader(b []byte) io.Reader { + return (*plainReader)(&b) +} + +func (r *plainReader) Read(buf []byte) (n int, err error) { if len(*r) == 0 { return 0, io.EOF } @@ -416,15 +459,14 @@ func (r *dumbReader) Read(buf []byte) (n int, err error) { func TestDecodeWithNonByteReader(t *testing.T) { runTests(t, func(input []byte, into interface{}) error { - r := dumbReader(input) - return Decode(&r, into) + return Decode(newPlainReader(input), into) }) } func TestDecodeStreamReset(t *testing.T) { - s := NewStream(nil) + s := NewStream(nil, 0) runTests(t, func(input []byte, into interface{}) error { - s.Reset(bytes.NewReader(input)) + s.Reset(bytes.NewReader(input), 0) return s.Decode(into) }) } @@ -518,7 +560,7 @@ func ExampleDecode() { func ExampleStream() { input, _ := hex.DecodeString("C90A1486666F6F626172") - s := NewStream(bytes.NewReader(input)) + s := NewStream(bytes.NewReader(input), 0) // Check what kind of value lies ahead kind, size, _ := s.Kind()