From 15a59bad99d2561b9d0ff3b628e3b0ddbe65eed1 Mon Sep 17 00:00:00 2001 From: Chris Wessels Date: Tue, 11 Oct 2022 07:51:02 -0700 Subject: [PATCH 1/3] fix(proxyd): Fix compliance with JSON-RPC 2.0 spec by adding optional RPCError.Data (#3683) * fix: add optional data field to RPCError struct * fix: formatting lint * feat(proxyd): add changeset --- proxyd/proxyd/rpc.go | 1 + proxyd/proxyd/rpc_test.go | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/proxyd/proxyd/rpc.go b/proxyd/proxyd/rpc.go index e34b010..0c30d64 100644 --- a/proxyd/proxyd/rpc.go +++ b/proxyd/proxyd/rpc.go @@ -57,6 +57,7 @@ func (r *RPCRes) MarshalJSON() ([]byte, error) { type RPCErr struct { Code int `json:"code"` Message string `json:"message"` + Data string `json:"data,omitempty"` HTTPErrorCode int `json:"-"` } diff --git a/proxyd/proxyd/rpc_test.go b/proxyd/proxyd/rpc_test.go index 0d38dec..e30fe93 100644 --- a/proxyd/proxyd/rpc_test.go +++ b/proxyd/proxyd/rpc_test.go @@ -45,7 +45,7 @@ func TestRPCResJSON(t *testing.T) { `{"jsonrpc":"2.0","result":null,"id":123}`, }, { - "error result", + "error result without data", &RPCRes{ JSONRPC: JSONRPCVersion, Error: &RPCErr{ @@ -56,6 +56,19 @@ func TestRPCResJSON(t *testing.T) { }, `{"jsonrpc":"2.0","error":{"code":1234,"message":"test err"},"id":123}`, }, + { + "error result with data", + &RPCRes{ + JSONRPC: JSONRPCVersion, + Error: &RPCErr{ + Code: 1234, + Message: "test err", + Data: "revert", + }, + ID: []byte("123"), + }, + `{"jsonrpc":"2.0","error":{"code":1234,"message":"test err","data":"revert"},"id":123}`, + }, { "string ID", &RPCRes{ From fa7425683a1715f68c9c2349df5f36b57201c30b Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Sun, 9 Oct 2022 14:20:29 -0500 Subject: [PATCH 2/3] proxyd: Custom rate limiter implementation Our current proxyd deployment does not share rate limit state across multiple servers within a backend group. This means that rate limits on the public endpoint are artifically high. This PR adds a Redis-based rate limiter to fix this problem. While our current rate limiting library (github.com/sethvargo/go-limiter) _does_ support Redis, the client library it uses is not type safe, is less performant, and would require us to update the other places we use Redis. To avoid these issues, I created a simple rate limiting interface with both Redis and memory backend. Note that this PR only adds the new implementations - it does not integrate them with the rest of the codebase. I'll do that in a separate PR to make review easier. --- ...ate_limiter.go => backend_rate_limiter.go} | 0 proxyd/proxyd/config.go | 1 + proxyd/proxyd/frontend_rate_limiter.go | 121 ++++++++++++++++++ proxyd/proxyd/frontend_rate_limiter_test.go | 53 ++++++++ 4 files changed, 175 insertions(+) rename proxyd/proxyd/{rate_limiter.go => backend_rate_limiter.go} (100%) create mode 100644 proxyd/proxyd/frontend_rate_limiter.go create mode 100644 proxyd/proxyd/frontend_rate_limiter_test.go diff --git a/proxyd/proxyd/rate_limiter.go b/proxyd/proxyd/backend_rate_limiter.go similarity index 100% rename from proxyd/proxyd/rate_limiter.go rename to proxyd/proxyd/backend_rate_limiter.go diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index 0647074..aea54a9 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -41,6 +41,7 @@ type MetricsConfig struct { } type RateLimitConfig struct { + UseRedis bool `toml:"use_redis"` RatePerSecond int `toml:"rate_per_second"` ExemptOrigins []string `toml:"exempt_origins"` ExemptUserAgents []string `toml:"exempt_user_agents"` diff --git a/proxyd/proxyd/frontend_rate_limiter.go b/proxyd/proxyd/frontend_rate_limiter.go new file mode 100644 index 0000000..3b06052 --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter.go @@ -0,0 +1,121 @@ +package proxyd + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/go-redis/redis/v8" +) + +type FrontendRateLimiter interface { + // Take consumes a key, and a maximum number of requests + // per time interval. It returns a boolean denoting if + // the limit could be taken, or an error if a failure + // occurred in the backing rate limit implementation. + // + // No error will be returned if the limit could not be taken + // as a result of the requestor being over the limit. + Take(ctx context.Context, key string, max int) (bool, error) +} + +// limitedKeys is a wrapper around a map that stores a truncated +// timestamp and a mutex. The map is used to keep track of rate +// limit keys, and their used limits. +type limitedKeys struct { + truncTS int64 + keys map[string]int + mtx sync.Mutex +} + +func newLimitedKeys(t int64) *limitedKeys { + return &limitedKeys{ + truncTS: t, + keys: make(map[string]int), + } +} + +func (l *limitedKeys) Take(key string, max int) bool { + l.mtx.Lock() + defer l.mtx.Unlock() + val, ok := l.keys[key] + if !ok { + l.keys[key] = 0 + val = 0 + } + l.keys[key] = val + 1 + return val < max +} + +// MemoryFrontendRateLimiter is a rate limiter that stores +// all rate limiting information in local memory. It works +// by storing a limitedKeys struct that references the +// truncated timestamp at which the struct was created. If +// the current truncated timestamp doesn't match what's +// referenced, the limit is reset. Otherwise, values in +// a map are incremented to represent the limit. +type MemoryFrontendRateLimiter struct { + currGeneration *limitedKeys + dur time.Duration + mtx sync.Mutex +} + +func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter { + return &MemoryFrontendRateLimiter{ + dur: dur, + } +} + +func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + m.mtx.Lock() + // Create truncated timestamp + truncTS := truncateNow(m.dur) + + // If there is no current rate limit map or the rate limit map reference + // a different timestamp, reset limits. + if m.currGeneration == nil || m.currGeneration.truncTS != truncTS { + m.currGeneration = newLimitedKeys(truncTS) + } + + // Pull out the limiter so we can unlock before incrementing the limit. + limiter := m.currGeneration + + m.mtx.Unlock() + + return limiter.Take(key, max), nil +} + +// RedisFrontendRateLimiter is a rate limiter that stores data in Redis. +// It uses the basic rate limiter pattern described on the Redis best +// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. +type RedisFrontendRateLimiter struct { + r *redis.Client + dur time.Duration +} + +func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter { + return &RedisFrontendRateLimiter{r: r, dur: dur} +} + +func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + var incr *redis.IntCmd + truncTS := truncateNow(r.dur) + fullKey := fmt.Sprintf("%s:%d", key, truncTS) + _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { + incr = pipe.Incr(ctx, fullKey) + pipe.Expire(ctx, fullKey, r.dur-time.Second) + return nil + }) + if err != nil { + return false, err + } + + return incr.Val()-1 < int64(max), nil +} + +// truncateNow truncates the current timestamp +// to the specified duration. +func truncateNow(dur time.Duration) int64 { + return time.Now().Truncate(dur).Unix() +} diff --git a/proxyd/proxyd/frontend_rate_limiter_test.go b/proxyd/proxyd/frontend_rate_limiter_test.go new file mode 100644 index 0000000..c3d43cd --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter_test.go @@ -0,0 +1,53 @@ +package proxyd + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func TestFrontendRateLimiter(t *testing.T) { + redisServer, err := miniredis.Run() + require.NoError(t, err) + defer redisServer.Close() + + redisClient := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()), + }) + + lims := []struct { + name string + frl FrontendRateLimiter + }{ + {"memory", NewMemoryFrontendRateLimit(2 * time.Second)}, + {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)}, + } + + max := 2 + for _, cfg := range lims { + frl := cfg.frl + ctx := context.Background() + t.Run(cfg.name, func(t *testing.T) { + for i := 0; i < 4; i++ { + ok, err := frl.Take(ctx, "foo", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + ok, err = frl.Take(ctx, "bar", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + } + time.Sleep(2 * time.Second) + for i := 0; i < 4; i++ { + ok, _ := frl.Take(ctx, "foo", max) + require.Equal(t, i < max, ok) + ok, _ = frl.Take(ctx, "bar", max) + require.Equal(t, i < max, ok) + } + }) + } +} From f737002baac3dc134e4667e418b9daba8b980696 Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Sun, 9 Oct 2022 15:26:27 -0500 Subject: [PATCH 3/3] proxyd: Integrate custom rate limiter Integrates the custom rate limiter in the previous PR into the rest of the application. Also takes the opportunity to clean up how we instantiate Redis clients so that we can share them among multiple different services. There are some config changes in this PR. Specifically, you must specify a `base_rate` and `base_interval` in the rate limit config. --- proxyd/proxyd/backend_rate_limiter.go | 12 +---- proxyd/proxyd/cache.go | 12 +---- proxyd/proxyd/config.go | 3 +- proxyd/proxyd/frontend_rate_limiter.go | 42 ++++++++++----- proxyd/proxyd/frontend_rate_limiter_test.go | 14 ++--- .../proxyd/integration_tests/failover_test.go | 4 +- .../testdata/frontend_rate_limit.toml | 3 +- proxyd/proxyd/metrics.go | 6 +++ proxyd/proxyd/proxyd.go | 28 +++++----- proxyd/proxyd/redis.go | 22 ++++++++ proxyd/proxyd/server.go | 51 +++++++++---------- 11 files changed, 116 insertions(+), 81 deletions(-) create mode 100644 proxyd/proxyd/redis.go diff --git a/proxyd/proxyd/backend_rate_limiter.go b/proxyd/proxyd/backend_rate_limiter.go index fe286e6..03c6436 100644 --- a/proxyd/proxyd/backend_rate_limiter.go +++ b/proxyd/proxyd/backend_rate_limiter.go @@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct { tkMtx sync.Mutex } -func NewRedisRateLimiter(url string) (BackendRateLimiter, error) { - opts, err := redis.ParseURL(url) - if err != nil { - return nil, err - } - rdb := redis.NewClient(opts) - if err := rdb.Ping(context.Background()).Err(); err != nil { - return nil, wrapErr(err, "error connecting to redis") - } +func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter { out := &RedisBackendRateLimiter{ rdb: rdb, randID: randStr(20), touchKeys: make(map[string]time.Duration), } go out.touch() - return out, nil + return out } func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) { diff --git a/proxyd/proxyd/cache.go b/proxyd/proxyd/cache.go index 69dbb0b..73b7fd8 100644 --- a/proxyd/proxyd/cache.go +++ b/proxyd/proxyd/cache.go @@ -46,16 +46,8 @@ type redisCache struct { rdb *redis.Client } -func newRedisCache(url string) (*redisCache, error) { - opts, err := redis.ParseURL(url) - if err != nil { - return nil, err - } - rdb := redis.NewClient(opts) - if err := rdb.Ping(context.Background()).Err(); err != nil { - return nil, wrapErr(err, "error connecting to redis") - } - return &redisCache{rdb}, nil +func newRedisCache(rdb *redis.Client) *redisCache { + return &redisCache{rdb} } func (c *redisCache) Get(ctx context.Context, key string) (string, error) { diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index aea54a9..d0a32d6 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -42,7 +42,8 @@ type MetricsConfig struct { type RateLimitConfig struct { UseRedis bool `toml:"use_redis"` - RatePerSecond int `toml:"rate_per_second"` + BaseRate int `toml:"base_rate"` + BaseInterval TOMLDuration `toml:"base_interval"` ExemptOrigins []string `toml:"exempt_origins"` ExemptUserAgents []string `toml:"exempt_user_agents"` ErrorMessage string `toml:"error_message"` diff --git a/proxyd/proxyd/frontend_rate_limiter.go b/proxyd/proxyd/frontend_rate_limiter.go index 3b06052..d377370 100644 --- a/proxyd/proxyd/frontend_rate_limiter.go +++ b/proxyd/proxyd/frontend_rate_limiter.go @@ -17,7 +17,7 @@ type FrontendRateLimiter interface { // // No error will be returned if the limit could not be taken // as a result of the requestor being over the limit. - Take(ctx context.Context, key string, max int) (bool, error) + Take(ctx context.Context, key string) (bool, error) } // limitedKeys is a wrapper around a map that stores a truncated @@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool { type MemoryFrontendRateLimiter struct { currGeneration *limitedKeys dur time.Duration + max int mtx sync.Mutex } -func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter { +func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter { return &MemoryFrontendRateLimiter{ dur: dur, + max: max, } } -func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { +func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { m.mtx.Lock() // Create truncated timestamp truncTS := truncateNow(m.dur) @@ -83,35 +85,51 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in m.mtx.Unlock() - return limiter.Take(key, max), nil + return limiter.Take(key, m.max), nil } // RedisFrontendRateLimiter is a rate limiter that stores data in Redis. // It uses the basic rate limiter pattern described on the Redis best // practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. type RedisFrontendRateLimiter struct { - r *redis.Client - dur time.Duration + r *redis.Client + dur time.Duration + max int + prefix string } -func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter { - return &RedisFrontendRateLimiter{r: r, dur: dur} +func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter { + return &RedisFrontendRateLimiter{ + r: r, + dur: dur, + max: max, + prefix: prefix, + } } -func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { +func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { var incr *redis.IntCmd truncTS := truncateNow(r.dur) - fullKey := fmt.Sprintf("%s:%d", key, truncTS) + fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS) _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { incr = pipe.Incr(ctx, fullKey) - pipe.Expire(ctx, fullKey, r.dur-time.Second) + pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond) return nil }) if err != nil { + frontendRateLimitTakeErrors.Inc() return false, err } - return incr.Val()-1 < int64(max), nil + return incr.Val()-1 < int64(r.max), nil +} + +type noopFrontendRateLimiter struct{} + +var NoopFrontendRateLimiter = &noopFrontendRateLimiter{} + +func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { + return true, nil } // truncateNow truncates the current timestamp diff --git a/proxyd/proxyd/frontend_rate_limiter_test.go b/proxyd/proxyd/frontend_rate_limiter_test.go index c3d43cd..f3542cf 100644 --- a/proxyd/proxyd/frontend_rate_limiter_test.go +++ b/proxyd/proxyd/frontend_rate_limiter_test.go @@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) { Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()), }) + max := 2 lims := []struct { name string frl FrontendRateLimiter }{ - {"memory", NewMemoryFrontendRateLimit(2 * time.Second)}, - {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)}, + {"memory", NewMemoryFrontendRateLimit(2*time.Second, max)}, + {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second, max, "")}, } - max := 2 for _, cfg := range lims { frl := cfg.frl ctx := context.Background() t.Run(cfg.name, func(t *testing.T) { for i := 0; i < 4; i++ { - ok, err := frl.Take(ctx, "foo", max) + ok, err := frl.Take(ctx, "foo") require.NoError(t, err) require.Equal(t, i < max, ok) - ok, err = frl.Take(ctx, "bar", max) + ok, err = frl.Take(ctx, "bar") require.NoError(t, err) require.Equal(t, i < max, ok) } time.Sleep(2 * time.Second) for i := 0; i < 4; i++ { - ok, _ := frl.Take(ctx, "foo", max) + ok, _ := frl.Take(ctx, "foo") require.Equal(t, i < max, ok) - ok, _ = frl.Take(ctx, "bar", max) + ok, _ = frl.Take(ctx, "bar") require.Equal(t, i < max, ok) } }) diff --git a/proxyd/proxyd/integration_tests/failover_test.go b/proxyd/proxyd/integration_tests/failover_test.go index f80f47c..47c9e26 100644 --- a/proxyd/proxyd/integration_tests/failover_test.go +++ b/proxyd/proxyd/integration_tests/failover_test.go @@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) { config.BackendOptions.MaxRetries = 2 // Setup redis to detect offline backends config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port()) + redisClient, err := proxyd.NewRedisClient(config.Redis.URL) + require.NoError(t, err) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse)) defer goodBackend.Close() @@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) { require.Equal(t, 1, len(badBackend.Requests())) require.Equal(t, 1, len(goodBackend.Requests())) - rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL) + rr := proxyd.NewRedisRateLimiter(redisClient) require.NoError(t, err) online, err := rr.IsBackendOnline("bad") require.NoError(t, err) diff --git a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml index f34840d..affb855 100644 --- a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml +++ b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml @@ -18,7 +18,8 @@ eth_chainId = "main" eth_foobar = "main" [rate_limit] -rate_per_second = 2 +base_rate = 2 +base_interval = "1s" exempt_origins = ["exempt_origin"] exempt_user_agents = ["exempt_agent"] error_message = "over rate limit with special message" diff --git a/proxyd/proxyd/metrics.go b/proxyd/proxyd/metrics.go index a3cfe45..06fef15 100644 --- a/proxyd/proxyd/metrics.go +++ b/proxyd/proxyd/metrics.go @@ -236,6 +236,12 @@ var ( 100, }, }) + + frontendRateLimitTakeErrors = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: MetricsNamespace, + Name: "rate_limit_take_errors", + Help: "Count of errors taking frontend rate limits", + }) ) func RecordRedisError(source string) { diff --git a/proxyd/proxyd/proxyd.go b/proxyd/proxyd/proxyd.go index 12a6a1a..5685633 100644 --- a/proxyd/proxyd/proxyd.go +++ b/proxyd/proxyd/proxyd.go @@ -13,6 +13,7 @@ import ( "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/semaphore" ) @@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) { } } - var redisURL string + var redisClient *redis.Client if config.Redis.URL != "" { rURL, err := ReadFromEnvOrConfig(config.Redis.URL) if err != nil { return nil, err } - redisURL = rURL + redisClient, err = NewRedisClient(rURL) + if err != nil { + return nil, err + } + } + + if redisClient == nil && config.RateLimit.UseRedis { + return nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config") } var lim BackendRateLimiter var err error - if redisURL == "" { + if redisClient == nil { log.Warn("redis is not configured, using local rate limiter") lim = NewLocalBackendRateLimiter() } else { - lim, err = NewRedisRateLimiter(redisURL) - if err != nil { - return nil, err - } + lim = NewRedisRateLimiter(redisClient) } // While modifying shared globals is a bad practice, the alternative @@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) { return nil, err } - if redisURL != "" { - if cache, err = newRedisCache(redisURL); err != nil { - return nil, err - } - } else { + if redisClient == nil { log.Warn("redis is not configured, using in-memory cache") cache = newMemoryCache() + } else { + cache = newRedisCache(redisClient) } // Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind ethClient, err := ethclient.Dial(blockSyncRPCURL) @@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) { config.Server.EnableRequestLog, config.Server.MaxRequestBodyLogLen, config.BatchConfig.MaxSize, + redisClient, ) if err != nil { return nil, fmt.Errorf("error creating server: %w", err) diff --git a/proxyd/proxyd/redis.go b/proxyd/proxyd/redis.go new file mode 100644 index 0000000..e32bff2 --- /dev/null +++ b/proxyd/proxyd/redis.go @@ -0,0 +1,22 @@ +package proxyd + +import ( + "context" + "time" + + "github.com/go-redis/redis/v8" +) + +func NewRedisClient(url string) (*redis.Client, error) { + opts, err := redis.ParseURL(url) + if err != nil { + return nil, err + } + client := redis.NewClient(opts) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.Ping(ctx).Err(); err != nil { + return nil, wrapErr(err, "error connecting to redis") + } + return client, nil +} diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 94787d1..e86ba4c 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -13,11 +13,8 @@ import ( "sync" "time" - "github.com/sethvargo/go-limiter" - "github.com/sethvargo/go-limiter/memorystore" - "github.com/sethvargo/go-limiter/noopstore" - "github.com/ethereum/go-ethereum/log" + "github.com/go-redis/redis/v8" "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" @@ -50,9 +47,8 @@ type Server struct { maxUpstreamBatchSize int maxBatchSize int upgrader *websocket.Upgrader - mainLim limiter.Store - overrideLims map[string]limiter.Store - limConfig RateLimitConfig + mainLim FrontendRateLimiter + overrideLims map[string]FrontendRateLimiter limExemptOrigins map[string]bool limExemptUserAgents map[string]bool rpcServer *http.Server @@ -77,6 +73,7 @@ func NewServer( enableRequestLog bool, maxRequestBodyLogLen int, maxBatchSize int, + redisClient *redis.Client, ) (*Server, error) { if cache == nil { cache = &NoopRPCCache{} @@ -98,19 +95,19 @@ func NewServer( maxBatchSize = MaxBatchRPCCallsHardLimit } - var mainLim limiter.Store - limExemptOrigins := make(map[string]bool) - limExemptUserAgents := make(map[string]bool) - if rateLimitConfig.RatePerSecond > 0 { - var err error - mainLim, err = memorystore.New(&memorystore.Config{ - Tokens: uint64(rateLimitConfig.RatePerSecond), - Interval: time.Second, - }) - if err != nil { - return nil, err + limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter { + if rateLimitConfig.UseRedis { + return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix) } + return NewMemoryFrontendRateLimit(dur, max) + } + + var mainLim FrontendRateLimiter + limExemptOrigins := make(map[string]bool) + limExemptUserAgents := make(map[string]bool) + if rateLimitConfig.BaseRate > 0 { + mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main") for _, origin := range rateLimitConfig.ExemptOrigins { limExemptOrigins[strings.ToLower(origin)] = true } @@ -118,16 +115,13 @@ func NewServer( limExemptUserAgents[strings.ToLower(agent)] = true } } else { - mainLim, _ = noopstore.New() + mainLim = NoopFrontendRateLimiter } - overrideLims := make(map[string]limiter.Store) + overrideLims := make(map[string]FrontendRateLimiter) for method, override := range rateLimitConfig.MethodOverrides { var err error - overrideLims[method], err = memorystore.New(&memorystore.Config{ - Tokens: uint64(override.Limit), - Interval: time.Duration(override.Interval), - }) + overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method) if err != nil { return nil, err } @@ -151,7 +145,6 @@ func NewServer( }, mainLim: mainLim, overrideLims: overrideLims, - limConfig: rateLimitConfig, limExemptOrigins: limExemptOrigins, limExemptUserAgents: limExemptUserAgents, }, nil @@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { return false } - var lim limiter.Store + var lim FrontendRateLimiter if method == "" { lim = s.mainLim } else { @@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { return false } - _, _, _, ok, _ := lim.Take(ctx, xff) + ok, err := lim.Take(ctx, xff) + if err != nil { + log.Warn("error taking rate limit", "err", err) + return true + } return !ok }