diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index 56eb3cc..ba60d67 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strings" + "time" ) type ServerConfig struct { @@ -40,10 +41,28 @@ type MetricsConfig struct { } type RateLimitConfig struct { - RatePerSecond int `toml:"rate_per_second"` - ExemptOrigins []string `toml:"exempt_origins"` - ExemptUserAgents []string `toml:"exempt_user_agents"` - ErrorMessage string `toml:"error_message"` + RatePerSecond int `toml:"rate_per_second"` + ExemptOrigins []string `toml:"exempt_origins"` + ExemptUserAgents []string `toml:"exempt_user_agents"` + ErrorMessage string `toml:"error_message"` + MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"` +} + +type RateLimitMethodOverride struct { + Limit int `toml:"limit"` + Interval TOMLDuration `toml:"interval"` +} + +type TOMLDuration time.Duration + +func (t *TOMLDuration) UnmarshalText(b []byte) error { + d, err := time.ParseDuration(string(b)) + if err != nil { + return err + } + + *t = TOMLDuration(d) + return nil } type BackendOptions struct { diff --git a/proxyd/proxyd/integration_tests/rate_limit_test.go b/proxyd/proxyd/integration_tests/rate_limit_test.go index 310be22..5d3163a 100644 --- a/proxyd/proxyd/integration_tests/rate_limit_test.go +++ b/proxyd/proxyd/integration_tests/rate_limit_test.go @@ -1,6 +1,7 @@ package integration_tests import ( + "encoding/json" "net/http" "os" "testing" @@ -17,6 +18,8 @@ type resWithCode struct { const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}` +var ethChainID = "eth_chainId" + func TestBackendMaxRPSLimit(t *testing.T) { goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) defer goodBackend.Close() @@ -28,8 +31,7 @@ func TestBackendMaxRPSLimit(t *testing.T) { shutdown, err := proxyd.Start(config) require.NoError(t, err) defer shutdown() - - limitedRes, codes := spamReqs(t, client, 503) + limitedRes, codes := spamReqs(t, client, ethChainID, 503) require.Equal(t, 2, codes[200]) require.Equal(t, 1, codes[503]) RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes) @@ -48,7 +50,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) { t.Run("non-exempt over limit", func(t *testing.T) { client := NewProxydClient("http://127.0.0.1:8545") - limitedRes, codes := spamReqs(t, client, 429) + limitedRes, codes := spamReqs(t, client, ethChainID, 429) require.Equal(t, 1, codes[429]) require.Equal(t, 2, codes[200]) RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes) @@ -58,7 +60,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) { h := make(http.Header) h.Set("User-Agent", "exempt_agent") client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) - _, codes := spamReqs(t, client, 429) + _, codes := spamReqs(t, client, ethChainID, 429) require.Equal(t, 3, codes[200]) }) @@ -66,7 +68,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) { h := make(http.Header) h.Set("Origin", "exempt_origin") client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) - _, codes := spamReqs(t, client, 429) + _, codes := spamReqs(t, client, ethChainID, 429) require.Equal(t, 3, codes[200]) }) @@ -77,24 +79,72 @@ func TestFrontendMaxRPSLimit(t *testing.T) { h2.Set("X-Forwarded-For", "1.1.1.1") client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1) client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2) - _, codes := spamReqs(t, client1, 429) + _, codes := spamReqs(t, client1, ethChainID, 429) require.Equal(t, 1, codes[429]) require.Equal(t, 2, codes[200]) - _, code, err := client2.SendRPC("eth_chainId", nil) + _, code, err := client2.SendRPC(ethChainID, nil) require.Equal(t, 200, code) require.NoError(t, err) time.Sleep(time.Second) - _, code, err = client2.SendRPC("eth_chainId", nil) + _, code, err = client2.SendRPC(ethChainID, nil) require.Equal(t, 200, code) require.NoError(t, err) }) + + time.Sleep(time.Second) + + t.Run("RPC override", func(t *testing.T) { + client := NewProxydClient("http://127.0.0.1:8545") + limitedRes, codes := spamReqs(t, client, "eth_foobar", 429) + // use 2 and 1 here since the limit for eth_foobar is 1 + require.Equal(t, 2, codes[429]) + require.Equal(t, 1, codes[200]) + RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes) + }) + + time.Sleep(time.Second) + + t.Run("RPC override in batch", func(t *testing.T) { + client := NewProxydClient("http://127.0.0.1:8545") + req := NewRPCReq("123", "eth_foobar", nil) + out, code, err := client.SendBatchRPC(req, req, req) + require.NoError(t, err) + var res []proxyd.RPCRes + require.NoError(t, json.Unmarshal(out, &res)) + + expCode := proxyd.ErrOverRateLimit.Code + require.Equal(t, 200, code) + require.Equal(t, 3, len(res)) + require.Nil(t, res[0].Error) + require.Equal(t, expCode, res[1].Error.Code) + require.Equal(t, expCode, res[2].Error.Code) + }) + + time.Sleep(time.Second) + + t.Run("RPC override in batch exempt", func(t *testing.T) { + h := make(http.Header) + h.Set("User-Agent", "exempt_agent") + client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) + req := NewRPCReq("123", "eth_foobar", nil) + out, code, err := client.SendBatchRPC(req, req, req) + require.NoError(t, err) + var res []proxyd.RPCRes + require.NoError(t, json.Unmarshal(out, &res)) + + require.Equal(t, 200, code) + require.Equal(t, 3, len(res)) + require.Nil(t, res[0].Error) + require.Nil(t, res[1].Error) + require.Nil(t, res[2].Error) + }) } -func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) { +func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int) ([]byte, map[int]int) { resCh := make(chan *resWithCode) for i := 0; i < 3; i++ { go func() { - res, code, err := client.SendRPC("eth_chainId", nil) + res, code, err := client.SendRPC(method, nil) require.NoError(t, err) resCh <- &resWithCode{ code: code, diff --git a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml index e838ace..8615f3f 100644 --- a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml +++ b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml @@ -15,9 +15,14 @@ backends = ["good"] [rpc_method_mappings] eth_chainId = "main" +eth_foobar = "main" [rate_limit] rate_per_second = 2 exempt_origins = ["exempt_origin"] exempt_user_agents = ["exempt_agent"] error_message = "over rate limit" + +[rate_limit.method_overrides.eth_foobar] +limit = 1 +interval = "1s" \ No newline at end of file diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 86d7d5f..b3ff52e 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -49,7 +49,8 @@ type Server struct { timeout time.Duration maxUpstreamBatchSize int upgrader *websocket.Upgrader - lim limiter.Store + mainLim limiter.Store + overrideLims map[string]limiter.Store limConfig RateLimitConfig limExemptOrigins map[string]bool limExemptUserAgents map[string]bool @@ -59,6 +60,8 @@ type Server struct { srvMu sync.Mutex } +type limiterFunc func(method string) bool + func NewServer( backendGroups map[string]*BackendGroup, wsBackendGroup *BackendGroup, @@ -89,12 +92,12 @@ func NewServer( maxUpstreamBatchSize = defaultMaxUpstreamBatchSize } - var lim limiter.Store + var mainLim limiter.Store limExemptOrigins := make(map[string]bool) limExemptUserAgents := make(map[string]bool) if rateLimitConfig.RatePerSecond > 0 { var err error - lim, err = memorystore.New(&memorystore.Config{ + mainLim, err = memorystore.New(&memorystore.Config{ Tokens: uint64(rateLimitConfig.RatePerSecond), Interval: time.Second, }) @@ -109,7 +112,19 @@ func NewServer( limExemptUserAgents[strings.ToLower(agent)] = true } } else { - lim, _ = noopstore.New() + mainLim, _ = noopstore.New() + } + + overrideLims := make(map[string]limiter.Store) + for method, override := range rateLimitConfig.MethodOverrides { + var err error + overrideLims[method], err = memorystore.New(&memorystore.Config{ + Tokens: uint64(override.Limit), + Interval: time.Duration(override.Interval), + }) + if err != nil { + return nil, err + } } return &Server{ @@ -127,7 +142,8 @@ func NewServer( upgrader: &websocket.Upgrader{ HandshakeTimeout: 5 * time.Second, }, - lim: lim, + mainLim: mainLim, + overrideLims: overrideLims, limConfig: rateLimitConfig, limExemptOrigins: limExemptOrigins, limExemptUserAgents: limExemptUserAgents, @@ -197,22 +213,37 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") userAgent := r.Header.Get("User-Agent") - exemptOrigin := s.limExemptOrigins[strings.ToLower(origin)] - exemptUserAgent := s.limExemptUserAgents[strings.ToLower(userAgent)] // Use XFF in context since it will automatically be replaced by the remote IP xff := stripXFF(GetXForwardedFor(ctx)) - var ok bool - if exemptOrigin || exemptUserAgent { - ok = true - } else { - if xff == "" { - log.Warn("rejecting request without XFF or remote IP") - ok = false - } else { - _, _, _, ok, _ = s.lim.Take(ctx, xff) - } + isUnlimitedOrigin := s.isUnlimitedOrigin(origin) + isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent) + + if xff == "" { + writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP")) + return } - if !ok { + + isLimited := func(method string) bool { + if isUnlimitedOrigin || isUnlimitedUserAgent { + return false + } + + var lim limiter.Store + if method == "" { + lim = s.mainLim + } else { + lim = s.overrideLims[method] + } + + if lim == nil { + return false + } + + _, _, _, ok, _ := lim.Take(ctx, xff) + return !ok + } + + if isLimited("") { rpcErr := ErrOverRateLimit.Clone() rpcErr.Message = s.limConfig.ErrorMessage RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr) @@ -271,7 +302,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { return } - batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, true) + batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true) if err == context.DeadlineExceeded { writeRPCError(ctx, w, nil, ErrGatewayTimeout) return @@ -287,7 +318,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { } rawBody := json.RawMessage(body) - backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, false) + backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false) if err != nil { writeRPCError(ctx, w, nil, ErrInternal) return @@ -296,7 +327,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { writeRPCRes(ctx, w, backendRes[0]) } -func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isBatch bool) ([]*RPCRes, bool, error) { +func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) { // A request set is transformed into groups of batches. // Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints) // A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's @@ -347,6 +378,22 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isB continue } + // Take rate limit for specific methods. + // NOTE: eventually, this should apply to all batch requests. However, + // since we don't have data right now on the size of each batch, we + // only apply this to the methods that have an additional rate limit. + if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) { + log.Info( + "rate limited specific RPC", + "source", "rpc", + "req_id", GetReqID(ctx), + "method", parsedReq.Method, + ) + RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit) + responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit) + continue + } + id := string(parsedReq.ID) // If this is a duplicate Request ID, move the Request to a new batchGroup ids[id]++ @@ -494,6 +541,14 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context ) } +func (s *Server) isUnlimitedOrigin(origin string) bool { + return s.limExemptOrigins[strings.ToLower(origin)] +} + +func (s *Server) isUnlimitedUserAgent(origin string) bool { + return s.limExemptUserAgents[strings.ToLower(origin)] +} + func setCacheHeader(w http.ResponseWriter, cached bool) { if cached { w.Header().Set(cacheStatusHdr, "HIT")