Merge pull request #3479 from ethereum-optimism/develop

Develop -> Master
This commit is contained in:
Matthew Slipper 2022-09-15 10:49:26 +02:00 committed by GitHub
commit c1ab3f356d
4 changed files with 164 additions and 35 deletions

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"time"
) )
type ServerConfig struct { type ServerConfig struct {
@ -44,6 +45,24 @@ type RateLimitConfig struct {
ExemptOrigins []string `toml:"exempt_origins"` ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"` ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"` ErrorMessage string `toml:"error_message"`
MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"`
}
type RateLimitMethodOverride struct {
Limit int `toml:"limit"`
Interval TOMLDuration `toml:"interval"`
}
type TOMLDuration time.Duration
func (t *TOMLDuration) UnmarshalText(b []byte) error {
d, err := time.ParseDuration(string(b))
if err != nil {
return err
}
*t = TOMLDuration(d)
return nil
} }
type BackendOptions struct { type BackendOptions struct {

@ -1,6 +1,7 @@
package integration_tests package integration_tests
import ( import (
"encoding/json"
"net/http" "net/http"
"os" "os"
"testing" "testing"
@ -17,6 +18,8 @@ type resWithCode struct {
const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}` const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`
var ethChainID = "eth_chainId"
func TestBackendMaxRPSLimit(t *testing.T) { func TestBackendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
@ -28,8 +31,7 @@ func TestBackendMaxRPSLimit(t *testing.T) {
shutdown, err := proxyd.Start(config) shutdown, err := proxyd.Start(config)
require.NoError(t, err) require.NoError(t, err)
defer shutdown() defer shutdown()
limitedRes, codes := spamReqs(t, client, ethChainID, 503)
limitedRes, codes := spamReqs(t, client, 503)
require.Equal(t, 2, codes[200]) require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503]) require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes) RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
@ -48,7 +50,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
t.Run("non-exempt over limit", func(t *testing.T) { t.Run("non-exempt over limit", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545") client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, 429) limitedRes, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 1, codes[429]) require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200]) require.Equal(t, 2, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes) RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
@ -58,7 +60,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
h := make(http.Header) h := make(http.Header)
h.Set("User-Agent", "exempt_agent") h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429) _, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 3, codes[200]) require.Equal(t, 3, codes[200])
}) })
@ -66,7 +68,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
h := make(http.Header) h := make(http.Header)
h.Set("Origin", "exempt_origin") h.Set("Origin", "exempt_origin")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429) _, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 3, codes[200]) require.Equal(t, 3, codes[200])
}) })
@ -77,24 +79,72 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
h2.Set("X-Forwarded-For", "1.1.1.1") h2.Set("X-Forwarded-For", "1.1.1.1")
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1) client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2) client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
_, codes := spamReqs(t, client1, 429) _, codes := spamReqs(t, client1, ethChainID, 429)
require.Equal(t, 1, codes[429]) require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200]) require.Equal(t, 2, codes[200])
_, code, err := client2.SendRPC("eth_chainId", nil) _, code, err := client2.SendRPC(ethChainID, nil)
require.Equal(t, 200, code) require.Equal(t, 200, code)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
_, code, err = client2.SendRPC("eth_chainId", nil) _, code, err = client2.SendRPC(ethChainID, nil)
require.Equal(t, 200, code) require.Equal(t, 200, code)
require.NoError(t, err) require.NoError(t, err)
}) })
time.Sleep(time.Second)
t.Run("RPC override", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, "eth_foobar", 429)
// use 2 and 1 here since the limit for eth_foobar is 1
require.Equal(t, 2, codes[429])
require.Equal(t, 1, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
})
time.Sleep(time.Second)
t.Run("RPC override in batch", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
req := NewRPCReq("123", "eth_foobar", nil)
out, code, err := client.SendBatchRPC(req, req, req)
require.NoError(t, err)
var res []proxyd.RPCRes
require.NoError(t, json.Unmarshal(out, &res))
expCode := proxyd.ErrOverRateLimit.Code
require.Equal(t, 200, code)
require.Equal(t, 3, len(res))
require.Nil(t, res[0].Error)
require.Equal(t, expCode, res[1].Error.Code)
require.Equal(t, expCode, res[2].Error.Code)
})
time.Sleep(time.Second)
t.Run("RPC override in batch exempt", func(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
req := NewRPCReq("123", "eth_foobar", nil)
out, code, err := client.SendBatchRPC(req, req, req)
require.NoError(t, err)
var res []proxyd.RPCRes
require.NoError(t, json.Unmarshal(out, &res))
require.Equal(t, 200, code)
require.Equal(t, 3, len(res))
require.Nil(t, res[0].Error)
require.Nil(t, res[1].Error)
require.Nil(t, res[2].Error)
})
} }
func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) { func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int) ([]byte, map[int]int) {
resCh := make(chan *resWithCode) resCh := make(chan *resWithCode)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
go func() { go func() {
res, code, err := client.SendRPC("eth_chainId", nil) res, code, err := client.SendRPC(method, nil)
require.NoError(t, err) require.NoError(t, err)
resCh <- &resWithCode{ resCh <- &resWithCode{
code: code, code: code,

@ -15,9 +15,14 @@ backends = ["good"]
[rpc_method_mappings] [rpc_method_mappings]
eth_chainId = "main" eth_chainId = "main"
eth_foobar = "main"
[rate_limit] [rate_limit]
rate_per_second = 2 rate_per_second = 2
exempt_origins = ["exempt_origin"] exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"] exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit" error_message = "over rate limit"
[rate_limit.method_overrides.eth_foobar]
limit = 1
interval = "1s"

@ -49,7 +49,8 @@ type Server struct {
timeout time.Duration timeout time.Duration
maxUpstreamBatchSize int maxUpstreamBatchSize int
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
lim limiter.Store mainLim limiter.Store
overrideLims map[string]limiter.Store
limConfig RateLimitConfig limConfig RateLimitConfig
limExemptOrigins map[string]bool limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool limExemptUserAgents map[string]bool
@ -59,6 +60,8 @@ type Server struct {
srvMu sync.Mutex srvMu sync.Mutex
} }
type limiterFunc func(method string) bool
func NewServer( func NewServer(
backendGroups map[string]*BackendGroup, backendGroups map[string]*BackendGroup,
wsBackendGroup *BackendGroup, wsBackendGroup *BackendGroup,
@ -89,12 +92,12 @@ func NewServer(
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
} }
var lim limiter.Store var mainLim limiter.Store
limExemptOrigins := make(map[string]bool) limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool) limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 { if rateLimitConfig.RatePerSecond > 0 {
var err error var err error
lim, err = memorystore.New(&memorystore.Config{ mainLim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond), Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second, Interval: time.Second,
}) })
@ -109,7 +112,19 @@ func NewServer(
limExemptUserAgents[strings.ToLower(agent)] = true limExemptUserAgents[strings.ToLower(agent)] = true
} }
} else { } else {
lim, _ = noopstore.New() mainLim, _ = noopstore.New()
}
overrideLims := make(map[string]limiter.Store)
for method, override := range rateLimitConfig.MethodOverrides {
var err error
overrideLims[method], err = memorystore.New(&memorystore.Config{
Tokens: uint64(override.Limit),
Interval: time.Duration(override.Interval),
})
if err != nil {
return nil, err
}
} }
return &Server{ return &Server{
@ -127,7 +142,8 @@ func NewServer(
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
}, },
lim: lim, mainLim: mainLim,
overrideLims: overrideLims,
limConfig: rateLimitConfig, limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins, limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents, limExemptUserAgents: limExemptUserAgents,
@ -197,22 +213,37 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
userAgent := r.Header.Get("User-Agent") userAgent := r.Header.Get("User-Agent")
exemptOrigin := s.limExemptOrigins[strings.ToLower(origin)]
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(userAgent)]
// Use XFF in context since it will automatically be replaced by the remote IP // Use XFF in context since it will automatically be replaced by the remote IP
xff := stripXFF(GetXForwardedFor(ctx)) xff := stripXFF(GetXForwardedFor(ctx))
var ok bool isUnlimitedOrigin := s.isUnlimitedOrigin(origin)
if exemptOrigin || exemptUserAgent { isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)
ok = true
} else {
if xff == "" { if xff == "" {
log.Warn("rejecting request without XFF or remote IP") writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
ok = false return
}
isLimited := func(method string) bool {
if isUnlimitedOrigin || isUnlimitedUserAgent {
return false
}
var lim limiter.Store
if method == "" {
lim = s.mainLim
} else { } else {
_, _, _, ok, _ = s.lim.Take(ctx, xff) lim = s.overrideLims[method]
} }
if lim == nil {
return false
} }
if !ok {
_, _, _, ok, _ := lim.Take(ctx, xff)
return !ok
}
if isLimited("") {
rpcErr := ErrOverRateLimit.Clone() rpcErr := ErrOverRateLimit.Clone()
rpcErr.Message = s.limConfig.ErrorMessage rpcErr.Message = s.limConfig.ErrorMessage
RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr) RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr)
@ -271,7 +302,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return return
} }
batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, true) batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
writeRPCError(ctx, w, nil, ErrGatewayTimeout) writeRPCError(ctx, w, nil, ErrGatewayTimeout)
return return
@ -287,7 +318,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
} }
rawBody := json.RawMessage(body) rawBody := json.RawMessage(body)
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, false) backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
if err != nil { if err != nil {
writeRPCError(ctx, w, nil, ErrInternal) writeRPCError(ctx, w, nil, ErrInternal)
return return
@ -296,7 +327,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
writeRPCRes(ctx, w, backendRes[0]) writeRPCRes(ctx, w, backendRes[0])
} }
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isBatch bool) ([]*RPCRes, bool, error) { func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) {
// A request set is transformed into groups of batches. // A request set is transformed into groups of batches.
// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints) // Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's // A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
@ -347,6 +378,22 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isB
continue continue
} }
// Take rate limit for specific methods.
// NOTE: eventually, this should apply to all batch requests. However,
// since we don't have data right now on the size of each batch, we
// only apply this to the methods that have an additional rate limit.
if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) {
log.Info(
"rate limited specific RPC",
"source", "rpc",
"req_id", GetReqID(ctx),
"method", parsedReq.Method,
)
RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
continue
}
id := string(parsedReq.ID) id := string(parsedReq.ID)
// If this is a duplicate Request ID, move the Request to a new batchGroup // If this is a duplicate Request ID, move the Request to a new batchGroup
ids[id]++ ids[id]++
@ -494,6 +541,14 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
) )
} }
func (s *Server) isUnlimitedOrigin(origin string) bool {
return s.limExemptOrigins[strings.ToLower(origin)]
}
func (s *Server) isUnlimitedUserAgent(origin string) bool {
return s.limExemptUserAgents[strings.ToLower(origin)]
}
func setCacheHeader(w http.ResponseWriter, cached bool) { func setCacheHeader(w http.ResponseWriter, cached bool) {
if cached { if cached {
w.Header().Set(cacheStatusHdr, "HIT") w.Header().Set(cacheStatusHdr, "HIT")