signer/core/apitypes: support more input types for eip-712 encoding (#26074)

* apitypes: synchronize handling of types

* signer/core/apitypes: improve array check

* apitypes: add a test for big.Int -> int32

* signer/core/apitypes: Add a test for parsing addresses from [20]byte, []byte and string

* signer/core/apitypes: add some testcases

Co-authored-by: Felix Lange <fjl@twurst.com>
Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
Obtuse7772 2022-11-04 20:58:12 +05:30 committed by GitHub
parent a51188a163
commit 6d55908347
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 6 deletions

@ -21,6 +21,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
) )
@ -84,6 +85,55 @@ func TestBytesPadding(t *testing.T) {
} }
} }
func TestParseAddress(t *testing.T) {
tests := []struct {
Input interface{}
Output []byte // nil => error
}{
{
Input: [20]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
},
{
Input: "0x0102030405060708090A0B0C0D0E0F1011121314",
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
},
{
Input: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"),
},
// Various error-cases:
{Input: "0x000102030405060708090A0B0C0D0E0F1011121314"}, // too long string
{Input: "0x01"}, // too short string
{Input: ""},
{Input: [32]byte{}}, // too long fixed-size array
{Input: [21]byte{}}, // too long fixed-size array
{Input: make([]byte, 19)}, // too short slice
{Input: make([]byte, 21)}, // too long slice
{Input: nil},
}
d := TypedData{}
for i, test := range tests {
val, err := d.EncodePrimitiveValue("address", test.Input, 1)
if test.Output == nil {
if err == nil {
t.Errorf("test %d: expected error, got no error (result %x)", i, val)
}
continue
}
if err != nil {
t.Errorf("test %d: expected no error, got %v", i, err)
}
if have, want := len(val), 32; have != want {
t.Errorf("test %d: have len %d, want %d", i, have, want)
}
if !bytes.Equal(val, test.Output) {
t.Errorf("test %d: want %x, have %x", i, test.Output, val)
}
}
}
func TestParseBytes(t *testing.T) { func TestParseBytes(t *testing.T) {
for i, tt := range []struct { for i, tt := range []struct {
v interface{} v interface{}
@ -98,6 +148,9 @@ func TestParseBytes(t *testing.T) {
{"not a hex string", nil}, {"not a hex string", nil},
{15, nil}, {15, nil},
{nil, nil}, {nil, nil},
{[2]byte{12, 34}, []byte{12, 34}},
{[8]byte{12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56}},
{[16]byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}},
} { } {
out, ok := parseBytes(tt.v) out, ok := parseBytes(tt.v)
if tt.exp == nil { if tt.exp == nil {
@ -123,6 +176,7 @@ func TestParseInteger(t *testing.T) {
}{ }{
{"uint32", "-123", nil}, {"uint32", "-123", nil},
{"int32", "-123", big.NewInt(-123)}, {"int32", "-123", big.NewInt(-123)},
{"int32", big.NewInt(-124), big.NewInt(-124)},
{"uint32", "0xff", big.NewInt(0xff)}, {"uint32", "0xff", big.NewInt(0xff)},
{"int8", "0xffff", nil}, {"int8", "0xffff", nil},
} { } {

@ -418,6 +418,14 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter
// Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes. // Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes.
func parseBytes(encType interface{}) ([]byte, bool) { func parseBytes(encType interface{}) ([]byte, bool) {
// Handle array types.
val := reflect.ValueOf(encType)
if val.Kind() == reflect.Array && val.Type().Elem().Kind() == reflect.Uint8 {
v := reflect.MakeSlice(reflect.TypeOf([]byte{}), val.Len(), val.Len())
reflect.Copy(v, val)
return v.Bytes(), true
}
switch v := encType.(type) { switch v := encType.(type) {
case []byte: case []byte:
return v, true return v, true
@ -458,6 +466,8 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
switch v := encValue.(type) { switch v := encValue.(type) {
case *math.HexOrDecimal256: case *math.HexOrDecimal256:
b = (*big.Int)(v) b = (*big.Int)(v)
case *big.Int:
b = v
case string: case string:
var hexIntValue math.HexOrDecimal256 var hexIntValue math.HexOrDecimal256
if err := hexIntValue.UnmarshalText([]byte(v)); err != nil { if err := hexIntValue.UnmarshalText([]byte(v)); err != nil {
@ -490,13 +500,23 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) {
func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interface{}, depth int) ([]byte, error) { func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interface{}, depth int) ([]byte, error) {
switch encType { switch encType {
case "address": case "address":
stringValue, ok := encValue.(string)
if !ok || !common.IsHexAddress(stringValue) {
return nil, dataMismatchError(encType, encValue)
}
retval := make([]byte, 32) retval := make([]byte, 32)
copy(retval[12:], common.HexToAddress(stringValue).Bytes()) switch val := encValue.(type) {
return retval, nil case string:
if common.IsHexAddress(val) {
copy(retval[12:], common.HexToAddress(val).Bytes())
return retval, nil
}
case []byte:
if len(val) == 20 {
copy(retval[12:], val)
return retval, nil
}
case [20]byte:
copy(retval[12:], val[:])
return retval, nil
}
return nil, dataMismatchError(encType, encValue)
case "bool": case "bool":
boolValue, ok := encValue.(bool) boolValue, ok := encValue.(bool)
if !ok { if !ok {