proxyd: Integrate custom rate limiter
Integrates the custom rate limiter in the previous PR into the rest of the application. Also takes the opportunity to clean up how we instantiate Redis clients so that we can share them among multiple different services. There are some config changes in this PR. Specifically, you must specify a `base_rate` and `base_interval` in the rate limit config.
This commit is contained in:
parent
fa7425683a
commit
f737002baa
@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct {
|
||||
tkMtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rdb := redis.NewClient(opts)
|
||||
if err := rdb.Ping(context.Background()).Err(); err != nil {
|
||||
return nil, wrapErr(err, "error connecting to redis")
|
||||
}
|
||||
func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter {
|
||||
out := &RedisBackendRateLimiter{
|
||||
rdb: rdb,
|
||||
randID: randStr(20),
|
||||
touchKeys: make(map[string]time.Duration),
|
||||
}
|
||||
go out.touch()
|
||||
return out, nil
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
|
@ -46,16 +46,8 @@ type redisCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func newRedisCache(url string) (*redisCache, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rdb := redis.NewClient(opts)
|
||||
if err := rdb.Ping(context.Background()).Err(); err != nil {
|
||||
return nil, wrapErr(err, "error connecting to redis")
|
||||
}
|
||||
return &redisCache{rdb}, nil
|
||||
func newRedisCache(rdb *redis.Client) *redisCache {
|
||||
return &redisCache{rdb}
|
||||
}
|
||||
|
||||
func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
|
||||
|
@ -42,7 +42,8 @@ type MetricsConfig struct {
|
||||
|
||||
type RateLimitConfig struct {
|
||||
UseRedis bool `toml:"use_redis"`
|
||||
RatePerSecond int `toml:"rate_per_second"`
|
||||
BaseRate int `toml:"base_rate"`
|
||||
BaseInterval TOMLDuration `toml:"base_interval"`
|
||||
ExemptOrigins []string `toml:"exempt_origins"`
|
||||
ExemptUserAgents []string `toml:"exempt_user_agents"`
|
||||
ErrorMessage string `toml:"error_message"`
|
||||
|
@ -17,7 +17,7 @@ type FrontendRateLimiter interface {
|
||||
//
|
||||
// No error will be returned if the limit could not be taken
|
||||
// as a result of the requestor being over the limit.
|
||||
Take(ctx context.Context, key string, max int) (bool, error)
|
||||
Take(ctx context.Context, key string) (bool, error)
|
||||
}
|
||||
|
||||
// limitedKeys is a wrapper around a map that stores a truncated
|
||||
@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool {
|
||||
type MemoryFrontendRateLimiter struct {
|
||||
currGeneration *limitedKeys
|
||||
dur time.Duration
|
||||
max int
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter {
|
||||
func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
|
||||
return &MemoryFrontendRateLimiter{
|
||||
dur: dur,
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
|
||||
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
|
||||
m.mtx.Lock()
|
||||
// Create truncated timestamp
|
||||
truncTS := truncateNow(m.dur)
|
||||
@ -83,7 +85,7 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in
|
||||
|
||||
m.mtx.Unlock()
|
||||
|
||||
return limiter.Take(key, max), nil
|
||||
return limiter.Take(key, m.max), nil
|
||||
}
|
||||
|
||||
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
|
||||
@ -92,26 +94,42 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in
|
||||
type RedisFrontendRateLimiter struct {
|
||||
r *redis.Client
|
||||
dur time.Duration
|
||||
max int
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter {
|
||||
return &RedisFrontendRateLimiter{r: r, dur: dur}
|
||||
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
|
||||
return &RedisFrontendRateLimiter{
|
||||
r: r,
|
||||
dur: dur,
|
||||
max: max,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
|
||||
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
|
||||
var incr *redis.IntCmd
|
||||
truncTS := truncateNow(r.dur)
|
||||
fullKey := fmt.Sprintf("%s:%d", key, truncTS)
|
||||
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
|
||||
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
|
||||
incr = pipe.Incr(ctx, fullKey)
|
||||
pipe.Expire(ctx, fullKey, r.dur-time.Second)
|
||||
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
frontendRateLimitTakeErrors.Inc()
|
||||
return false, err
|
||||
}
|
||||
|
||||
return incr.Val()-1 < int64(max), nil
|
||||
return incr.Val()-1 < int64(r.max), nil
|
||||
}
|
||||
|
||||
type noopFrontendRateLimiter struct{}
|
||||
|
||||
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
|
||||
|
||||
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// truncateNow truncates the current timestamp
|
||||
|
@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) {
|
||||
Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()),
|
||||
})
|
||||
|
||||
max := 2
|
||||
lims := []struct {
|
||||
name string
|
||||
frl FrontendRateLimiter
|
||||
}{
|
||||
{"memory", NewMemoryFrontendRateLimit(2 * time.Second)},
|
||||
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)},
|
||||
{"memory", NewMemoryFrontendRateLimit(2*time.Second, max)},
|
||||
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second, max, "")},
|
||||
}
|
||||
|
||||
max := 2
|
||||
for _, cfg := range lims {
|
||||
frl := cfg.frl
|
||||
ctx := context.Background()
|
||||
t.Run(cfg.name, func(t *testing.T) {
|
||||
for i := 0; i < 4; i++ {
|
||||
ok, err := frl.Take(ctx, "foo", max)
|
||||
ok, err := frl.Take(ctx, "foo")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, i < max, ok)
|
||||
ok, err = frl.Take(ctx, "bar", max)
|
||||
ok, err = frl.Take(ctx, "bar")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, i < max, ok)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
for i := 0; i < 4; i++ {
|
||||
ok, _ := frl.Take(ctx, "foo", max)
|
||||
ok, _ := frl.Take(ctx, "foo")
|
||||
require.Equal(t, i < max, ok)
|
||||
ok, _ = frl.Take(ctx, "bar", max)
|
||||
ok, _ = frl.Take(ctx, "bar")
|
||||
require.Equal(t, i < max, ok)
|
||||
}
|
||||
})
|
||||
|
@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
|
||||
config.BackendOptions.MaxRetries = 2
|
||||
// Setup redis to detect offline backends
|
||||
config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())
|
||||
redisClient, err := proxyd.NewRedisClient(config.Redis.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
|
||||
defer goodBackend.Close()
|
||||
@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
|
||||
require.Equal(t, 1, len(badBackend.Requests()))
|
||||
require.Equal(t, 1, len(goodBackend.Requests()))
|
||||
|
||||
rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL)
|
||||
rr := proxyd.NewRedisRateLimiter(redisClient)
|
||||
require.NoError(t, err)
|
||||
online, err := rr.IsBackendOnline("bad")
|
||||
require.NoError(t, err)
|
||||
|
@ -18,7 +18,8 @@ eth_chainId = "main"
|
||||
eth_foobar = "main"
|
||||
|
||||
[rate_limit]
|
||||
rate_per_second = 2
|
||||
base_rate = 2
|
||||
base_interval = "1s"
|
||||
exempt_origins = ["exempt_origin"]
|
||||
exempt_user_agents = ["exempt_agent"]
|
||||
error_message = "over rate limit with special message"
|
||||
|
@ -236,6 +236,12 @@ var (
|
||||
100,
|
||||
},
|
||||
})
|
||||
|
||||
frontendRateLimitTakeErrors = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Name: "rate_limit_take_errors",
|
||||
Help: "Count of errors taking frontend rate limits",
|
||||
})
|
||||
)
|
||||
|
||||
func RecordRedisError(source string) {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/ethereum/go-ethereum/common/math"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) {
|
||||
}
|
||||
}
|
||||
|
||||
var redisURL string
|
||||
var redisClient *redis.Client
|
||||
if config.Redis.URL != "" {
|
||||
rURL, err := ReadFromEnvOrConfig(config.Redis.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redisURL = rURL
|
||||
redisClient, err = NewRedisClient(rURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if redisClient == nil && config.RateLimit.UseRedis {
|
||||
return nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config")
|
||||
}
|
||||
|
||||
var lim BackendRateLimiter
|
||||
var err error
|
||||
if redisURL == "" {
|
||||
if redisClient == nil {
|
||||
log.Warn("redis is not configured, using local rate limiter")
|
||||
lim = NewLocalBackendRateLimiter()
|
||||
} else {
|
||||
lim, err = NewRedisRateLimiter(redisURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lim = NewRedisRateLimiter(redisClient)
|
||||
}
|
||||
|
||||
// While modifying shared globals is a bad practice, the alternative
|
||||
@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if redisURL != "" {
|
||||
if cache, err = newRedisCache(redisURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if redisClient == nil {
|
||||
log.Warn("redis is not configured, using in-memory cache")
|
||||
cache = newMemoryCache()
|
||||
} else {
|
||||
cache = newRedisCache(redisClient)
|
||||
}
|
||||
// Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind
|
||||
ethClient, err := ethclient.Dial(blockSyncRPCURL)
|
||||
@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) {
|
||||
config.Server.EnableRequestLog,
|
||||
config.Server.MaxRequestBodyLogLen,
|
||||
config.BatchConfig.MaxSize,
|
||||
redisClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating server: %w", err)
|
||||
|
22
proxyd/proxyd/redis.go
Normal file
22
proxyd/proxyd/redis.go
Normal file
@ -0,0 +1,22 @@
|
||||
package proxyd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func NewRedisClient(url string) (*redis.Client, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := redis.NewClient(opts)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, wrapErr(err, "error connecting to redis")
|
||||
}
|
||||
return client, nil
|
||||
}
|
@ -13,11 +13,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-limiter"
|
||||
"github.com/sethvargo/go-limiter/memorystore"
|
||||
"github.com/sethvargo/go-limiter/noopstore"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@ -50,9 +47,8 @@ type Server struct {
|
||||
maxUpstreamBatchSize int
|
||||
maxBatchSize int
|
||||
upgrader *websocket.Upgrader
|
||||
mainLim limiter.Store
|
||||
overrideLims map[string]limiter.Store
|
||||
limConfig RateLimitConfig
|
||||
mainLim FrontendRateLimiter
|
||||
overrideLims map[string]FrontendRateLimiter
|
||||
limExemptOrigins map[string]bool
|
||||
limExemptUserAgents map[string]bool
|
||||
rpcServer *http.Server
|
||||
@ -77,6 +73,7 @@ func NewServer(
|
||||
enableRequestLog bool,
|
||||
maxRequestBodyLogLen int,
|
||||
maxBatchSize int,
|
||||
redisClient *redis.Client,
|
||||
) (*Server, error) {
|
||||
if cache == nil {
|
||||
cache = &NoopRPCCache{}
|
||||
@ -98,19 +95,19 @@ func NewServer(
|
||||
maxBatchSize = MaxBatchRPCCallsHardLimit
|
||||
}
|
||||
|
||||
var mainLim limiter.Store
|
||||
limExemptOrigins := make(map[string]bool)
|
||||
limExemptUserAgents := make(map[string]bool)
|
||||
if rateLimitConfig.RatePerSecond > 0 {
|
||||
var err error
|
||||
mainLim, err = memorystore.New(&memorystore.Config{
|
||||
Tokens: uint64(rateLimitConfig.RatePerSecond),
|
||||
Interval: time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
|
||||
if rateLimitConfig.UseRedis {
|
||||
return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
|
||||
}
|
||||
|
||||
return NewMemoryFrontendRateLimit(dur, max)
|
||||
}
|
||||
|
||||
var mainLim FrontendRateLimiter
|
||||
limExemptOrigins := make(map[string]bool)
|
||||
limExemptUserAgents := make(map[string]bool)
|
||||
if rateLimitConfig.BaseRate > 0 {
|
||||
mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
|
||||
for _, origin := range rateLimitConfig.ExemptOrigins {
|
||||
limExemptOrigins[strings.ToLower(origin)] = true
|
||||
}
|
||||
@ -118,16 +115,13 @@ func NewServer(
|
||||
limExemptUserAgents[strings.ToLower(agent)] = true
|
||||
}
|
||||
} else {
|
||||
mainLim, _ = noopstore.New()
|
||||
mainLim = NoopFrontendRateLimiter
|
||||
}
|
||||
|
||||
overrideLims := make(map[string]limiter.Store)
|
||||
overrideLims := make(map[string]FrontendRateLimiter)
|
||||
for method, override := range rateLimitConfig.MethodOverrides {
|
||||
var err error
|
||||
overrideLims[method], err = memorystore.New(&memorystore.Config{
|
||||
Tokens: uint64(override.Limit),
|
||||
Interval: time.Duration(override.Interval),
|
||||
})
|
||||
overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -151,7 +145,6 @@ func NewServer(
|
||||
},
|
||||
mainLim: mainLim,
|
||||
overrideLims: overrideLims,
|
||||
limConfig: rateLimitConfig,
|
||||
limExemptOrigins: limExemptOrigins,
|
||||
limExemptUserAgents: limExemptUserAgents,
|
||||
}, nil
|
||||
@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
return false
|
||||
}
|
||||
|
||||
var lim limiter.Store
|
||||
var lim FrontendRateLimiter
|
||||
if method == "" {
|
||||
lim = s.mainLim
|
||||
} else {
|
||||
@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
return false
|
||||
}
|
||||
|
||||
_, _, _, ok, _ := lim.Take(ctx, xff)
|
||||
ok, err := lim.Take(ctx, xff)
|
||||
if err != nil {
|
||||
log.Warn("error taking rate limit", "err", err)
|
||||
return true
|
||||
}
|
||||
return !ok
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user