proxyd: Integrate custom rate limiter

Integrates the custom rate limiter in the previous PR into the rest of the application. Also takes the opportunity to clean up how we instantiate Redis clients so that we can share them among multiple different services.

There are some config changes in this PR. Specifically, you must specify a `base_rate` and `base_interval` in the rate limit config.
This commit is contained in:
Matthew Slipper 2022-10-09 15:26:27 -05:00
parent fa7425683a
commit f737002baa
11 changed files with 116 additions and 81 deletions

@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct {
tkMtx sync.Mutex
}
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter {
out := &RedisBackendRateLimiter{
rdb: rdb,
randID: randStr(20),
touchKeys: make(map[string]time.Duration),
}
go out.touch()
return out, nil
return out
}
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {

@ -46,16 +46,8 @@ type redisCache struct {
rdb *redis.Client
}
func newRedisCache(url string) (*redisCache, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return &redisCache{rdb}, nil
func newRedisCache(rdb *redis.Client) *redisCache {
return &redisCache{rdb}
}
func (c *redisCache) Get(ctx context.Context, key string) (string, error) {

@ -42,7 +42,8 @@ type MetricsConfig struct {
type RateLimitConfig struct {
UseRedis bool `toml:"use_redis"`
RatePerSecond int `toml:"rate_per_second"`
BaseRate int `toml:"base_rate"`
BaseInterval TOMLDuration `toml:"base_interval"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`

@ -17,7 +17,7 @@ type FrontendRateLimiter interface {
//
// No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit.
Take(ctx context.Context, key string, max int) (bool, error)
Take(ctx context.Context, key string) (bool, error)
}
// limitedKeys is a wrapper around a map that stores a truncated
@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool {
type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys
dur time.Duration
max int
mtx sync.Mutex
}
func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter {
func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
max: max,
}
}
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock()
// Create truncated timestamp
truncTS := truncateNow(m.dur)
@ -83,35 +85,51 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in
m.mtx.Unlock()
return limiter.Take(key, max), nil
return limiter.Take(key, m.max), nil
}
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
r *redis.Client
dur time.Duration
max int
prefix string
}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter {
return &RedisFrontendRateLimiter{r: r, dur: dur}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
}
}
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd
truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("%s:%d", key, truncTS)
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.Expire(ctx, fullKey, r.dur-time.Second)
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil
})
if err != nil {
frontendRateLimitTakeErrors.Inc()
return false, err
}
return incr.Val()-1 < int64(max), nil
return incr.Val()-1 < int64(r.max), nil
}
type noopFrontendRateLimiter struct{}
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
return true, nil
}
// truncateNow truncates the current timestamp

@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) {
Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()),
})
max := 2
lims := []struct {
name string
frl FrontendRateLimiter
}{
{"memory", NewMemoryFrontendRateLimit(2 * time.Second)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)},
{"memory", NewMemoryFrontendRateLimit(2*time.Second, max)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second, max, "")},
}
max := 2
for _, cfg := range lims {
frl := cfg.frl
ctx := context.Background()
t.Run(cfg.name, func(t *testing.T) {
for i := 0; i < 4; i++ {
ok, err := frl.Take(ctx, "foo", max)
ok, err := frl.Take(ctx, "foo")
require.NoError(t, err)
require.Equal(t, i < max, ok)
ok, err = frl.Take(ctx, "bar", max)
ok, err = frl.Take(ctx, "bar")
require.NoError(t, err)
require.Equal(t, i < max, ok)
}
time.Sleep(2 * time.Second)
for i := 0; i < 4; i++ {
ok, _ := frl.Take(ctx, "foo", max)
ok, _ := frl.Take(ctx, "foo")
require.Equal(t, i < max, ok)
ok, _ = frl.Take(ctx, "bar", max)
ok, _ = frl.Take(ctx, "bar")
require.Equal(t, i < max, ok)
}
})

@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
config.BackendOptions.MaxRetries = 2
// Setup redis to detect offline backends
config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())
redisClient, err := proxyd.NewRedisClient(config.Redis.URL)
require.NoError(t, err)
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
defer goodBackend.Close()
@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests()))
rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL)
rr := proxyd.NewRedisRateLimiter(redisClient)
require.NoError(t, err)
online, err := rr.IsBackendOnline("bad")
require.NoError(t, err)

