From fa7425683a1715f68c9c2349df5f36b57201c30b Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Sun, 9 Oct 2022 14:20:29 -0500 Subject: [PATCH] proxyd: Custom rate limiter implementation Our current proxyd deployment does not share rate limit state across multiple servers within a backend group. This means that rate limits on the public endpoint are artifically high. This PR adds a Redis-based rate limiter to fix this problem. While our current rate limiting library (github.com/sethvargo/go-limiter) _does_ support Redis, the client library it uses is not type safe, is less performant, and would require us to update the other places we use Redis. To avoid these issues, I created a simple rate limiting interface with both Redis and memory backend. Note that this PR only adds the new implementations - it does not integrate them with the rest of the codebase. I'll do that in a separate PR to make review easier. --- ...ate_limiter.go => backend_rate_limiter.go} | 0 proxyd/proxyd/config.go | 1 + proxyd/proxyd/frontend_rate_limiter.go | 121 ++++++++++++++++++ proxyd/proxyd/frontend_rate_limiter_test.go | 53 ++++++++ 4 files changed, 175 insertions(+) rename proxyd/proxyd/{rate_limiter.go => backend_rate_limiter.go} (100%) create mode 100644 proxyd/proxyd/frontend_rate_limiter.go create mode 100644 proxyd/proxyd/frontend_rate_limiter_test.go diff --git a/proxyd/proxyd/rate_limiter.go b/proxyd/proxyd/backend_rate_limiter.go similarity index 100% rename from proxyd/proxyd/rate_limiter.go rename to proxyd/proxyd/backend_rate_limiter.go diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index 0647074..aea54a9 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -41,6 +41,7 @@ type MetricsConfig struct { } type RateLimitConfig struct { + UseRedis bool `toml:"use_redis"` RatePerSecond int `toml:"rate_per_second"` ExemptOrigins []string `toml:"exempt_origins"` ExemptUserAgents []string `toml:"exempt_user_agents"` diff --git a/proxyd/proxyd/frontend_rate_limiter.go b/proxyd/proxyd/frontend_rate_limiter.go new file mode 100644 index 0000000..3b06052 --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter.go @@ -0,0 +1,121 @@ +package proxyd + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/go-redis/redis/v8" +) + +type FrontendRateLimiter interface { + // Take consumes a key, and a maximum number of requests + // per time interval. It returns a boolean denoting if + // the limit could be taken, or an error if a failure + // occurred in the backing rate limit implementation. + // + // No error will be returned if the limit could not be taken + // as a result of the requestor being over the limit. + Take(ctx context.Context, key string, max int) (bool, error) +} + +// limitedKeys is a wrapper around a map that stores a truncated +// timestamp and a mutex. The map is used to keep track of rate +// limit keys, and their used limits. +type limitedKeys struct { + truncTS int64 + keys map[string]int + mtx sync.Mutex +} + +func newLimitedKeys(t int64) *limitedKeys { + return &limitedKeys{ + truncTS: t, + keys: make(map[string]int), + } +} + +func (l *limitedKeys) Take(key string, max int) bool { + l.mtx.Lock() + defer l.mtx.Unlock() + val, ok := l.keys[key] + if !ok { + l.keys[key] = 0 + val = 0 + } + l.keys[key] = val + 1 + return val < max +} + +// MemoryFrontendRateLimiter is a rate limiter that stores +// all rate limiting information in local memory. It works +// by storing a limitedKeys struct that references the +// truncated timestamp at which the struct was created. If +// the current truncated timestamp doesn't match what's +// referenced, the limit is reset. Otherwise, values in +// a map are incremented to represent the limit. +type MemoryFrontendRateLimiter struct { + currGeneration *limitedKeys + dur time.Duration + mtx sync.Mutex +} + +func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter { + return &MemoryFrontendRateLimiter{ + dur: dur, + } +} + +func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + m.mtx.Lock() + // Create truncated timestamp + truncTS := truncateNow(m.dur) + + // If there is no current rate limit map or the rate limit map reference + // a different timestamp, reset limits. + if m.currGeneration == nil || m.currGeneration.truncTS != truncTS { + m.currGeneration = newLimitedKeys(truncTS) + } + + // Pull out the limiter so we can unlock before incrementing the limit. + limiter := m.currGeneration + + m.mtx.Unlock() + + return limiter.Take(key, max), nil +} + +// RedisFrontendRateLimiter is a rate limiter that stores data in Redis. +// It uses the basic rate limiter pattern described on the Redis best +// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. +type RedisFrontendRateLimiter struct { + r *redis.Client + dur time.Duration +} + +func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter { + return &RedisFrontendRateLimiter{r: r, dur: dur} +} + +func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + var incr *redis.IntCmd + truncTS := truncateNow(r.dur) + fullKey := fmt.Sprintf("%s:%d", key, truncTS) + _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { + incr = pipe.Incr(ctx, fullKey) + pipe.Expire(ctx, fullKey, r.dur-time.Second) + return nil + }) + if err != nil { + return false, err + } + + return incr.Val()-1 < int64(max), nil +} + +// truncateNow truncates the current timestamp +// to the specified duration. +func truncateNow(dur time.Duration) int64 { + return time.Now().Truncate(dur).Unix() +} diff --git a/proxyd/proxyd/frontend_rate_limiter_test.go b/proxyd/proxyd/frontend_rate_limiter_test.go new file mode 100644 index 0000000..c3d43cd --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter_test.go @@ -0,0 +1,53 @@ +package proxyd + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func TestFrontendRateLimiter(t *testing.T) { + redisServer, err := miniredis.Run() + require.NoError(t, err) + defer redisServer.Close() + + redisClient := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()), + }) + + lims := []struct { + name string + frl FrontendRateLimiter + }{ + {"memory", NewMemoryFrontendRateLimit(2 * time.Second)}, + {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)}, + } + + max := 2 + for _, cfg := range lims { + frl := cfg.frl + ctx := context.Background() + t.Run(cfg.name, func(t *testing.T) { + for i := 0; i < 4; i++ { + ok, err := frl.Take(ctx, "foo", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + ok, err = frl.Take(ctx, "bar", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + } + time.Sleep(2 * time.Second) + for i := 0; i < 4; i++ { + ok, _ := frl.Take(ctx, "foo", max) + require.Equal(t, i < max, ok) + ok, _ = frl.Take(ctx, "bar", max) + require.Equal(t, i < max, ok) + } + }) + } +}