proxyd: Custom rate limiter implementation

Our current proxyd deployment does not share rate limit state across multiple servers within a backend group. This means that rate limits on the public endpoint are artifically high.

This PR adds a Redis-based rate limiter to fix this problem. While our current rate limiting library (github.com/sethvargo/go-limiter) _does_ support Redis, the client library it uses is not type safe, is less performant, and would require us to update the other places we use Redis. To avoid these issues, I created a simple rate limiting interface with both Redis and memory backend.

Note that this PR only adds the new implementations - it does not integrate them with the rest of the codebase. I'll do that in a separate PR to make review easier.
This commit is contained in:
Matthew Slipper 2022-10-09 14:20:29 -05:00
parent 7cadaca188
commit fa7425683a
4 changed files with 175 additions and 0 deletions

@ -41,6 +41,7 @@ type MetricsConfig struct {
} }
type RateLimitConfig struct { type RateLimitConfig struct {
UseRedis bool `toml:"use_redis"`
RatePerSecond int `toml:"rate_per_second"` RatePerSecond int `toml:"rate_per_second"`
ExemptOrigins []string `toml:"exempt_origins"` ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"` ExemptUserAgents []string `toml:"exempt_user_agents"`

@ -0,0 +1,121 @@
package proxyd
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
)
type FrontendRateLimiter interface {
// Take consumes a key, and a maximum number of requests
// per time interval. It returns a boolean denoting if
// the limit could be taken, or an error if a failure
// occurred in the backing rate limit implementation.
//
// No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit.
Take(ctx context.Context, key string, max int) (bool, error)
}
// limitedKeys is a wrapper around a map that stores a truncated
// timestamp and a mutex. The map is used to keep track of rate
// limit keys, and their used limits.
type limitedKeys struct {
truncTS int64
keys map[string]int
mtx sync.Mutex
}
func newLimitedKeys(t int64) *limitedKeys {
return &limitedKeys{
truncTS: t,
keys: make(map[string]int),
}
}
func (l *limitedKeys) Take(key string, max int) bool {
l.mtx.Lock()
defer l.mtx.Unlock()
val, ok := l.keys[key]
if !ok {
l.keys[key] = 0
val = 0
}
l.keys[key] = val + 1
return val < max
}
// MemoryFrontendRateLimiter is a rate limiter that stores
// all rate limiting information in local memory. It works
// by storing a limitedKeys struct that references the
// truncated timestamp at which the struct was created. If
// the current truncated timestamp doesn't match what's
// referenced, the limit is reset. Otherwise, values in
// a map are incremented to represent the limit.
type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys
dur time.Duration
mtx sync.Mutex
}
func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
}
}
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
m.mtx.Lock()
// Create truncated timestamp
truncTS := truncateNow(m.dur)
// If there is no current rate limit map or the rate limit map reference
// a different timestamp, reset limits.
if m.currGeneration == nil || m.currGeneration.truncTS != truncTS {
m.currGeneration = newLimitedKeys(truncTS)
}
// Pull out the limiter so we can unlock before incrementing the limit.
limiter := m.currGeneration
m.mtx.Unlock()
return limiter.Take(key, max), nil
}
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter {
return &RedisFrontendRateLimiter{r: r, dur: dur}
}
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
var incr *redis.IntCmd
truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("%s:%d", key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.Expire(ctx, fullKey, r.dur-time.Second)
return nil
})
if err != nil {
return false, err
}
return incr.Val()-1 < int64(max), nil
}
// truncateNow truncates the current timestamp
// to the specified duration.
func truncateNow(dur time.Duration) int64 {
return time.Now().Truncate(dur).Unix()
}

@ -0,0 +1,53 @@
package proxyd
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"
)
func TestFrontendRateLimiter(t *testing.T) {
redisServer, err := miniredis.Run()
require.NoError(t, err)
defer redisServer.Close()
redisClient := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()),
})
lims := []struct {
name string
frl FrontendRateLimiter
}{
{"memory", NewMemoryFrontendRateLimit(2 * time.Second)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)},
}
max := 2
for _, cfg := range lims {
frl := cfg.frl
ctx := context.Background()
t.Run(cfg.name, func(t *testing.T) {
for i := 0; i < 4; i++ {
ok, err := frl.Take(ctx, "foo", max)
require.NoError(t, err)
require.Equal(t, i < max, ok)
ok, err = frl.Take(ctx, "bar", max)
require.NoError(t, err)
require.Equal(t, i < max, ok)
}
time.Sleep(2 * time.Second)
for i := 0; i < 4; i++ {
ok, _ := frl.Take(ctx, "foo", max)
require.Equal(t, i < max, ok)
ok, _ = frl.Take(ctx, "bar", max)
require.Equal(t, i < max, ok)
}
})
}
}