@ -18,7 +18,8 @@ eth_chainId = "main"
eth_foobar = "main"
[rate_limit]
rate_per_second = 2
base_rate = 2
base_interval = "1s"
exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit with special message"

@ -236,6 +236,12 @@ var (
100,
},
})
frontendRateLimitTakeErrors = promauto.NewCounter(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rate_limit_take_errors",
Help: "Count of errors taking frontend rate limits",
})
)
func RecordRedisError(source string) {

@ -13,6 +13,7 @@ import (
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/semaphore"
)
@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) {
}
}
var redisURL string
var redisClient *redis.Client
if config.Redis.URL != "" {
rURL, err := ReadFromEnvOrConfig(config.Redis.URL)
if err != nil {
return nil, err
}
redisURL = rURL
redisClient, err = NewRedisClient(rURL)
if err != nil {
return nil, err
}
}
if redisClient == nil && config.RateLimit.UseRedis {
return nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config")
}
var lim BackendRateLimiter
var err error
if redisURL == "" {
if redisClient == nil {
log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalBackendRateLimiter()
} else {
lim, err = NewRedisRateLimiter(redisURL)
if err != nil {
return nil, err
}
lim = NewRedisRateLimiter(redisClient)
}
// While modifying shared globals is a bad practice, the alternative
@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) {
return nil, err
}
if redisURL != "" {
if cache, err = newRedisCache(redisURL); err != nil {
return nil, err
}
} else {
if redisClient == nil {
log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache()
} else {
cache = newRedisCache(redisClient)
}
// Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind
ethClient, err := ethclient.Dial(blockSyncRPCURL)
@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) {
config.Server.EnableRequestLog,
config.Server.MaxRequestBodyLogLen,
config.BatchConfig.MaxSize,
redisClient,
)
if err != nil {
return nil, fmt.Errorf("error creating server: %w", err)

22
proxyd/proxyd/redis.go Normal file

@ -0,0 +1,22 @@
package proxyd
import (
"context"
"time"
"github.com/go-redis/redis/v8"
)
func NewRedisClient(url string) (*redis.Client, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
client := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return client, nil
}

@ -13,11 +13,8 @@ import (
"sync"
"time"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-limiter/noopstore"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
@ -50,9 +47,8 @@ type Server struct {
maxUpstreamBatchSize int
maxBatchSize int
upgrader *websocket.Upgrader
mainLim limiter.Store
overrideLims map[string]limiter.Store
limConfig RateLimitConfig
mainLim FrontendRateLimiter
overrideLims map[string]FrontendRateLimiter
limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool
rpcServer *http.Server
@ -77,6 +73,7 @@ func NewServer(
enableRequestLog bool,
maxRequestBodyLogLen int,
maxBatchSize int,
redisClient *redis.Client,
) (*Server, error) {
if cache == nil {
cache = &NoopRPCCache{}
@ -98,19 +95,19 @@ func NewServer(
maxBatchSize = MaxBatchRPCCallsHardLimit
}
var mainLim limiter.Store
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 {
var err error
mainLim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
if err != nil {
return nil, err
limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
if rateLimitConfig.UseRedis {
return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
}
return NewMemoryFrontendRateLimit(dur, max)
}
var mainLim FrontendRateLimiter
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.BaseRate > 0 {
mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
for _, origin := range rateLimitConfig.ExemptOrigins {
limExemptOrigins[strings.ToLower(origin)] = true
}
@ -118,16 +115,13 @@ func NewServer(
limExemptUserAgents[strings.ToLower(agent)] = true
}
} else {
mainLim, _ = noopstore.New()
mainLim = NoopFrontendRateLimiter
}
overrideLims := make(map[string]limiter.Store)
overrideLims := make(map[string]FrontendRateLimiter)
for method, override := range rateLimitConfig.MethodOverrides {
var err error
overrideLims[method], err = memorystore.New(&memorystore.Config{
Tokens: uint64(override.Limit),
Interval: time.Duration(override.Interval),
})
overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
if err != nil {
return nil, err
}
@ -151,7 +145,6 @@ func NewServer(
},
mainLim: mainLim,
overrideLims: overrideLims,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
}, nil
@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false
}
var lim limiter.Store
var lim FrontendRateLimiter
if method == "" {
lim = s.mainLim
} else {
@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false
}
_, _, _, ok, _ := lim.Take(ctx, xff)
ok, err := lim.Take(ctx, xff)
if err != nil {
log.Warn("error taking rate limit", "err", err)
return true
}
return !ok
}