diff --git a/proxyd/proxyd/rate_limiter.go b/proxyd/proxyd/backend_rate_limiter.go similarity index 100% rename from proxyd/proxyd/rate_limiter.go rename to proxyd/proxyd/backend_rate_limiter.go diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index 0647074..aea54a9 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -41,6 +41,7 @@ type MetricsConfig struct { } type RateLimitConfig struct { + UseRedis bool `toml:"use_redis"` RatePerSecond int `toml:"rate_per_second"` ExemptOrigins []string `toml:"exempt_origins"` ExemptUserAgents []string `toml:"exempt_user_agents"` diff --git a/proxyd/proxyd/frontend_rate_limiter.go b/proxyd/proxyd/frontend_rate_limiter.go new file mode 100644 index 0000000..3b06052 --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter.go @@ -0,0 +1,121 @@ +package proxyd + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/go-redis/redis/v8" +) + +type FrontendRateLimiter interface { + // Take consumes a key, and a maximum number of requests + // per time interval. It returns a boolean denoting if + // the limit could be taken, or an error if a failure + // occurred in the backing rate limit implementation. + // + // No error will be returned if the limit could not be taken + // as a result of the requestor being over the limit. + Take(ctx context.Context, key string, max int) (bool, error) +} + +// limitedKeys is a wrapper around a map that stores a truncated +// timestamp and a mutex. The map is used to keep track of rate +// limit keys, and their used limits. +type limitedKeys struct { + truncTS int64 + keys map[string]int + mtx sync.Mutex +} + +func newLimitedKeys(t int64) *limitedKeys { + return &limitedKeys{ + truncTS: t, + keys: make(map[string]int), + } +} + +func (l *limitedKeys) Take(key string, max int) bool { + l.mtx.Lock() + defer l.mtx.Unlock() + val, ok := l.keys[key] + if !ok { + l.keys[key] = 0 + val = 0 + } + l.keys[key] = val + 1 + return val < max +} + +// MemoryFrontendRateLimiter is a rate limiter that stores +// all rate limiting information in local memory. It works +// by storing a limitedKeys struct that references the +// truncated timestamp at which the struct was created. If +// the current truncated timestamp doesn't match what's +// referenced, the limit is reset. Otherwise, values in +// a map are incremented to represent the limit. +type MemoryFrontendRateLimiter struct { + currGeneration *limitedKeys + dur time.Duration + mtx sync.Mutex +} + +func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter { + return &MemoryFrontendRateLimiter{ + dur: dur, + } +} + +func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + m.mtx.Lock() + // Create truncated timestamp + truncTS := truncateNow(m.dur) + + // If there is no current rate limit map or the rate limit map reference + // a different timestamp, reset limits. + if m.currGeneration == nil || m.currGeneration.truncTS != truncTS { + m.currGeneration = newLimitedKeys(truncTS) + } + + // Pull out the limiter so we can unlock before incrementing the limit. + limiter := m.currGeneration + + m.mtx.Unlock() + + return limiter.Take(key, max), nil +} + +// RedisFrontendRateLimiter is a rate limiter that stores data in Redis. +// It uses the basic rate limiter pattern described on the Redis best +// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. +type RedisFrontendRateLimiter struct { + r *redis.Client + dur time.Duration +} + +func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter { + return &RedisFrontendRateLimiter{r: r, dur: dur} +} + +func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { + var incr *redis.IntCmd + truncTS := truncateNow(r.dur) + fullKey := fmt.Sprintf("%s:%d", key, truncTS) + _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { + incr = pipe.Incr(ctx, fullKey) + pipe.Expire(ctx, fullKey, r.dur-time.Second) + return nil + }) + if err != nil { + return false, err + } + + return incr.Val()-1 < int64(max), nil +} + +// truncateNow truncates the current timestamp +// to the specified duration. +func truncateNow(dur time.Duration) int64 { + return time.Now().Truncate(dur).Unix() +} diff --git a/proxyd/proxyd/frontend_rate_limiter_test.go b/proxyd/proxyd/frontend_rate_limiter_test.go new file mode 100644 index 0000000..c3d43cd --- /dev/null +++ b/proxyd/proxyd/frontend_rate_limiter_test.go @@ -0,0 +1,53 @@ +package proxyd + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func TestFrontendRateLimiter(t *testing.T) { + redisServer, err := miniredis.Run() + require.NoError(t, err) + defer redisServer.Close() + + redisClient := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()), + }) + + lims := []struct { + name string + frl FrontendRateLimiter + }{ + {"memory", NewMemoryFrontendRateLimit(2 * time.Second)}, + {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)}, + } + + max := 2 + for _, cfg := range lims { + frl := cfg.frl + ctx := context.Background() + t.Run(cfg.name, func(t *testing.T) { + for i := 0; i < 4; i++ { + ok, err := frl.Take(ctx, "foo", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + ok, err = frl.Take(ctx, "bar", max) + require.NoError(t, err) + require.Equal(t, i < max, ok) + } + time.Sleep(2 * time.Second) + for i := 0; i < 4; i++ { + ok, _ := frl.Take(ctx, "foo", max) + require.Equal(t, i < max, ok) + ok, _ = frl.Take(ctx, "bar", max) + require.Equal(t, i < max, ok) + } + }) + } +}