diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index c81fec6..7a004f0 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -55,6 +55,7 @@ type RateLimitConfig struct { type RateLimitMethodOverride struct { Limit int `toml:"limit"` Interval TOMLDuration `toml:"interval"` + Global bool `toml:"global"` } type TOMLDuration time.Duration diff --git a/proxyd/proxyd/go.mod b/proxyd/proxyd/go.mod index 0a866b8..be300c1 100644 --- a/proxyd/proxyd/go.mod +++ b/proxyd/proxyd/go.mod @@ -13,7 +13,6 @@ require ( github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/prometheus/client_golang v1.11.0 github.com/rs/cors v1.8.2 - github.com/sethvargo/go-limiter v0.7.2 github.com/stretchr/testify v1.7.0 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c ) diff --git a/proxyd/proxyd/go.sum b/proxyd/proxyd/go.sum index 8e5fd8d..15350df 100644 --- a/proxyd/proxyd/go.sum +++ b/proxyd/proxyd/go.sum @@ -451,8 +451,6 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8= -github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= diff --git a/proxyd/proxyd/integration_tests/rate_limit_test.go b/proxyd/proxyd/integration_tests/rate_limit_test.go index e4cc698..7a70dea 100644 --- a/proxyd/proxyd/integration_tests/rate_limit_test.go +++ b/proxyd/proxyd/integration_tests/rate_limit_test.go @@ -139,6 +139,19 @@ func TestFrontendMaxRPSLimit(t *testing.T) { require.Nil(t, res[1].Error) require.Nil(t, res[2].Error) }) + + time.Sleep(time.Second) + + t.Run("global RPC override", func(t *testing.T) { + h := make(http.Header) + h.Set("User-Agent", "exempt_agent") + client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) + limitedRes, codes := spamReqs(t, client, "eth_baz", 429, 2) + // use 1 and 1 here since the limit for eth_baz is 1 + require.Equal(t, 1, codes[429]) + require.Equal(t, 1, codes[200]) + RequireEqualJSON(t, []byte(frontendOverLimitResponseWithID), limitedRes) + }) } func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int, n int) ([]byte, map[int]int) { diff --git a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml index affb855..8aa9d19 100644 --- a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml +++ b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml @@ -16,6 +16,7 @@ backends = ["good"] [rpc_method_mappings] eth_chainId = "main" eth_foobar = "main" +eth_baz = "main" [rate_limit] base_rate = 2 @@ -26,4 +27,9 @@ error_message = "over rate limit with special message" [rate_limit.method_overrides.eth_foobar] limit = 1 -interval = "1s" \ No newline at end of file +interval = "1s" + +[rate_limit.method_overrides.eth_baz] +limit = 1 +interval = "1s" +global = true \ No newline at end of file diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 5eab80e..5502d74 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -39,27 +39,28 @@ const ( var emptyArrayResponse = json.RawMessage("[]") type Server struct { - backendGroups map[string]*BackendGroup - wsBackendGroup *BackendGroup - wsMethodWhitelist *StringSet - rpcMethodMappings map[string]string - maxBodySize int64 - enableRequestLog bool - maxRequestBodyLogLen int - authenticatedPaths map[string]string - timeout time.Duration - maxUpstreamBatchSize int - maxBatchSize int - upgrader *websocket.Upgrader - mainLim FrontendRateLimiter - overrideLims map[string]FrontendRateLimiter - senderLim FrontendRateLimiter - limExemptOrigins []*regexp.Regexp - limExemptUserAgents []*regexp.Regexp - rpcServer *http.Server - wsServer *http.Server - cache RPCCache - srvMu sync.Mutex + backendGroups map[string]*BackendGroup + wsBackendGroup *BackendGroup + wsMethodWhitelist *StringSet + rpcMethodMappings map[string]string + maxBodySize int64 + enableRequestLog bool + maxRequestBodyLogLen int + authenticatedPaths map[string]string + timeout time.Duration + maxUpstreamBatchSize int + maxBatchSize int + upgrader *websocket.Upgrader + mainLim FrontendRateLimiter + overrideLims map[string]FrontendRateLimiter + senderLim FrontendRateLimiter + limExemptOrigins []*regexp.Regexp + limExemptUserAgents []*regexp.Regexp + globallyLimitedMethods map[string]bool + rpcServer *http.Server + wsServer *http.Server + cache RPCCache + srvMu sync.Mutex } type limiterFunc func(method string) bool @@ -133,12 +134,17 @@ func NewServer( } overrideLims := make(map[string]FrontendRateLimiter) + globalMethodLims := make(map[string]bool) for method, override := range rateLimitConfig.MethodOverrides { var err error overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method) if err != nil { return nil, err } + + if override.Global { + globalMethodLims[method] = true + } } var senderLim FrontendRateLimiter if senderRateLimitConfig.Enabled { @@ -161,11 +167,12 @@ func NewServer( upgrader: &websocket.Upgrader{ HandshakeTimeout: 5 * time.Second, }, - mainLim: mainLim, - overrideLims: overrideLims, - senderLim: senderLim, - limExemptOrigins: limExemptOrigins, - limExemptUserAgents: limExemptUserAgents, + mainLim: mainLim, + overrideLims: overrideLims, + globallyLimitedMethods: globalMethodLims, + senderLim: senderLim, + limExemptOrigins: limExemptOrigins, + limExemptUserAgents: limExemptUserAgents, }, nil } @@ -243,7 +250,9 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { } isLimited := func(method string) bool { - if isUnlimitedOrigin || isUnlimitedUserAgent { + isGloballyLimitedMethod := s.isGlobalLimit(method) + fmt.Println(method, isGloballyLimitedMethod) + if !isGloballyLimitedMethod && (isUnlimitedOrigin || isUnlimitedUserAgent) { return false } @@ -474,6 +483,7 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL "error forwarding RPC batch", "batch_size", len(elems), "backend_group", group, + "req_id", GetReqID(ctx), "err", err, ) res = nil @@ -596,6 +606,10 @@ func (s *Server) isUnlimitedUserAgent(origin string) bool { return false } +func (s *Server) isGlobalLimit(method string) bool { + return s.globallyLimitedMethods[method] +} + func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error { var params []string if err := json.Unmarshal(req.Params, ¶ms); err != nil { @@ -631,7 +645,7 @@ func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error { return ErrInvalidParams(err.Error()) } - ok, err := s.senderLim.Take(ctx, msg.From().Hex()) + ok, err := s.senderLim.Take(ctx, fmt.Sprintf("%s:%d", msg.From().Hex(), tx.Nonce())) if err != nil { log.Error("error taking from sender limiter", "err", err, "req_id", GetReqID(ctx)) return ErrInternal