diff --git a/signer/core/apitypes/signed_data_internal_test.go b/signer/core/apitypes/signed_data_internal_test.go index 121cc00dec..8379c0a7f0 100644 --- a/signer/core/apitypes/signed_data_internal_test.go +++ b/signer/core/apitypes/signed_data_internal_test.go @@ -21,6 +21,7 @@ import ( "math/big" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" ) @@ -84,6 +85,55 @@ func TestBytesPadding(t *testing.T) { } } +func TestParseAddress(t *testing.T) { + tests := []struct { + Input interface{} + Output []byte // nil => error + }{ + { + Input: [20]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14}, + Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"), + }, + { + Input: "0x0102030405060708090A0B0C0D0E0F1011121314", + Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"), + }, + { + Input: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14}, + Output: common.FromHex("0x0000000000000000000000000102030405060708090A0B0C0D0E0F1011121314"), + }, + // Various error-cases: + {Input: "0x000102030405060708090A0B0C0D0E0F1011121314"}, // too long string + {Input: "0x01"}, // too short string + {Input: ""}, + {Input: [32]byte{}}, // too long fixed-size array + {Input: [21]byte{}}, // too long fixed-size array + {Input: make([]byte, 19)}, // too short slice + {Input: make([]byte, 21)}, // too long slice + {Input: nil}, + } + + d := TypedData{} + for i, test := range tests { + val, err := d.EncodePrimitiveValue("address", test.Input, 1) + if test.Output == nil { + if err == nil { + t.Errorf("test %d: expected error, got no error (result %x)", i, val) + } + continue + } + if err != nil { + t.Errorf("test %d: expected no error, got %v", i, err) + } + if have, want := len(val), 32; have != want { + t.Errorf("test %d: have len %d, want %d", i, have, want) + } + if !bytes.Equal(val, test.Output) { + t.Errorf("test %d: want %x, have %x", i, test.Output, val) + } + } +} + func TestParseBytes(t *testing.T) { for i, tt := range []struct { v interface{} @@ -98,6 +148,9 @@ func TestParseBytes(t *testing.T) { {"not a hex string", nil}, {15, nil}, {nil, nil}, + {[2]byte{12, 34}, []byte{12, 34}}, + {[8]byte{12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56}}, + {[16]byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}, []byte{12, 34, 56, 78, 90, 12, 34, 56, 12, 34, 56, 78, 90, 12, 34, 56}}, } { out, ok := parseBytes(tt.v) if tt.exp == nil { @@ -123,6 +176,7 @@ func TestParseInteger(t *testing.T) { }{ {"uint32", "-123", nil}, {"int32", "-123", big.NewInt(-123)}, + {"int32", big.NewInt(-124), big.NewInt(-124)}, {"uint32", "0xff", big.NewInt(0xff)}, {"int8", "0xffff", nil}, } { diff --git a/signer/core/apitypes/types.go b/signer/core/apitypes/types.go index 2c8907ac82..6e883b27c8 100644 --- a/signer/core/apitypes/types.go +++ b/signer/core/apitypes/types.go @@ -418,6 +418,14 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter // Attempt to parse bytes in different formats: byte array, hex string, hexutil.Bytes. func parseBytes(encType interface{}) ([]byte, bool) { + // Handle array types. + val := reflect.ValueOf(encType) + if val.Kind() == reflect.Array && val.Type().Elem().Kind() == reflect.Uint8 { + v := reflect.MakeSlice(reflect.TypeOf([]byte{}), val.Len(), val.Len()) + reflect.Copy(v, val) + return v.Bytes(), true + } + switch v := encType.(type) { case []byte: return v, true @@ -458,6 +466,8 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) { switch v := encValue.(type) { case *math.HexOrDecimal256: b = (*big.Int)(v) + case *big.Int: + b = v case string: var hexIntValue math.HexOrDecimal256 if err := hexIntValue.UnmarshalText([]byte(v)); err != nil { @@ -490,13 +500,23 @@ func parseInteger(encType string, encValue interface{}) (*big.Int, error) { func (typedData *TypedData) EncodePrimitiveValue(encType string, encValue interface{}, depth int) ([]byte, error) { switch encType { case "address": - stringValue, ok := encValue.(string) - if !ok || !common.IsHexAddress(stringValue) { - return nil, dataMismatchError(encType, encValue) - } retval := make([]byte, 32) - copy(retval[12:], common.HexToAddress(stringValue).Bytes()) - return retval, nil + switch val := encValue.(type) { + case string: + if common.IsHexAddress(val) { + copy(retval[12:], common.HexToAddress(val).Bytes()) + return retval, nil + } + case []byte: + if len(val) == 20 { + copy(retval[12:], val) + return retval, nil + } + case [20]byte: + copy(retval[12:], val[:]) + return retval, nil + } + return nil, dataMismatchError(encType, encValue) case "bool": boolValue, ok := encValue.(bool) if !ok {