proxyd: Support per-RPC rate limits (#3471)
* proxyd: Support per-RPC rate limits * add log Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
f6f4a32997
commit
ccf0934459
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
@ -40,10 +41,28 @@ type MetricsConfig struct {
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
RatePerSecond int `toml:"rate_per_second"`
|
||||
ExemptOrigins []string `toml:"exempt_origins"`
|
||||
ExemptUserAgents []string `toml:"exempt_user_agents"`
|
||||
ErrorMessage string `toml:"error_message"`
|
||||
RatePerSecond int `toml:"rate_per_second"`
|
||||
ExemptOrigins []string `toml:"exempt_origins"`
|
||||
ExemptUserAgents []string `toml:"exempt_user_agents"`
|
||||
ErrorMessage string `toml:"error_message"`
|
||||
MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"`
|
||||
}
|
||||
|
||||
type RateLimitMethodOverride struct {
|
||||
Limit int `toml:"limit"`
|
||||
Interval TOMLDuration `toml:"interval"`
|
||||
}
|
||||
|
||||
type TOMLDuration time.Duration
|
||||
|
||||
func (t *TOMLDuration) UnmarshalText(b []byte) error {
|
||||
d, err := time.ParseDuration(string(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = TOMLDuration(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
type BackendOptions struct {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package integration_tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
@ -17,6 +18,8 @@ type resWithCode struct {
|
||||
|
||||
const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`
|
||||
|
||||
var ethChainID = "eth_chainId"
|
||||
|
||||
func TestBackendMaxRPSLimit(t *testing.T) {
|
||||
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
|
||||
defer goodBackend.Close()
|
||||
@ -28,8 +31,7 @@ func TestBackendMaxRPSLimit(t *testing.T) {
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
limitedRes, codes := spamReqs(t, client, 503)
|
||||
limitedRes, codes := spamReqs(t, client, ethChainID, 503)
|
||||
require.Equal(t, 2, codes[200])
|
||||
require.Equal(t, 1, codes[503])
|
||||
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
|
||||
@ -48,7 +50,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
|
||||
|
||||
t.Run("non-exempt over limit", func(t *testing.T) {
|
||||
client := NewProxydClient("http://127.0.0.1:8545")
|
||||
limitedRes, codes := spamReqs(t, client, 429)
|
||||
limitedRes, codes := spamReqs(t, client, ethChainID, 429)
|
||||
require.Equal(t, 1, codes[429])
|
||||
require.Equal(t, 2, codes[200])
|
||||
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
|
||||
@ -58,7 +60,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("User-Agent", "exempt_agent")
|
||||
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
|
||||
_, codes := spamReqs(t, client, 429)
|
||||
_, codes := spamReqs(t, client, ethChainID, 429)
|
||||
require.Equal(t, 3, codes[200])
|
||||
})
|
||||
|
||||
@ -66,7 +68,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("Origin", "exempt_origin")
|
||||
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
|
||||
_, codes := spamReqs(t, client, 429)
|
||||
_, codes := spamReqs(t, client, ethChainID, 429)
|
||||
require.Equal(t, 3, codes[200])
|
||||
})
|
||||
|
||||
@ -77,24 +79,72 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
|
||||
h2.Set("X-Forwarded-For", "1.1.1.1")
|
||||
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
|
||||
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
|
||||
_, codes := spamReqs(t, client1, 429)
|
||||
_, codes := spamReqs(t, client1, ethChainID, 429)
|
||||
require.Equal(t, 1, codes[429])
|
||||
require.Equal(t, 2, codes[200])
|
||||
_, code, err := client2.SendRPC("eth_chainId", nil)
|
||||
_, code, err := client2.SendRPC(ethChainID, nil)
|
||||
require.Equal(t, 200, code)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
_, code, err = client2.SendRPC("eth_chainId", nil)
|
||||
_, code, err = client2.SendRPC(ethChainID, nil)
|
||||
require.Equal(t, 200, code)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
t.Run("RPC override", func(t *testing.T) {
|
||||
client := NewProxydClient("http://127.0.0.1:8545")
|
||||
limitedRes, codes := spamReqs(t, client, "eth_foobar", 429)
|
||||
// use 2 and 1 here since the limit for eth_foobar is 1
|
||||
require.Equal(t, 2, codes[429])
|
||||
require.Equal(t, 1, codes[200])
|
||||
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
|
||||
})
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
t.Run("RPC override in batch", func(t *testing.T) {
|
||||
client := NewProxydClient("http://127.0.0.1:8545")
|
||||
req := NewRPCReq("123", "eth_foobar", nil)
|
||||
out, code, err := client.SendBatchRPC(req, req, req)
|
||||
require.NoError(t, err)
|
||||
var res []proxyd.RPCRes
|
||||
require.NoError(t, json.Unmarshal(out, &res))
|
||||
|
||||
expCode := proxyd.ErrOverRateLimit.Code
|
||||
require.Equal(t, 200, code)
|
||||
require.Equal(t, 3, len(res))
|
||||
require.Nil(t, res[0].Error)
|
||||
require.Equal(t, expCode, res[1].Error.Code)
|
||||
require.Equal(t, expCode, res[2].Error.Code)
|
||||
})
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
t.Run("RPC override in batch exempt", func(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("User-Agent", "exempt_agent")
|
||||
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
|
||||
req := NewRPCReq("123", "eth_foobar", nil)
|
||||
out, code, err := client.SendBatchRPC(req, req, req)
|
||||
require.NoError(t, err)
|
||||
var res []proxyd.RPCRes
|
||||
require.NoError(t, json.Unmarshal(out, &res))
|
||||
|
||||
require.Equal(t, 200, code)
|
||||
require.Equal(t, 3, len(res))
|
||||
require.Nil(t, res[0].Error)
|
||||
require.Nil(t, res[1].Error)
|
||||
require.Nil(t, res[2].Error)
|
||||
})
|
||||
}
|
||||
|
||||
func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) {
|
||||
func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int) ([]byte, map[int]int) {
|
||||
resCh := make(chan *resWithCode)
|
||||
for i := 0; i < 3; i++ {
|
||||
go func() {
|
||||
res, code, err := client.SendRPC("eth_chainId", nil)
|
||||
res, code, err := client.SendRPC(method, nil)
|
||||
require.NoError(t, err)
|
||||
resCh <- &resWithCode{
|
||||
code: code,
|
||||
|
@ -15,9 +15,14 @@ backends = ["good"]
|
||||
|
||||
[rpc_method_mappings]
|
||||
eth_chainId = "main"
|
||||
eth_foobar = "main"
|
||||
|
||||
[rate_limit]
|
||||
rate_per_second = 2
|
||||
exempt_origins = ["exempt_origin"]
|
||||
exempt_user_agents = ["exempt_agent"]
|
||||
error_message = "over rate limit"
|
||||
|
||||
[rate_limit.method_overrides.eth_foobar]
|
||||
limit = 1
|
||||
interval = "1s"
|
@ -49,7 +49,8 @@ type Server struct {
|
||||
timeout time.Duration
|
||||
maxUpstreamBatchSize int
|
||||
upgrader *websocket.Upgrader
|
||||
lim limiter.Store
|
||||
mainLim limiter.Store
|
||||
overrideLims map[string]limiter.Store
|
||||
limConfig RateLimitConfig
|
||||
limExemptOrigins map[string]bool
|
||||
limExemptUserAgents map[string]bool
|
||||
@ -59,6 +60,8 @@ type Server struct {
|
||||
srvMu sync.Mutex
|
||||
}
|
||||
|
||||
type limiterFunc func(method string) bool
|
||||
|
||||
func NewServer(
|
||||
backendGroups map[string]*BackendGroup,
|
||||
wsBackendGroup *BackendGroup,
|
||||
@ -89,12 +92,12 @@ func NewServer(
|
||||
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
|
||||
}
|
||||
|
||||
var lim limiter.Store
|
||||
var mainLim limiter.Store
|
||||
limExemptOrigins := make(map[string]bool)
|
||||
limExemptUserAgents := make(map[string]bool)
|
||||
if rateLimitConfig.RatePerSecond > 0 {
|
||||
var err error
|
||||
lim, err = memorystore.New(&memorystore.Config{
|
||||
mainLim, err = memorystore.New(&memorystore.Config{
|
||||
Tokens: uint64(rateLimitConfig.RatePerSecond),
|
||||
Interval: time.Second,
|
||||
})
|
||||
@ -109,7 +112,19 @@ func NewServer(
|
||||
limExemptUserAgents[strings.ToLower(agent)] = true
|
||||
}
|
||||
} else {
|
||||
lim, _ = noopstore.New()
|
||||
mainLim, _ = noopstore.New()
|
||||
}
|
||||
|
||||
overrideLims := make(map[string]limiter.Store)
|
||||
for method, override := range rateLimitConfig.MethodOverrides {
|
||||
var err error
|
||||
overrideLims[method], err = memorystore.New(&memorystore.Config{
|
||||
Tokens: uint64(override.Limit),
|
||||
Interval: time.Duration(override.Interval),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
@ -127,7 +142,8 @@ func NewServer(
|
||||
upgrader: &websocket.Upgrader{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
lim: lim,
|
||||
mainLim: mainLim,
|
||||
overrideLims: overrideLims,
|
||||
limConfig: rateLimitConfig,
|
||||
limExemptOrigins: limExemptOrigins,
|
||||
limExemptUserAgents: limExemptUserAgents,
|
||||
@ -197,22 +213,37 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
origin := r.Header.Get("Origin")
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
exemptOrigin := s.limExemptOrigins[strings.ToLower(origin)]
|
||||
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(userAgent)]
|
||||
// Use XFF in context since it will automatically be replaced by the remote IP
|
||||
xff := stripXFF(GetXForwardedFor(ctx))
|
||||
var ok bool
|
||||
if exemptOrigin || exemptUserAgent {
|
||||
ok = true
|
||||
} else {
|
||||
if xff == "" {
|
||||
log.Warn("rejecting request without XFF or remote IP")
|
||||
ok = false
|
||||
} else {
|
||||
_, _, _, ok, _ = s.lim.Take(ctx, xff)
|
||||
}
|
||||
isUnlimitedOrigin := s.isUnlimitedOrigin(origin)
|
||||
isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)
|
||||
|
||||
if xff == "" {
|
||||
writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
|
||||
isLimited := func(method string) bool {
|
||||
if isUnlimitedOrigin || isUnlimitedUserAgent {
|
||||
return false
|
||||
}
|
||||
|
||||
var lim limiter.Store
|
||||
if method == "" {
|
||||
lim = s.mainLim
|
||||
} else {
|
||||
lim = s.overrideLims[method]
|
||||
}
|
||||
|
||||
if lim == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, _, _, ok, _ := lim.Take(ctx, xff)
|
||||
return !ok
|
||||
}
|
||||
|
||||
if isLimited("") {
|
||||
rpcErr := ErrOverRateLimit.Clone()
|
||||
rpcErr.Message = s.limConfig.ErrorMessage
|
||||
RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr)
|
||||
@ -271,7 +302,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, true)
|
||||
batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
|
||||
if err == context.DeadlineExceeded {
|
||||
writeRPCError(ctx, w, nil, ErrGatewayTimeout)
|
||||
return
|
||||
@ -287,7 +318,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
rawBody := json.RawMessage(body)
|
||||
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, false)
|
||||
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
|
||||
if err != nil {
|
||||
writeRPCError(ctx, w, nil, ErrInternal)
|
||||
return
|
||||
@ -296,7 +327,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
writeRPCRes(ctx, w, backendRes[0])
|
||||
}
|
||||
|
||||
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isBatch bool) ([]*RPCRes, bool, error) {
|
||||
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) {
|
||||
// A request set is transformed into groups of batches.
|
||||
// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
|
||||
// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
|
||||
@ -347,6 +378,22 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isB
|
||||
continue
|
||||
}
|
||||
|
||||
// Take rate limit for specific methods.
|
||||
// NOTE: eventually, this should apply to all batch requests. However,
|
||||
// since we don't have data right now on the size of each batch, we
|
||||
// only apply this to the methods that have an additional rate limit.
|
||||
if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) {
|
||||
log.Info(
|
||||
"rate limited specific RPC",
|
||||
"source", "rpc",
|
||||
"req_id", GetReqID(ctx),
|
||||
"method", parsedReq.Method,
|
||||
)
|
||||
RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
|
||||
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
|
||||
continue
|
||||
}
|
||||
|
||||
id := string(parsedReq.ID)
|
||||
// If this is a duplicate Request ID, move the Request to a new batchGroup
|
||||
ids[id]++
|
||||
@ -494,6 +541,14 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) isUnlimitedOrigin(origin string) bool {
|
||||
return s.limExemptOrigins[strings.ToLower(origin)]
|
||||
}
|
||||
|
||||
func (s *Server) isUnlimitedUserAgent(origin string) bool {
|
||||
return s.limExemptUserAgents[strings.ToLower(origin)]
|
||||
}
|
||||
|
||||
func setCacheHeader(w http.ResponseWriter, cached bool) {
|
||||
if cached {
|
||||
w.Header().Set(cacheStatusHdr, "HIT")
|
||||
|
Loading…
Reference in New Issue
Block a user