package proxyd import ( "context" "fmt" "sync" "time" "github.com/redis/go-redis/v9" ) type FrontendRateLimiter interface { // Take consumes a key, and a maximum number of requests // per time interval. It returns a boolean denoting if // the limit could be taken, or an error if a failure // occurred in the backing rate limit implementation. // // No error will be returned if the limit could not be taken // as a result of the requestor being over the limit. Take(ctx context.Context, key string) (bool, error) } // limitedKeys is a wrapper around a map that stores a truncated // timestamp and a mutex. The map is used to keep track of rate // limit keys, and their used limits. type limitedKeys struct { truncTS int64 keys map[string]int mtx sync.Mutex } func newLimitedKeys(t int64) *limitedKeys { return &limitedKeys{ truncTS: t, keys: make(map[string]int), } } func (l *limitedKeys) Take(key string, max int) bool { l.mtx.Lock() defer l.mtx.Unlock() val, ok := l.keys[key] if !ok { l.keys[key] = 0 val = 0 } l.keys[key] = val + 1 return val < max } // MemoryFrontendRateLimiter is a rate limiter that stores // all rate limiting information in local memory. It works // by storing a limitedKeys struct that references the // truncated timestamp at which the struct was created. If // the current truncated timestamp doesn't match what's // referenced, the limit is reset. Otherwise, values in // a map are incremented to represent the limit. type MemoryFrontendRateLimiter struct { currGeneration *limitedKeys dur time.Duration max int mtx sync.Mutex } func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter { return &MemoryFrontendRateLimiter{ dur: dur, max: max, } } func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { m.mtx.Lock() // Create truncated timestamp truncTS := truncateNow(m.dur) // If there is no current rate limit map or the rate limit map reference // a different timestamp, reset limits. if m.currGeneration == nil || m.currGeneration.truncTS != truncTS { m.currGeneration = newLimitedKeys(truncTS) } // Pull out the limiter so we can unlock before incrementing the limit. limiter := m.currGeneration m.mtx.Unlock() return limiter.Take(key, m.max), nil } // RedisFrontendRateLimiter is a rate limiter that stores data in Redis. // It uses the basic rate limiter pattern described on the Redis best // practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. type RedisFrontendRateLimiter struct { r *redis.Client dur time.Duration max int prefix string } func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter { return &RedisFrontendRateLimiter{ r: r, dur: dur, max: max, prefix: prefix, } } func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { var incr *redis.IntCmd truncTS := truncateNow(r.dur) fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS) _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { incr = pipe.Incr(ctx, fullKey) pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond) return nil }) if err != nil { frontendRateLimitTakeErrors.Inc() return false, err } return incr.Val()-1 < int64(r.max), nil } type noopFrontendRateLimiter struct{} var NoopFrontendRateLimiter = &noopFrontendRateLimiter{} func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { return true, nil } // truncateNow truncates the current timestamp // to the specified duration. func truncateNow(dur time.Duration) int64 { return time.Now().Truncate(dur).Unix() }