diff --git a/ethclient/gethclient/gethclient_test.go b/ethclient/gethclient/gethclient_test.go index 7ff5836f81..c2c11d0770 100644 --- a/ethclient/gethclient/gethclient_test.go +++ b/ethclient/gethclient/gethclient_test.go @@ -105,6 +105,9 @@ func TestGethClient(t *testing.T) { { "TestGetProof", func(t *testing.T) { testGetProof(t, client) }, + }, { + "TestGetProofCanonicalizeKeys", + func(t *testing.T) { testGetProofCanonicalizeKeys(t, client) }, }, { "TestGCStats", func(t *testing.T) { testGCStats(t, client) }, @@ -218,6 +221,7 @@ func testGetProof(t *testing.T, client *rpc.Client) { if result.Balance.Cmp(balance) != 0 { t.Fatalf("invalid balance, want: %v got: %v", balance, result.Balance) } + // test storage if len(result.StorageProof) != 1 { t.Fatalf("invalid storage proof, want 1 proof, got %v proof(s)", len(result.StorageProof)) @@ -228,7 +232,37 @@ func testGetProof(t *testing.T, client *rpc.Client) { t.Fatalf("invalid storage proof value, want: %v, got: %v", slotValue, proof.Value.Bytes()) } if proof.Key != testSlot.String() { - t.Fatalf("invalid storage proof key, want: %v, got: %v", testSlot.String(), proof.Key) + t.Fatalf("invalid storage proof key, want: %q, got: %q", testSlot.String(), proof.Key) + } +} + +func testGetProofCanonicalizeKeys(t *testing.T, client *rpc.Client) { + ec := New(client) + + // Tests with non-canon input for storage keys. + // Here we check that the storage key is canonicalized. + result, err := ec.GetProof(context.Background(), testAddr, []string{"0x0dEadbeef"}, nil) + if err != nil { + t.Fatal(err) + } + if result.StorageProof[0].Key != "0xdeadbeef" { + t.Fatalf("wrong storage key encoding in proof: %q", result.StorageProof[0].Key) + } + if result, err = ec.GetProof(context.Background(), testAddr, []string{"0x000deadbeef"}, nil); err != nil { + t.Fatal(err) + } + if result.StorageProof[0].Key != "0xdeadbeef" { + t.Fatalf("wrong storage key encoding in proof: %q", result.StorageProof[0].Key) + } + + // If the requested storage key is 32 bytes long, it will be returned as is. + hashSizedKey := "0x00000000000000000000000000000000000000000000000000000000deadbeef" + result, err = ec.GetProof(context.Background(), testAddr, []string{hashSizedKey}, nil) + if err != nil { + t.Fatal(err) + } + if result.StorageProof[0].Key != hashSizedKey { + t.Fatalf("wrong storage key encoding in proof: %q", result.StorageProof[0].Key) } } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 48c6e38e88..ec29cf28f0 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -672,19 +672,21 @@ func (n *proofList) Delete(key []byte) error { func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNrOrHash rpc.BlockNumberOrHash) (*AccountResult, error) { var ( keys = make([]common.Hash, len(storageKeys)) + keyLengths = make([]int, len(storageKeys)) storageProof = make([]StorageResult, len(storageKeys)) storageTrie state.Trie storageHash = types.EmptyRootHash codeHash = types.EmptyCodeHash ) - // Greedily deserialize all keys. This prevents state access on invalid input + // Deserialize all keys. This prevents state access on invalid input. for i, hexKey := range storageKeys { - if key, err := decodeHash(hexKey); err != nil { + var err error + keys[i], keyLengths[i], err = decodeHash(hexKey) + if err != nil { return nil, err - } else { - keys[i] = key } } + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err @@ -692,28 +694,39 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st if storageTrie, err = state.StorageTrie(address); err != nil { return nil, err } - // if we have a storageTrie, the account exists and we must update + + // If we have a storageTrie, the account exists and we must update // the storage root hash and the code hash. if storageTrie != nil { storageHash = storageTrie.Hash() codeHash = state.GetCodeHash(address) } - // create the proof for the storageKeys + // Create the proofs for the storageKeys. for i, key := range keys { + // Output key encoding is a bit special: if the input was a 32-byte hash, it is + // returned as such. Otherwise, we apply the QUANTITY encoding mandated by the + // JSON-RPC spec for getProof. This behavior exists to preserve backwards + // compatibility with older client versions. + var outputKey string + if keyLengths[i] != 32 { + outputKey = hexutil.EncodeBig(key.Big()) + } else { + outputKey = hexutil.Encode(key[:]) + } + if storageTrie == nil { - storageProof[i] = StorageResult{storageKeys[i], &hexutil.Big{}, []string{}} + storageProof[i] = StorageResult{outputKey, &hexutil.Big{}, []string{}} continue } var proof proofList if err := storageTrie.Prove(crypto.Keccak256(key.Bytes()), &proof); err != nil { return nil, err } - storageProof[i] = StorageResult{storageKeys[i], - (*hexutil.Big)(state.GetState(address, key).Big()), - proof} + value := (*hexutil.Big)(state.GetState(address, key).Big()) + storageProof[i] = StorageResult{outputKey, value, proof} } - // create the accountProof + // Create the accountProof. accountProof, proofErr := state.GetProof(address) if proofErr != nil { return nil, proofErr @@ -732,7 +745,7 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st // decodeHash parses a hex-encoded 32-byte hash. The input may optionally // be prefixed by 0x and can have a byte length up to 32. -func decodeHash(s string) (common.Hash, error) { +func decodeHash(s string) (h common.Hash, inputLength int, err error) { if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") { s = s[2:] } @@ -741,12 +754,12 @@ func decodeHash(s string) (common.Hash, error) { } b, err := hex.DecodeString(s) if err != nil { - return common.Hash{}, errors.New("hex string invalid") + return common.Hash{}, 0, errors.New("hex string invalid") } if len(b) > 32 { - return common.Hash{}, errors.New("hex string too long, want at most 32 bytes") + return common.Hash{}, len(b), errors.New("hex string too long, want at most 32 bytes") } - return common.BytesToHash(b), nil + return common.BytesToHash(b), len(b), nil } // GetHeaderByNumber returns the requested canonical block header. @@ -876,7 +889,7 @@ func (s *BlockChainAPI) GetStorageAt(ctx context.Context, address common.Address if state == nil || err != nil { return nil, err } - key, err := decodeHash(hexKey) + key, _, err := decodeHash(hexKey) if err != nil { return nil, fmt.Errorf("unable to decode storage key: %s", err) }