From 0fb094feb4065869144e93ce5490577fffed908e Mon Sep 17 00:00:00 2001 From: cody-wang-cb Date: Wed, 31 Jul 2024 16:25:53 -0400 Subject: [PATCH] Move base rate rate limit check inside handleBatchRPC() (#37) * move base rate rate limit check * fix comment --- proxyd/integration_tests/rate_limit_test.go | 22 +++++++++++++-- proxyd/server.go | 30 +++++++++------------ 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/proxyd/integration_tests/rate_limit_test.go b/proxyd/integration_tests/rate_limit_test.go index 4e17f62..0801d05 100644 --- a/proxyd/integration_tests/rate_limit_test.go +++ b/proxyd/integration_tests/rate_limit_test.go @@ -16,7 +16,6 @@ type resWithCode struct { res []byte } -const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit with special message"},"id":null,"jsonrpc":"2.0"}` const frontendOverLimitResponseWithID = `{"error":{"code":-32016,"message":"over rate limit with special message"},"id":999,"jsonrpc":"2.0"}` var ethChainID = "eth_chainId" @@ -37,7 +36,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) { limitedRes, codes := spamReqs(t, client, ethChainID, 429, 3) require.Equal(t, 1, codes[429]) require.Equal(t, 2, codes[200]) - RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes) + RequireEqualJSON(t, []byte(frontendOverLimitResponseWithID), limitedRes) }) t.Run("exempt user agent over limit", func(t *testing.T) { @@ -106,6 +105,25 @@ func TestFrontendMaxRPSLimit(t *testing.T) { time.Sleep(time.Second) + t.Run("Batch RPC with some requests rate limited", func(t *testing.T) { + client := NewProxydClient("http://127.0.0.1:8545") + req := NewRPCReq("123", "eth_chainId", nil) + out, code, err := client.SendBatchRPC(req, req, req) + require.NoError(t, err) + var res []proxyd.RPCRes + require.NoError(t, json.Unmarshal(out, &res)) + + expCode := proxyd.ErrOverRateLimit.Code + require.Equal(t, 200, code) + require.Equal(t, 3, len(res)) + require.Nil(t, res[0].Error) + require.Nil(t, res[1].Error) + // base rate = 2, so the third request should be rate limited + require.Equal(t, expCode, res[2].Error.Code) + }) + + time.Sleep(time.Second) + t.Run("RPC override in batch exempt", func(t *testing.T) { h := make(http.Header) h.Set("User-Agent", "exempt_agent") diff --git a/proxyd/server.go b/proxyd/server.go index c663f42..0c053f5 100644 --- a/proxyd/server.go +++ b/proxyd/server.go @@ -301,20 +301,6 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { return !ok } - if isLimited("") { - RecordRPCError(ctx, BackendProxyd, "unknown", ErrOverRateLimit) - log.Warn( - "rate limited request", - "req_id", GetReqID(ctx), - "auth", GetAuthCtx(ctx), - "user_agent", userAgent, - "origin", origin, - "remote_ip", xff, - ) - writeRPCError(ctx, w, nil, ErrOverRateLimit) - return - } - log.Info( "received RPC request", "req_id", GetReqID(ctx), @@ -469,10 +455,20 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL continue } + // Take base rate limit first + if isLimited("") { + log.Info( + "rate limited individual RPC in a batch request", + "source", "rpc", + "req_id", parsedReq.ID, + "method", parsedReq.Method, + ) + RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit) + responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit) + continue + } + // Take rate limit for specific methods. - // NOTE: eventually, this should apply to all batch requests. However, - // since we don't have data right now on the size of each batch, we - // only apply this to the methods that have an additional rate limit. if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) { log.Info( "rate limited specific RPC",