signer: EIP 712, parse bytes and bytesX as hex strings + correct padding (#21307)

* Handle hex strings for bytesX types

* Add tests for parseBytes

* Improve tests

* Return nil bytes if error is non-nil

* Right-pad instead of left-pad bytes

* More tests
This commit is contained in:
Natsu Kagami 2020-08-03 19:53:12 +00:00 committed by GitHub
parent c0c01612e9
commit 90dedea40f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 3 deletions

@ -481,6 +481,24 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
func parseBytes(encType interface{}) ([]byte, bool) {
switch v := encType.(type) {
case []byte:
return v, true
case hexutil.Bytes:
return []byte(v), true
case string:
bytes, err := hexutil.Decode(v)
if err != nil {
return nil, false
}
return bytes, true
default:
return nil, false
}
}
func parseInteger(encType string, encValue interface{}) (*big.Int, error) { func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
var ( var (
length int length int
@ -560,7 +578,7 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf
} }
return crypto.Keccak256([]byte(strVal)), nil return crypto.Keccak256([]byte(strVal)), nil
case "bytes": case "bytes":
bytesValue, ok := encValue.([]byte) bytesValue, ok := parseBytes(encValue)
if !ok { if !ok {
return nil, dataMismatchError(encType, encValue) return nil, dataMismatchError(encType, encValue)
} }
@ -575,10 +593,13 @@ func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interf
if length < 0 || length > 32 { if length < 0 || length > 32 {
return nil, fmt.Errorf("invalid size on bytes: %d", length) return nil, fmt.Errorf("invalid size on bytes: %d", length)
} }
if byteValue, ok := encValue.(hexutil.Bytes); !ok { if byteValue, ok := parseBytes(encValue); !ok || len(byteValue) != length {
return nil, dataMismatchError(encType, encValue) return nil, dataMismatchError(encType, encValue)
} else { } else {
return math.PaddedBigBytes(new(big.Int).SetBytes(byteValue), 32), nil // Right-pad the bits
dst := make([]byte, 32)
copy(dst, byteValue)
return dst, nil
} }
} }
if strings.HasPrefix(encType, "int") || strings.HasPrefix(encType, "uint") { if strings.HasPrefix(encType, "int") || strings.HasPrefix(encType, "uint") {

@ -17,10 +17,104 @@
package core package core
import ( import (
"bytes"
"math/big" "math/big"
"testing" "testing"
"github.com/ethereum/go-ethereum/common/hexutil"
) )
func TestBytesPadding(t *testing.T) {
tests := []struct {
Type string
Input []byte
Output []byte // nil => error
}{
{
// Fail on wrong length
Type: "bytes20",
Input: []byte{},
Output: nil,
},
{
Type: "bytes1",
Input: []byte{1},
Output: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
Type: "bytes1",
Input: []byte{1, 2},
Output: nil,
},
{
Type: "bytes7",
Input: []byte{1, 2, 3, 4, 5, 6, 7},
Output: []byte{1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
Type: "bytes32",
Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
Output: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
},
{
Type: "bytes32",
Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33},
Output: nil,
},
}
d := TypedData{}
for i, test := range tests {
val, err := d.EncodePrimitiveValue(test.Type, test.Input, 1)
if test.Output == nil {
if err == nil {
t.Errorf("test %d: expected error, got no error (result %x)", i, val)
}
} else {
if err != nil {
t.Errorf("test %d: expected no error, got %v", i, err)
}
if len(val) != 32 {
t.Errorf("test %d: expected len 32, got %d", i, len(val))
}
if !bytes.Equal(val, test.Output) {
t.Errorf("test %d: expected %x, got %x", i, test.Output, val)
}
}
}
}
func TestParseBytes(t *testing.T) {
for i, tt := range []struct {
v interface{}
exp []byte
}{
{"0x", []byte{}},
{"0x1234", []byte{0x12, 0x34}},
{[]byte{12, 34}, []byte{12, 34}},
{hexutil.Bytes([]byte{12, 34}), []byte{12, 34}},
{"1234", nil}, // not a proper hex-string
{"0x01233", nil}, // nibbles should be rejected
{"not a hex string", nil},
{15, nil},
{nil, nil},
} {
out, ok := parseBytes(tt.v)
if tt.exp == nil {
if ok || out != nil {
t.Errorf("test %d: expected !ok, got ok = %v with out = %x", i, ok, out)
}
continue
}
if !ok {
t.Errorf("test %d: expected ok got !ok", i)
}
if !bytes.Equal(out, tt.exp) {
t.Errorf("test %d: expected %x got %x", i, tt.exp, out)
}
}
}
func TestParseInteger(t *testing.T) { func TestParseInteger(t *testing.T) {
for i, tt := range []struct { for i, tt := range []struct {
t string t string