diff --git a/proxyd/proxyd/backend.go b/proxyd/proxyd/backend.go index f50926b..6b00a2e 100644 --- a/proxyd/proxyd/backend.go +++ b/proxyd/proxyd/backend.go @@ -90,6 +90,11 @@ var ( Message: "backend is currently not healthy to serve traffic", HTTPErrorCode: 503, } + ErrBlockOutOfRange = &RPCErr{ + Code: JSONRPCErrorInternal - 19, + Message: "block is out of range", + HTTPErrorCode: 400, + } ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response") ) @@ -220,6 +225,12 @@ func WithMaxErrorRateThreshold(maxErrorRateThreshold float64) BackendOpt { } } +type indexedReqRes struct { + index int + req *RPCReq + res *RPCRes +} + func NewBackend( name string, rpcURL string, @@ -599,47 +610,96 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch b backends := b.Backends - // When `consensus_aware` is set to `true`, the backend group acts as a load balancer - // serving traffic from any backend that agrees in the consensus group + overriddenResponses := make([]*indexedReqRes, 0) + rewrittenReqs := make([]*RPCReq, 0, len(rpcReqs)) + if b.Consensus != nil { + // When `consensus_aware` is set to `true`, the backend group acts as a load balancer + // serving traffic from any backend that agrees in the consensus group backends = b.loadBalancedConsensusGroup() + + // We also rewrite block tags to enforce compliance with consensus + rctx := RewriteContext{latest: b.Consensus.GetConsensusBlockNumber()} + + for i, req := range rpcReqs { + res := RPCRes{JSONRPC: JSONRPCVersion, ID: req.ID} + result, err := RewriteTags(rctx, req, &res) + switch result { + case RewriteOverrideError: + overriddenResponses = append(overriddenResponses, &indexedReqRes{ + index: i, + req: req, + res: &res, + }) + if errors.Is(err, ErrRewriteBlockOutOfRange) { + res.Error = ErrBlockOutOfRange + } else { + res.Error = ErrParseErr + } + case RewriteOverrideResponse: + overriddenResponses = append(overriddenResponses, &indexedReqRes{ + index: i, + req: req, + res: &res, + }) + case RewriteOverrideRequest, RewriteNone: + rewrittenReqs = append(rewrittenReqs, req) + } + } + rpcReqs = rewrittenReqs } rpcRequestsTotal.Inc() for _, back := range backends { - res, err := back.Forward(ctx, rpcReqs, isBatch) - if errors.Is(err, ErrMethodNotWhitelisted) { - return nil, err + res := make([]*RPCRes, 0) + var err error + + if len(rpcReqs) > 0 { + res, err = back.Forward(ctx, rpcReqs, isBatch) + if errors.Is(err, ErrMethodNotWhitelisted) { + return nil, err + } + if errors.Is(err, ErrBackendOffline) { + log.Warn( + "skipping offline backend", + "name", back.Name, + "auth", GetAuthCtx(ctx), + "req_id", GetReqID(ctx), + ) + continue + } + if errors.Is(err, ErrBackendOverCapacity) { + log.Warn( + "skipping over-capacity backend", + "name", back.Name, + "auth", GetAuthCtx(ctx), + "req_id", GetReqID(ctx), + ) + continue + } + if err != nil { + log.Error( + "error forwarding request to backend", + "name", back.Name, + "req_id", GetReqID(ctx), + "auth", GetAuthCtx(ctx), + "err", err, + ) + continue + } } - if errors.Is(err, ErrBackendOffline) { - log.Warn( - "skipping offline backend", - "name", back.Name, - "auth", GetAuthCtx(ctx), - "req_id", GetReqID(ctx), - ) - continue - } - if errors.Is(err, ErrBackendOverCapacity) { - log.Warn( - "skipping over-capacity backend", - "name", back.Name, - "auth", GetAuthCtx(ctx), - "req_id", GetReqID(ctx), - ) - continue - } - if err != nil { - log.Error( - "error forwarding request to backend", - "name", back.Name, - "req_id", GetReqID(ctx), - "auth", GetAuthCtx(ctx), - "err", err, - ) - continue + + // re-apply overridden responses + for _, ov := range overriddenResponses { + if len(res) > 0 { + // insert ov.res at position ov.index + res = append(res[:ov.index], append([]*RPCRes{ov.res}, res[ov.index:]...)...) + } else { + res = append(res, ov.res) + } } + return res, nil } diff --git a/proxyd/proxyd/integration_tests/consensus_test.go b/proxyd/proxyd/integration_tests/consensus_test.go index aa7b7e1..51d1f27 100644 --- a/proxyd/proxyd/integration_tests/consensus_test.go +++ b/proxyd/proxyd/integration_tests/consensus_test.go @@ -433,6 +433,184 @@ func TestConsensus(t *testing.T) { require.Equal(t, len(node1.Requests()), 0, msg) require.Equal(t, len(node2.Requests()), 10, msg) }) + + t.Run("rewrite response of eth_blockNumber", func(t *testing.T) { + h1.ResetOverrides() + h2.ResetOverrides() + node1.Reset() + node2.Reset() + bg.Consensus.Unban() + + // establish the consensus + + h1.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x2", "hash2"), + }) + h2.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x2", "hash2"), + }) + + for _, be := range bg.Backends { + bg.Consensus.UpdateBackend(ctx, be) + } + bg.Consensus.UpdateBackendGroupConsensus(ctx) + + totalRequests := len(node1.Requests()) + len(node2.Requests()) + + require.Equal(t, 2, len(bg.Consensus.GetConsensusGroup())) + + // pretend backends advanced in consensus, but we are still serving the latest value of the consensus + // until it gets updated again + + h1.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x3", "hash3"), + }) + h2.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x3", "hash3"), + }) + + resRaw, statusCode, err := client.SendRPC("eth_blockNumber", nil) + require.NoError(t, err) + require.Equal(t, 200, statusCode) + + var jsonMap map[string]interface{} + err = json.Unmarshal(resRaw, &jsonMap) + require.NoError(t, err) + require.Equal(t, "0x2", jsonMap["result"]) + + // no extra request hit the backends + require.Equal(t, totalRequests, len(node1.Requests())+len(node2.Requests())) + }) + + t.Run("rewrite request of eth_getBlockByNumber", func(t *testing.T) { + h1.ResetOverrides() + h2.ResetOverrides() + bg.Consensus.Unban() + + // establish the consensus and ban node2 for now + h1.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x2", "hash2"), + }) + h2.AddOverride(&ms.MethodTemplate{ + Method: "net_peerCount", + Block: "", + Response: buildPeerCountResponse(1), + }) + + for _, be := range bg.Backends { + bg.Consensus.UpdateBackend(ctx, be) + } + bg.Consensus.UpdateBackendGroupConsensus(ctx) + + require.Equal(t, 1, len(bg.Consensus.GetConsensusGroup())) + + node1.Reset() + + _, statusCode, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"latest"}) + require.NoError(t, err) + require.Equal(t, 200, statusCode) + + var jsonMap map[string]interface{} + err = json.Unmarshal(node1.Requests()[0].Body, &jsonMap) + require.NoError(t, err) + require.Equal(t, "0x2", jsonMap["params"].([]interface{})[0]) + }) + + t.Run("rewrite request of eth_getBlockByNumber - out of range", func(t *testing.T) { + h1.ResetOverrides() + h2.ResetOverrides() + bg.Consensus.Unban() + + // establish the consensus and ban node2 for now + h1.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x2", "hash2"), + }) + h2.AddOverride(&ms.MethodTemplate{ + Method: "net_peerCount", + Block: "", + Response: buildPeerCountResponse(1), + }) + + for _, be := range bg.Backends { + bg.Consensus.UpdateBackend(ctx, be) + } + bg.Consensus.UpdateBackendGroupConsensus(ctx) + + require.Equal(t, 1, len(bg.Consensus.GetConsensusGroup())) + + node1.Reset() + + resRaw, statusCode, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"0x10"}) + require.NoError(t, err) + require.Equal(t, 400, statusCode) + + var jsonMap map[string]interface{} + err = json.Unmarshal(resRaw, &jsonMap) + require.NoError(t, err) + require.Equal(t, -32019, int(jsonMap["error"].(map[string]interface{})["code"].(float64))) + require.Equal(t, "block is out of range", jsonMap["error"].(map[string]interface{})["message"]) + }) + + t.Run("batched rewrite", func(t *testing.T) { + h1.ResetOverrides() + h2.ResetOverrides() + bg.Consensus.Unban() + + // establish the consensus and ban node2 for now + h1.AddOverride(&ms.MethodTemplate{ + Method: "eth_getBlockByNumber", + Block: "latest", + Response: buildGetBlockResponse("0x2", "hash2"), + }) + h2.AddOverride(&ms.MethodTemplate{ + Method: "net_peerCount", + Block: "", + Response: buildPeerCountResponse(1), + }) + + for _, be := range bg.Backends { + bg.Consensus.UpdateBackend(ctx, be) + } + bg.Consensus.UpdateBackendGroupConsensus(ctx) + + require.Equal(t, 1, len(bg.Consensus.GetConsensusGroup())) + + node1.Reset() + + resRaw, statusCode, err := client.SendBatchRPC( + NewRPCReq("1", "eth_getBlockByNumber", []interface{}{"latest"}), + NewRPCReq("2", "eth_getBlockByNumber", []interface{}{"0x10"}), + NewRPCReq("3", "eth_getBlockByNumber", []interface{}{"0x1"})) + require.NoError(t, err) + require.Equal(t, 200, statusCode) + + var jsonMap []map[string]interface{} + err = json.Unmarshal(resRaw, &jsonMap) + require.NoError(t, err) + require.Equal(t, 3, len(jsonMap)) + + // rewrite latest to 0x2 + require.Equal(t, "0x2", jsonMap[0]["result"].(map[string]interface{})["number"]) + + // out of bounds for block 0x10 + require.Equal(t, -32019, int(jsonMap[1]["error"].(map[string]interface{})["code"].(float64))) + require.Equal(t, "block is out of range", jsonMap[1]["error"].(map[string]interface{})["message"]) + + // dont rewrite for 0x1 + require.Equal(t, "0x1", jsonMap[2]["result"].(map[string]interface{})["number"]) + }) } func backend(bg *proxyd.BackendGroup, name string) *proxyd.Backend { diff --git a/proxyd/proxyd/rewriter.go b/proxyd/proxyd/rewriter.go new file mode 100644 index 0000000..6b01c1a --- /dev/null +++ b/proxyd/proxyd/rewriter.go @@ -0,0 +1,175 @@ +package proxyd + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/ethereum/go-ethereum/common/hexutil" +) + +type RewriteContext struct { + latest hexutil.Uint64 +} + +type RewriteResult uint8 + +const ( + // RewriteNone means request should be forwarded as-is + RewriteNone RewriteResult = iota + + // RewriteOverrideError means there was an error attempting to rewrite + RewriteOverrideError + + // RewriteOverrideRequest means the modified request should be forwarded to the backend + RewriteOverrideRequest + + // RewriteOverrideResponse means to skip calling the backend and serve the overridden response + RewriteOverrideResponse +) + +var ( + ErrRewriteBlockOutOfRange = errors.New("block is out of range") +) + +// RewriteTags modifies the request and the response based on block tags +func RewriteTags(rctx RewriteContext, req *RPCReq, res *RPCRes) (RewriteResult, error) { + rw, err := RewriteResponse(rctx, req, res) + if rw == RewriteOverrideResponse { + return rw, err + } + return RewriteRequest(rctx, req, res) +} + +// RewriteResponse modifies the response object to comply with the rewrite context +// after the method has been called at the backend +// RewriteResult informs the decision of the rewrite +func RewriteResponse(rctx RewriteContext, req *RPCReq, res *RPCRes) (RewriteResult, error) { + switch req.Method { + case "eth_blockNumber": + res.Result = rctx.latest + return RewriteOverrideResponse, nil + } + return RewriteNone, nil +} + +// RewriteRequest modifies the request object to comply with the rewrite context +// before the method has been called at the backend +// it returns false if nothing was changed +func RewriteRequest(rctx RewriteContext, req *RPCReq, res *RPCRes) (RewriteResult, error) { + switch req.Method { + case "eth_getLogs", + "eth_newFilter": + return rewriteRange(rctx, req, res, 0) + case "eth_getBalance", + "eth_getCode", + "eth_getTransactionCount", + "eth_call": + return rewriteParam(rctx, req, res, 1) + case "eth_getStorageAt": + return rewriteParam(rctx, req, res, 2) + case "eth_getBlockTransactionCountByNumber", + "eth_getUncleCountByBlockNumber", + "eth_getBlockByNumber", + "eth_getTransactionByBlockNumberAndIndex", + "eth_getUncleByBlockNumberAndIndex": + return rewriteParam(rctx, req, res, 0) + } + return RewriteNone, nil +} + +func rewriteParam(rctx RewriteContext, req *RPCReq, res *RPCRes, pos int) (RewriteResult, error) { + var p []interface{} + err := json.Unmarshal(req.Params, &p) + if err != nil { + return RewriteOverrideError, err + } + + if len(p) <= pos { + p = append(p, "latest") + } + + val, rw, err := rewriteTag(rctx, p[pos].(string)) + if err != nil { + return RewriteOverrideError, err + } + + if rw { + p[pos] = val + paramRaw, err := json.Marshal(p) + if err != nil { + return RewriteOverrideError, err + } + req.Params = paramRaw + return RewriteOverrideRequest, nil + } + return RewriteNone, nil +} + +func rewriteRange(rctx RewriteContext, req *RPCReq, res *RPCRes, pos int) (RewriteResult, error) { + var p []map[string]interface{} + err := json.Unmarshal(req.Params, &p) + if err != nil { + return RewriteOverrideError, err + } + + modifiedFrom, err := rewriteTagMap(rctx, p[pos], "fromBlock") + if err != nil { + return RewriteOverrideError, err + } + + modifiedTo, err := rewriteTagMap(rctx, p[pos], "toBlock") + if err != nil { + return RewriteOverrideError, err + } + + // if any of the fields the request have been changed, re-marshal the params + if modifiedFrom || modifiedTo { + paramsRaw, err := json.Marshal(p) + req.Params = paramsRaw + if err != nil { + return RewriteOverrideError, err + } + return RewriteOverrideRequest, nil + } + + return RewriteNone, nil +} + +func rewriteTagMap(rctx RewriteContext, m map[string]interface{}, key string) (bool, error) { + if m[key] == nil || m[key] == "" { + return false, nil + } + + current, ok := m[key].(string) + if !ok { + return false, errors.New("expected string") + } + + val, rw, err := rewriteTag(rctx, current) + if err != nil { + return false, err + } + if rw { + m[key] = val + return true, nil + } + + return false, nil +} + +func rewriteTag(rctx RewriteContext, current string) (string, bool, error) { + if current == "latest" { + return rctx.latest.String(), true, nil + } else if strings.HasPrefix(current, "0x") { + decode, err := hexutil.DecodeUint64(current) + if err != nil { + return current, false, err + } + b := hexutil.Uint64(decode) + if b > rctx.latest { + return "", false, ErrRewriteBlockOutOfRange + } + } + return current, false, nil +} diff --git a/proxyd/proxyd/rewriter_test.go b/proxyd/proxyd/rewriter_test.go new file mode 100644 index 0000000..566ccdb --- /dev/null +++ b/proxyd/proxyd/rewriter_test.go @@ -0,0 +1,441 @@ +package proxyd + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/stretchr/testify/require" +) + +type args struct { + rctx RewriteContext + req *RPCReq + res *RPCRes +} + +type rewriteTest struct { + name string + args args + expected RewriteResult + expectedErr error + check func(*testing.T, args) +} + +func TestRewriteRequest(t *testing.T) { + tests := []rewriteTest{ + /* range scoped */ + { + name: "eth_getLogs fromBlock latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": "latest"}})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(100).String(), p[0]["fromBlock"]) + }, + }, + { + name: "eth_getLogs fromBlock within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": hexutil.Uint64(55).String()}})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(55).String(), p[0]["fromBlock"]) + }, + }, + { + name: "eth_getLogs fromBlock out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": hexutil.Uint64(111).String()}})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + { + name: "eth_getLogs toBlock latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"toBlock": "latest"}})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(100).String(), p[0]["toBlock"]) + }, + }, + { + name: "eth_getLogs toBlock within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"toBlock": hexutil.Uint64(55).String()}})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(55).String(), p[0]["toBlock"]) + }, + }, + { + name: "eth_getLogs toBlock out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"toBlock": hexutil.Uint64(111).String()}})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + { + name: "eth_getLogs fromBlock, toBlock latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": "latest", "toBlock": "latest"}})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(100).String(), p[0]["fromBlock"]) + require.Equal(t, hexutil.Uint64(100).String(), p[0]["toBlock"]) + }, + }, + { + name: "eth_getLogs fromBlock, toBlock within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": hexutil.Uint64(55).String(), "toBlock": hexutil.Uint64(77).String()}})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []map[string]interface{} + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, hexutil.Uint64(55).String(), p[0]["fromBlock"]) + require.Equal(t, hexutil.Uint64(77).String(), p[0]["toBlock"]) + }, + }, + { + name: "eth_getLogs fromBlock, toBlock out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getLogs", Params: mustMarshalJSON([]map[string]interface{}{{"fromBlock": hexutil.Uint64(111).String(), "toBlock": hexutil.Uint64(222).String()}})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + /* default block parameter */ + { + name: "eth_getCode omit block, should add", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getCode", Params: mustMarshalJSON([]string{"0x123"})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 2, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, hexutil.Uint64(100).String(), p[1]) + }, + }, + { + name: "eth_getCode latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getCode", Params: mustMarshalJSON([]string{"0x123", "latest"})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 2, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, hexutil.Uint64(100).String(), p[1]) + }, + }, + { + name: "eth_getCode within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getCode", Params: mustMarshalJSON([]string{"0x123", hexutil.Uint64(55).String()})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 2, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, hexutil.Uint64(55).String(), p[1]) + }, + }, + { + name: "eth_getCode out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getCode", Params: mustMarshalJSON([]string{"0x123", hexutil.Uint64(111).String()})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + /* default block parameter, at position 2 */ + { + name: "eth_getStorageAt omit block, should add", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getStorageAt", Params: mustMarshalJSON([]string{"0x123", "5"})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 3, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, "5", p[1]) + require.Equal(t, hexutil.Uint64(100).String(), p[2]) + }, + }, + { + name: "eth_getStorageAt latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getStorageAt", Params: mustMarshalJSON([]string{"0x123", "5", "latest"})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 3, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, "5", p[1]) + require.Equal(t, hexutil.Uint64(100).String(), p[2]) + }, + }, + { + name: "eth_getStorageAt within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getStorageAt", Params: mustMarshalJSON([]string{"0x123", "5", hexutil.Uint64(55).String()})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 3, len(p)) + require.Equal(t, "0x123", p[0]) + require.Equal(t, "5", p[1]) + require.Equal(t, hexutil.Uint64(55).String(), p[2]) + }, + }, + { + name: "eth_getStorageAt out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getStorageAt", Params: mustMarshalJSON([]string{"0x123", "5", hexutil.Uint64(111).String()})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + /* default block parameter, at position 0 */ + { + name: "eth_getBlockByNumber omit block, should add", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getBlockByNumber", Params: mustMarshalJSON([]string{})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 1, len(p)) + require.Equal(t, hexutil.Uint64(100).String(), p[0]) + }, + }, + { + name: "eth_getBlockByNumber latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getBlockByNumber", Params: mustMarshalJSON([]string{"latest"})}, + res: nil, + }, + expected: RewriteOverrideRequest, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 1, len(p)) + require.Equal(t, hexutil.Uint64(100).String(), p[0]) + }, + }, + { + name: "eth_getBlockByNumber within range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getBlockByNumber", Params: mustMarshalJSON([]string{hexutil.Uint64(55).String()})}, + res: nil, + }, + expected: RewriteNone, + check: func(t *testing.T, args args) { + var p []string + err := json.Unmarshal(args.req.Params, &p) + require.Nil(t, err) + require.Equal(t, 1, len(p)) + require.Equal(t, hexutil.Uint64(55).String(), p[0]) + }, + }, + { + name: "eth_getBlockByNumber out of range", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_getBlockByNumber", Params: mustMarshalJSON([]string{hexutil.Uint64(111).String()})}, + res: nil, + }, + expected: RewriteOverrideError, + expectedErr: ErrRewriteBlockOutOfRange, + }, + } + + // generalize tests for other methods with same interface and behavior + tests = generalize(tests, "eth_getLogs", "eth_newFilter") + tests = generalize(tests, "eth_getCode", "eth_getBalance") + tests = generalize(tests, "eth_getCode", "eth_getTransactionCount") + tests = generalize(tests, "eth_getCode", "eth_call") + tests = generalize(tests, "eth_getBlockByNumber", "eth_getBlockTransactionCountByNumber") + tests = generalize(tests, "eth_getBlockByNumber", "eth_getUncleCountByBlockNumber") + tests = generalize(tests, "eth_getBlockByNumber", "eth_getTransactionByBlockNumberAndIndex") + tests = generalize(tests, "eth_getBlockByNumber", "eth_getUncleByBlockNumberAndIndex") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := RewriteRequest(tt.args.rctx, tt.args.req, tt.args.res) + if result != RewriteOverrideError { + require.Nil(t, err) + require.Equal(t, tt.expected, result) + } else { + require.Equal(t, tt.expectedErr, err) + } + if tt.check != nil { + tt.check(t, tt.args) + } + }) + } +} + +func generalize(tests []rewriteTest, baseMethod string, generalizedMethod string) []rewriteTest { + newCases := make([]rewriteTest, 0) + for _, t := range tests { + if t.args.req.Method == baseMethod { + newName := strings.Replace(t.name, baseMethod, generalizedMethod, -1) + var req *RPCReq + var res *RPCRes + + if t.args.req != nil { + req = &RPCReq{ + JSONRPC: t.args.req.JSONRPC, + Method: generalizedMethod, + Params: t.args.req.Params, + ID: t.args.req.ID, + } + } + + if t.args.res != nil { + res = &RPCRes{ + JSONRPC: t.args.res.JSONRPC, + Result: t.args.res.Result, + Error: t.args.res.Error, + ID: t.args.res.ID, + } + } + newCases = append(newCases, rewriteTest{ + name: newName, + args: args{ + rctx: t.args.rctx, + req: req, + res: res, + }, + expected: t.expected, + expectedErr: t.expectedErr, + check: t.check, + }) + } + } + return append(tests, newCases...) +} + +func TestRewriteResponse(t *testing.T) { + type args struct { + rctx RewriteContext + req *RPCReq + res *RPCRes + } + tests := []struct { + name string + args args + expected RewriteResult + check func(*testing.T, args) + }{ + { + name: "eth_blockNumber latest", + args: args{ + rctx: RewriteContext{latest: hexutil.Uint64(100)}, + req: &RPCReq{Method: "eth_blockNumber"}, + res: &RPCRes{Result: hexutil.Uint64(200)}, + }, + expected: RewriteOverrideResponse, + check: func(t *testing.T, args args) { + require.Equal(t, args.res.Result, hexutil.Uint64(100)) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := RewriteResponse(tt.args.rctx, tt.args.req, tt.args.res) + require.Nil(t, err) + require.Equal(t, tt.expected, result) + if tt.check != nil { + tt.check(t, tt.args) + } + }) + } +} diff --git a/proxyd/proxyd/tools/mockserver/handler/handler.go b/proxyd/proxyd/tools/mockserver/handler/handler.go index 18d6026..04f30e7 100644 --- a/proxyd/proxyd/tools/mockserver/handler/handler.go +++ b/proxyd/proxyd/tools/mockserver/handler/handler.go @@ -6,6 +6,9 @@ import ( "io" "net/http" "os" + "strings" + + "github.com/ethereum-optimism/optimism/proxyd" "github.com/gorilla/mux" "github.com/pkg/errors" @@ -46,12 +49,6 @@ func (mh *MockedHandler) Handler(w http.ResponseWriter, req *http.Request) { fmt.Printf("error reading request: %v\n", err) } - var j map[string]interface{} - err = json.Unmarshal(body, &j) - if err != nil { - fmt.Printf("error reading request: %v\n", err) - } - var template []*MethodTemplate if mh.Autoload { template = append(template, mh.LoadFromFile(mh.AutoloadFile)...) @@ -60,23 +57,51 @@ func (mh *MockedHandler) Handler(w http.ResponseWriter, req *http.Request) { template = append(template, mh.Overrides...) } - method := j["method"] - block := "" - if method == "eth_getBlockByNumber" { - block = (j["params"].([]interface{})[0]).(string) + batched := proxyd.IsBatch(body) + var requests []map[string]interface{} + if batched { + err = json.Unmarshal(body, &requests) + if err != nil { + fmt.Printf("error reading request: %v\n", err) + } + } else { + var j map[string]interface{} + err = json.Unmarshal(body, &j) + if err != nil { + fmt.Printf("error reading request: %v\n", err) + } + requests = append(requests, j) } - var selectedResponse *string - for _, r := range template { - if r.Method == method && r.Block == block { - selectedResponse = &r.Response + var responses []string + for _, r := range requests { + method := r["method"] + block := "" + if method == "eth_getBlockByNumber" { + block = (r["params"].([]interface{})[0]).(string) + } + + var selectedResponse string + for _, r := range template { + if r.Method == method && r.Block == block { + selectedResponse = r.Response + } + } + if selectedResponse != "" { + responses = append(responses, selectedResponse) } } - if selectedResponse != nil { - _, err := fmt.Fprintf(w, *selectedResponse) - if err != nil { - fmt.Printf("error writing response: %v\n", err) - } + + resBody := "" + if batched { + resBody = "[" + strings.Join(responses, ",") + "]" + } else { + resBody = responses[0] + } + + _, err = fmt.Fprint(w, resBody) + if err != nil { + fmt.Printf("error writing response: %v\n", err) } }