diff --git a/rlp/decode.go b/rlp/decode.go index e638c01eaf..4f29f2fb03 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -115,15 +115,17 @@ type Decoder interface { // type, Decode will return an error. Decode also supports *big.Int. // There is no size limit for big integers. // +// To decode into a boolean, the input must contain an unsigned integer +// of value zero (false) or one (true). +// // To decode into an interface value, Decode stores one of these // in the value: // // []interface{}, for RLP lists // []byte, for RLP strings // -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. +// Non-empty interface types are not supported, nor are signed integers, +// floating point numbers, maps, channels and functions. // // Note that Decode does not set an input limit for all readers // and may be vulnerable to panics cause by huge value sizes. If @@ -306,9 +308,9 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { } return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -467,9 +469,9 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { // the pointer's element type. func makePtrDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } dec := func(s *Stream, val reflect.Value) (err error) { newval := val @@ -491,9 +493,9 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { // This decoder is used for pointer-typed struct fields with struct tag "nil". func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } dec := func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() @@ -814,12 +816,12 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { // add decode target type to error so context has more meaning decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd0012..fa57182c99 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -347,6 +347,12 @@ type tailUint struct { Tail []uint `rlp:"tail"` } +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool +} + var ( veryBigInt = big.NewInt(0).Add( big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), @@ -510,6 +516,11 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, // struct tag "-" { @@ -691,6 +702,27 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") diff --git a/rlp/encode.go b/rlp/encode.go index 445b4b5b21..f255c38a9c 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -73,10 +73,12 @@ type Encoder interface { // An unsigned integer value is encoded as an RLP string. Zero always // encodes as an empty RLP string. Encode also supports *big.Int. // +// Boolean values are encoded as unsigned integers zero (false) and one (true). +// // An interface value encodes as the value contained in the interface. // -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Signed integers are not supported, nor are floating point numbers, maps, +// channels and functions. func Encode(w io.Writer, val interface{}) error { if outer, ok := w.(*encbuf); ok { // Encode was called by some type's EncodeRLP. @@ -180,11 +182,11 @@ func (w *encbuf) Write(b []byte) (int, error) { func (w *encbuf) encode(val interface{}) error { rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) + writer, err := cachedWriter(rval.Type()) if err != nil { return err } - return ti.writer(rval, w) + return writer(rval, w) } func (w *encbuf) encodeStringHeader(size int) { @@ -497,17 +499,17 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(typ.Elem(), tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } writer := func(val reflect.Value, w *encbuf) error { if !ts.tail { @@ -543,9 +545,9 @@ func makeStructWriter(typ reflect.Type) (writer, error) { } func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(typ.Elem(), tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } // determine nil pointer handler @@ -577,7 +579,7 @@ func makePtrWriter(typ reflect.Type) (writer, error) { } return etypeinfo.writer(val.Elem(), w) } - return writer, err + return writer, nil } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 827960f7c1..6e49b89a85 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -49,6 +49,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + _, err := w.Write(EmptyList) + return err +} + type encodableReader struct { A, B uint } @@ -239,6 +246,8 @@ var encTests = []encTest{ {val: (*testEncoder)(nil), output: "00000000"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, + // verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "C0"}, // verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, diff --git a/rlp/typecache.go b/rlp/typecache.go index 8c2dd518e2..ab5ee3da76 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -29,8 +29,10 @@ var ( ) type typeinfo struct { - decoder - writer + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } // represents struct tags @@ -56,12 +58,22 @@ type decoder func(*Stream, reflect.Value) error type writer func(reflect.Value, *encbuf) error -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := cachedTypeInfo(typ, tags{}) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := cachedTypeInfo(typ, tags{}) + return info.writer, info.writerErr +} + +func cachedTypeInfo(typ reflect.Type, tags tags) *typeinfo { typeCacheMutex.RLock() info := typeCache[typekey{typ, tags}] typeCacheMutex.RUnlock() if info != nil { - return info, nil + return info } // not in the cache, need to generate info for this type. typeCacheMutex.Lock() @@ -69,25 +81,20 @@ func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { return cachedTypeInfo1(typ, tags) } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func cachedTypeInfo1(typ reflect.Type, tags tags) *typeinfo { key := typekey{typ, tags} info := typeCache[key] if info != nil { // another goroutine got the write lock first - return info, nil + return info } // put a dummy value into the cache before generating. // if the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info = new(typeinfo) + typeCache[key] = info + info.generate(typ, tags) + return info } type field struct { @@ -96,26 +103,24 @@ type field struct { } func structFields(typ reflect.Type) (fields []field, err error) { + lastPublic := lastPublicField(typ) for i := 0; i < typ.NumField(); i++ { if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) + tags, err := parseStructTag(typ, i, lastPublic) if err != nil { return nil, err } if tags.ignored { continue } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } + info := cachedTypeInfo1(f.Type, tags) fields = append(fields, field{i, info}) } } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { +func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) { f := typ.Field(fi) var ts tags for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { @@ -127,7 +132,7 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) { ts.nilOK = true case "tail": ts.tail = true - if fi != typ.NumField()-1 { + if fi != lastPublic { return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) } if f.Type.Kind() != reflect.Slice { @@ -140,15 +145,19 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) { return ts, nil } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +func lastPublicField(typ reflect.Type) int { + last := 0 + for i := 0; i < typ.NumField(); i++ { + if typ.Field(i).PkgPath == "" { + last = i + } } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err - } - return info, nil + return last +} + +func (i *typeinfo) generate(typ reflect.Type, tags tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) } func isUint(k reflect.Kind) bool {