Added EIP-712 multi-dimensional array support (#687).

This commit is contained in:
Richard Moore 2020-10-12 00:58:04 -04:00
parent 345a830dc4
commit 5a4dd5a703
No known key found for this signature in database
GPG Key ID: 665176BE8E9DC651
2 changed files with 64 additions and 93 deletions

@ -1,67 +1,19 @@
"use strict"; "use strict";
import { Bytes, concat, hexlify } from "@ethersproject/bytes"; import { id } from "./id";
import { nameprep, toUtf8Bytes } from "@ethersproject/strings"; import { isValidName, namehash } from "./namehash";
import { keccak256 } from "@ethersproject/keccak256"; import { hashMessage, messagePrefix } from "./message";
import { Logger } from "@ethersproject/logger";
import { version } from "./_version";
const logger = new Logger(version);
import { TypedDataEncoder as _TypedDataEncoder } from "./typed-data"; import { TypedDataEncoder as _TypedDataEncoder } from "./typed-data";
import { id } from "./id";
export { export {
id,
namehash,
isValidName,
messagePrefix,
hashMessage,
_TypedDataEncoder, _TypedDataEncoder,
id
}
///////////////////////////////
const Zeros = new Uint8Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
const Partition = new RegExp("^((.*)\\.)?([^.]+)$");
export function isValidName(name: string): boolean {
try {
const comps = name.split(".");
for (let i = 0; i < comps.length; i++) {
if (nameprep(comps[i]).length === 0) {
throw new Error("empty")
}
}
return true;
} catch (error) { }
return false;
}
export function namehash(name: string): string {
/* istanbul ignore if */
if (typeof(name) !== "string") {
logger.throwArgumentError("invalid address - " + String(name), "name", name);
}
let result: string | Uint8Array = Zeros;
while (name.length) {
const partition = name.match(Partition);
const label = toUtf8Bytes(nameprep(partition[3]));
result = keccak256(concat([result, keccak256(label)]));
name = partition[2] || "";
}
return hexlify(result);
}
export const messagePrefix = "\x19Ethereum Signed Message:\n";
export function hashMessage(message: Bytes | string): string {
if (typeof(message) === "string") { message = toUtf8Bytes(message); }
return keccak256(concat([
toUtf8Bytes(messagePrefix),
toUtf8Bytes(String(message.length)),
message
]));
} }

@ -1,7 +1,7 @@
import { TypedDataDomain, TypedDataField } from "@ethersproject/abstract-signer"; import { TypedDataDomain, TypedDataField } from "@ethersproject/abstract-signer";
import { getAddress } from "@ethersproject/address"; import { getAddress } from "@ethersproject/address";
import { BigNumber, BigNumberish } from "@ethersproject/bignumber"; import { BigNumber, BigNumberish } from "@ethersproject/bignumber";
import { arrayify, BytesLike, concat, hexConcat, hexZeroPad } from "@ethersproject/bytes"; import { arrayify, BytesLike, hexConcat, hexlify, hexZeroPad } from "@ethersproject/bytes";
import { keccak256 } from "@ethersproject/keccak256"; import { keccak256 } from "@ethersproject/keccak256";
import { deepCopy, defineReadOnly } from "@ethersproject/properties"; import { deepCopy, defineReadOnly } from "@ethersproject/properties";
@ -21,7 +21,11 @@ const MaxUint256: BigNumber = BigNumber.from("0xffffffffffffffffffffffffffffffff
function hexPadRight(value: BytesLike) { function hexPadRight(value: BytesLike) {
const bytes = arrayify(value); const bytes = arrayify(value);
return hexConcat([ bytes, padding.slice(bytes.length % 32) ]); const padOffset = bytes.length % 32
if (padOffset) {
return hexConcat([ bytes, padding.slice(padOffset) ]);
}
return hexlify(bytes);
} }
const hexTrue = hexZeroPad(One.toHexString(), 32); const hexTrue = hexZeroPad(One.toHexString(), 32);
@ -35,6 +39,10 @@ const domainFieldTypes: Record<string, string> = {
salt: "bytes32" salt: "bytes32"
}; };
const domainFieldNames: Array<string> = [
"name", "version", "chainId", "verifyingContract", "salt"
];
function getBaseEncoder(type: string): (value: any) => string { function getBaseEncoder(type: string): (value: any) => string {
// intXX and uintXX // intXX and uintXX
{ {
@ -217,20 +225,38 @@ export class TypedDataEncoder {
} }
_getEncoder(type: string): (value: any) => string { _getEncoder(type: string): (value: any) => string {
const match = type.match(/^([^\x5b]*)(\x5b(\d*)\x5d)?$/);
if (!match) { logger.throwArgumentError(`unknown type: ${ type }`, "type", type); }
const baseType = match[1]; // Basic encoder type
{
const encoder = getBaseEncoder(type);
if (encoder) { return encoder; }
}
let baseEncoder = getBaseEncoder(baseType); // Array
const match = type.match(/^(.*)(\x5b(\d*)\x5d)$/);
if (match) {
const subtype = match[1];
const subEncoder = this.getEncoder(subtype);
const length = parseInt(match[3]);
return (value: Array<any>) => {
if (length >= 0 && value.length !== length) {
logger.throwArgumentError("array length mismatch; expected length ${ arrayLength }", "value", value);
}
// A struct type let result = value.map(subEncoder);
if (baseEncoder == null) { if (this._types[subtype]) {
const fields = this.types[baseType]; result = result.map(keccak256);
if (!fields) { logger.throwArgumentError(`unknown type: ${ type }`, "type", type); } }
const encodedType = id(this._types[baseType]); return keccak256(hexConcat(result));
baseEncoder = (value: Record<string, any>) => { };
}
// Struct
const fields = this.types[type];
if (fields) {
const encodedType = id(this._types[type]);
return (value: Record<string, any>) => {
const values = fields.map((f) => { const values = fields.map((f) => {
const result = this.getEncoder(f.type)(value[f.name]); const result = this.getEncoder(f.type)(value[f.name]);
if (this._types[f.type]) { return keccak256(result); } if (this._types[f.type]) { return keccak256(result); }
@ -241,23 +267,7 @@ export class TypedDataEncoder {
} }
} }
// An array type return logger.throwArgumentError(`unknown type: ${ type }`, "type", type);
if (match[2]) {
const length = (match[3] ? parseInt(match[3]): -1);
return (value: Array<any>) => {
if (length >= 0 && value.length !== length) {
logger.throwArgumentError("array length mismatch; expected length ${ arrayLength }", "value", value);
}
let result = value.map(baseEncoder);
if (this._types[baseType]) {
result = result.map(keccak256);
}
return keccak256(hexConcat(result));
};
}
return baseEncoder;
} }
encodeType(name: string): string { encodeType(name: string): string {
@ -296,7 +306,7 @@ export class TypedDataEncoder {
return TypedDataEncoder.from(types).hashStruct(name, value); return TypedDataEncoder.from(types).hashStruct(name, value);
} }
static hashTypedDataDomain(domain: TypedDataDomain): string { static hashDomain(domain: TypedDataDomain): string {
const domainFields: Array<TypedDataField> = [ ]; const domainFields: Array<TypedDataField> = [ ];
for (const name in domain) { for (const name in domain) {
const type = domainFieldTypes[name]; const type = domainFieldTypes[name];
@ -305,15 +315,24 @@ export class TypedDataEncoder {
} }
domainFields.push({ name, type }); domainFields.push({ name, type });
} }
domainFields.sort((a, b) => {
return domainFieldNames.indexOf(a.name) - domainFieldNames.indexOf(b.name);
});
return TypedDataEncoder.hashStruct("EIP712Domain", { EIP712Domain: domainFields }, domain); return TypedDataEncoder.hashStruct("EIP712Domain", { EIP712Domain: domainFields }, domain);
} }
static hashTypedData(domain: TypedDataDomain, types: Record<string, Array<TypedDataField>>, value: Record<string, any>): string { static encode(domain: TypedDataDomain, types: Record<string, Array<TypedDataField>>, value: Record<string, any>): string {
return keccak256(concat([ return hexConcat([
"0x1901", "0x1901",
TypedDataEncoder.hashTypedDataDomain(domain), TypedDataEncoder.hashDomain(domain),
TypedDataEncoder.from(types).hash(value) TypedDataEncoder.from(types).hash(value)
])); ]);
}
static hash(domain: TypedDataDomain, types: Record<string, Array<TypedDataField>>, value: Record<string, any>): string {
return keccak256(TypedDataEncoder.encode(domain, types, value));
} }
} }