proxyd: Add global flag to method overrides

This allows us to set global limits on individual RPCs that ignore any origin/user agent exemption.
This commit is contained in:
Matthew Slipper 2023-02-15 00:42:44 -07:00
parent 454bc10e44
commit c17bcc9b83
6 changed files with 61 additions and 31 deletions

@ -55,6 +55,7 @@ type RateLimitConfig struct {
type RateLimitMethodOverride struct { type RateLimitMethodOverride struct {
Limit int `toml:"limit"` Limit int `toml:"limit"`
Interval TOMLDuration `toml:"interval"` Interval TOMLDuration `toml:"interval"`
Global bool `toml:"global"`
} }
type TOMLDuration time.Duration type TOMLDuration time.Duration

@ -13,7 +13,6 @@ require (
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.2 github.com/rs/cors v1.8.2
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
) )

@ -451,8 +451,6 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=

@ -139,6 +139,19 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
require.Nil(t, res[1].Error) require.Nil(t, res[1].Error)
require.Nil(t, res[2].Error) require.Nil(t, res[2].Error)
}) })
time.Sleep(time.Second)
t.Run("global RPC override", func(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
limitedRes, codes := spamReqs(t, client, "eth_baz", 429, 2)
// use 1 and 1 here since the limit for eth_baz is 1
require.Equal(t, 1, codes[429])
require.Equal(t, 1, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponseWithID), limitedRes)
})
} }
func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int, n int) ([]byte, map[int]int) { func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int, n int) ([]byte, map[int]int) {

@ -16,6 +16,7 @@ backends = ["good"]
[rpc_method_mappings] [rpc_method_mappings]
eth_chainId = "main" eth_chainId = "main"
eth_foobar = "main" eth_foobar = "main"
eth_baz = "main"
[rate_limit] [rate_limit]
base_rate = 2 base_rate = 2
@ -26,4 +27,9 @@ error_message = "over rate limit with special message"
[rate_limit.method_overrides.eth_foobar] [rate_limit.method_overrides.eth_foobar]
limit = 1 limit = 1
interval = "1s" interval = "1s"
[rate_limit.method_overrides.eth_baz]
limit = 1
interval = "1s"
global = true

@ -39,27 +39,28 @@ const (
var emptyArrayResponse = json.RawMessage("[]") var emptyArrayResponse = json.RawMessage("[]")
type Server struct { type Server struct {
backendGroups map[string]*BackendGroup backendGroups map[string]*BackendGroup
wsBackendGroup *BackendGroup wsBackendGroup *BackendGroup
wsMethodWhitelist *StringSet wsMethodWhitelist *StringSet
rpcMethodMappings map[string]string rpcMethodMappings map[string]string
maxBodySize int64 maxBodySize int64
enableRequestLog bool enableRequestLog bool
maxRequestBodyLogLen int maxRequestBodyLogLen int
authenticatedPaths map[string]string authenticatedPaths map[string]string
timeout time.Duration timeout time.Duration
maxUpstreamBatchSize int maxUpstreamBatchSize int
maxBatchSize int maxBatchSize int
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
mainLim FrontendRateLimiter mainLim FrontendRateLimiter
overrideLims map[string]FrontendRateLimiter overrideLims map[string]FrontendRateLimiter
senderLim FrontendRateLimiter senderLim FrontendRateLimiter
limExemptOrigins []*regexp.Regexp limExemptOrigins []*regexp.Regexp
limExemptUserAgents []*regexp.Regexp limExemptUserAgents []*regexp.Regexp
rpcServer *http.Server globallyLimitedMethods map[string]bool
wsServer *http.Server rpcServer *http.Server
cache RPCCache wsServer *http.Server
srvMu sync.Mutex cache RPCCache
srvMu sync.Mutex
} }
type limiterFunc func(method string) bool type limiterFunc func(method string) bool
@ -133,12 +134,17 @@ func NewServer(
} }
overrideLims := make(map[string]FrontendRateLimiter) overrideLims := make(map[string]FrontendRateLimiter)
globalMethodLims := make(map[string]bool)
for method, override := range rateLimitConfig.MethodOverrides { for method, override := range rateLimitConfig.MethodOverrides {
var err error var err error
overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method) overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if override.Global {
globalMethodLims[method] = true
}
} }
var senderLim FrontendRateLimiter var senderLim FrontendRateLimiter
if senderRateLimitConfig.Enabled { if senderRateLimitConfig.Enabled {
@ -161,11 +167,12 @@ func NewServer(
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
}, },
mainLim: mainLim, mainLim: mainLim,
overrideLims: overrideLims, overrideLims: overrideLims,
senderLim: senderLim, globallyLimitedMethods: globalMethodLims,
limExemptOrigins: limExemptOrigins, senderLim: senderLim,
limExemptUserAgents: limExemptUserAgents, limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
}, nil }, nil
} }
@ -243,7 +250,9 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
} }
isLimited := func(method string) bool { isLimited := func(method string) bool {
if isUnlimitedOrigin || isUnlimitedUserAgent { isGloballyLimitedMethod := s.isGlobalLimit(method)
fmt.Println(method, isGloballyLimitedMethod)
if !isGloballyLimitedMethod && (isUnlimitedOrigin || isUnlimitedUserAgent) {
return false return false
} }
@ -597,6 +606,10 @@ func (s *Server) isUnlimitedUserAgent(origin string) bool {
return false return false
} }
func (s *Server) isGlobalLimit(method string) bool {
return s.globallyLimitedMethods[method]
}
func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error { func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error {
var params []string var params []string
if err := json.Unmarshal(req.Params, &params); err != nil { if err := json.Unmarshal(req.Params, &params); err != nil {