infra/proxyd/proxyd/frontend_rate_limiter.go
Matthew Slipper f737002baa proxyd: Integrate custom rate limiter
Integrates the custom rate limiter in the previous PR into the rest of the application. Also takes the opportunity to clean up how we instantiate Redis clients so that we can share them among multiple different services.

There are some config changes in this PR. Specifically, you must specify a `base_rate` and `base_interval` in the rate limit config.
2022-10-11 17:20:23 -05:00

140 lines
3.7 KiB
Go

package proxyd
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
)
type FrontendRateLimiter interface {
// Take consumes a key, and a maximum number of requests
// per time interval. It returns a boolean denoting if
// the limit could be taken, or an error if a failure
// occurred in the backing rate limit implementation.
//
// No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit.
Take(ctx context.Context, key string) (bool, error)
}
// limitedKeys is a wrapper around a map that stores a truncated
// timestamp and a mutex. The map is used to keep track of rate
// limit keys, and their used limits.
type limitedKeys struct {
truncTS int64
keys map[string]int
mtx sync.Mutex
}
func newLimitedKeys(t int64) *limitedKeys {
return &limitedKeys{
truncTS: t,
keys: make(map[string]int),
}
}
func (l *limitedKeys) Take(key string, max int) bool {
l.mtx.Lock()
defer l.mtx.Unlock()
val, ok := l.keys[key]
if !ok {
l.keys[key] = 0
val = 0
}
l.keys[key] = val + 1
return val < max
}
// MemoryFrontendRateLimiter is a rate limiter that stores
// all rate limiting information in local memory. It works
// by storing a limitedKeys struct that references the
// truncated timestamp at which the struct was created. If
// the current truncated timestamp doesn't match what's
// referenced, the limit is reset. Otherwise, values in
// a map are incremented to represent the limit.
type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys
dur time.Duration
max int
mtx sync.Mutex
}
func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
max: max,
}
}
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock()
// Create truncated timestamp
truncTS := truncateNow(m.dur)
// If there is no current rate limit map or the rate limit map reference
// a different timestamp, reset limits.
if m.currGeneration == nil || m.currGeneration.truncTS != truncTS {
m.currGeneration = newLimitedKeys(truncTS)
}
// Pull out the limiter so we can unlock before incrementing the limit.
limiter := m.currGeneration
m.mtx.Unlock()
return limiter.Take(key, m.max), nil
}
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
max int
prefix string
}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
}
}
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd
truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil
})
if err != nil {
frontendRateLimitTakeErrors.Inc()
return false, err
}
return incr.Val()-1 < int64(r.max), nil
}
type noopFrontendRateLimiter struct{}
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
return true, nil
}
// truncateNow truncates the current timestamp
// to the specified duration.
func truncateNow(dur time.Duration) int64 {
return time.Now().Truncate(dur).Unix()
}