proxyd: Add frontend rate limiting (#3166)

* proxyd: Add frontend rate limiting

To give us more flexibiltiy with rate limiting, proxyd now supports rate limiting of client (frontend) requests in addition to upstream (backend) requests. This PR also gives us the ability to exempt certain user agents/origins from rate limiting.

* lint
This commit is contained in:
Matthew Slipper 2022-08-04 11:34:43 -06:00 committed by GitHub
parent 4ea6a054c3
commit f3d3492a81
12 changed files with 256 additions and 66 deletions

@ -74,6 +74,11 @@ var (
Message: "gateway timeout", Message: "gateway timeout",
HTTPErrorCode: 504, HTTPErrorCode: 504,
} }
ErrOverRateLimit = &RPCErr{
Code: JSONRPCErrorInternal - 16,
Message: "rate limited",
HTTPErrorCode: 429,
}
ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response") ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response")
) )
@ -92,7 +97,7 @@ type Backend struct {
wsURL string wsURL string
authUsername string authUsername string
authPassword string authPassword string
rateLimiter RateLimiter rateLimiter BackendRateLimiter
client *LimitedHTTPClient client *LimitedHTTPClient
dialer *websocket.Dialer dialer *websocket.Dialer
maxRetries int maxRetries int
@ -174,7 +179,7 @@ func NewBackend(
name string, name string,
rpcURL string, rpcURL string,
wsURL string, wsURL string,
rateLimiter RateLimiter, rateLimiter BackendRateLimiter,
rpcSemaphore *semaphore.Weighted, rpcSemaphore *semaphore.Weighted,
opts ...BackendOpt, opts ...BackendOpt,
) *Backend { ) *Backend {
@ -372,10 +377,7 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool
xForwardedFor := GetXForwardedFor(ctx) xForwardedFor := GetXForwardedFor(ctx)
if b.stripTrailingXFF { if b.stripTrailingXFF {
ipList := strings.Split(xForwardedFor, ", ") xForwardedFor = stripXFF(xForwardedFor)
if len(ipList) > 0 {
xForwardedFor = ipList[0]
}
} else if b.proxydIP != "" { } else if b.proxydIP != "" {
xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP) xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP)
} }
@ -855,3 +857,8 @@ func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCR
RecordRPCForward(ctx, backendName, req.Method, source) RecordRPCForward(ctx, backendName, req.Method, source)
} }
} }
func stripXFF(xff string) string {
ipList := strings.Split(xff, ", ")
return strings.TrimSpace(ipList[0])
}

@ -39,6 +39,13 @@ type MetricsConfig struct {
Port int `toml:"port"` Port int `toml:"port"`
} }
type RateLimitConfig struct {
RatePerSecond int `toml:"rate_per_second"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`
}
type BackendOptions struct { type BackendOptions struct {
ResponseTimeoutSeconds int `toml:"response_timeout_seconds"` ResponseTimeoutSeconds int `toml:"response_timeout_seconds"`
MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"` MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"`
@ -75,6 +82,7 @@ type Config struct {
Cache CacheConfig `toml:"cache"` Cache CacheConfig `toml:"cache"`
Redis RedisConfig `toml:"redis"` Redis RedisConfig `toml:"redis"`
Metrics MetricsConfig `toml:"metrics"` Metrics MetricsConfig `toml:"metrics"`
RateLimit RateLimitConfig `toml:"rate_limit"`
BackendOptions BackendOptions `toml:"backend"` BackendOptions BackendOptions `toml:"backend"`
Backends BackendsConfig `toml:"backends"` Backends BackendsConfig `toml:"backends"`
Authentication map[string]string `toml:"authentication"` Authentication map[string]string `toml:"authentication"`

@ -13,6 +13,7 @@ require (
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.2 github.com/rs/cors v1.8.2
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
) )
@ -59,7 +60,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect
golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect
google.golang.org/protobuf v1.27.1 // indirect google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

@ -451,6 +451,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
@ -701,8 +703,8 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

@ -1,8 +1,11 @@
package integration_tests package integration_tests
import ( import (
"fmt"
"net/http"
"os" "os"
"testing" "testing"
"time"
"github.com/ethereum-optimism/optimism/proxyd" "github.com/ethereum-optimism/optimism/proxyd"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -13,18 +16,83 @@ type resWithCode struct {
res []byte res []byte
} }
func TestMaxRPSLimit(t *testing.T) { const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`
func TestBackendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("rate_limit") config := ReadConfig("backend_rate_limit")
client := NewProxydClient("http://127.0.0.1:8545") client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config) shutdown, err := proxyd.Start(config)
require.NoError(t, err) require.NoError(t, err)
defer shutdown() defer shutdown()
limitedRes, codes := spamReqs(t, client, 503)
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
}
func TestFrontendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("frontend_rate_limit")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
t.Run("non-exempt over limit", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
})
t.Run("exempt user agent over limit", func(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
require.Equal(t, 3, codes[200])
})
t.Run("exempt origin over limit", func(t *testing.T) {
h := make(http.Header)
h.Set("Origin", "exempt_origin")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
fmt.Println(codes)
require.Equal(t, 3, codes[200])
})
t.Run("multiple xff", func(t *testing.T) {
h1 := make(http.Header)
h1.Set("X-Forwarded-For", "0.0.0.0")
h2 := make(http.Header)
h2.Set("X-Forwarded-For", "1.1.1.1")
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
_, codes := spamReqs(t, client1, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
_, code, err := client2.SendRPC("eth_chainId", nil)
require.Equal(t, 200, code)
require.NoError(t, err)
time.Sleep(time.Second)
_, code, err = client2.SendRPC("eth_chainId", nil)
require.Equal(t, 200, code)
require.NoError(t, err)
})
}
func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) {
resCh := make(chan *resWithCode) resCh := make(chan *resWithCode)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
go func() { go func() {
@ -48,13 +116,10 @@ func TestMaxRPSLimit(t *testing.T) {
codes[code] += 1 codes[code] += 1
} }
// 503 because there's only one backend available if code == limCode {
if code == 503 {
limitedRes = res.res limitedRes = res.res
} }
} }
require.Equal(t, 2, codes[200]) return limitedRes, codes
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
} }

@ -0,0 +1,23 @@
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
[rate_limit]
rate_per_second = 2
exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit"

@ -20,11 +20,21 @@ import (
) )
type ProxydHTTPClient struct { type ProxydHTTPClient struct {
url string url string
headers http.Header
} }
func NewProxydClient(url string) *ProxydHTTPClient { func NewProxydClient(url string) *ProxydHTTPClient {
return &ProxydHTTPClient{url: url} return NewProxydClientWithHeaders(url, make(http.Header))
}
func NewProxydClientWithHeaders(url string, headers http.Header) *ProxydHTTPClient {
clonedHeaders := headers.Clone()
clonedHeaders.Set("Content-Type", "application/json")
return &ProxydHTTPClient{
url: url,
headers: clonedHeaders,
}
} }
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) { func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
@ -45,7 +55,13 @@ func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, er
} }
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) { func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
res, err := http.Post(p.url, "application/json", bytes.NewReader(body)) req, err := http.NewRequest("POST", p.url, bytes.NewReader(body))
if err != nil {
panic(err)
}
req.Header = p.headers
res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, -1, err return nil, -1, err
} }

@ -43,11 +43,11 @@ func Start(config *Config) (func(), error) {
redisURL = rURL redisURL = rURL
} }
var lim RateLimiter var lim BackendRateLimiter
var err error var err error
if redisURL == "" { if redisURL == "" {
log.Warn("redis is not configured, using local rate limiter") log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalRateLimiter() lim = NewLocalBackendRateLimiter()
} else { } else {
lim, err = NewRedisRateLimiter(redisURL) lim, err = NewRedisRateLimiter(redisURL)
if err != nil { if err != nil {
@ -212,7 +212,7 @@ func Start(config *Config) (func(), error) {
rpcCache = newRPCCache(newCacheWithCompression(cache), blockNumFn, gasPriceFn, config.Cache.NumBlockConfirmations) rpcCache = newRPCCache(newCacheWithCompression(cache), blockNumFn, gasPriceFn, config.Cache.NumBlockConfirmations)
} }
srv := NewServer( srv, err := NewServer(
backendGroups, backendGroups,
wsBackendGroup, wsBackendGroup,
NewStringSetFromStrings(config.WSMethodWhitelist), NewStringSetFromStrings(config.WSMethodWhitelist),
@ -222,9 +222,13 @@ func Start(config *Config) (func(), error) {
secondsToDuration(config.Server.TimeoutSeconds), secondsToDuration(config.Server.TimeoutSeconds),
config.Server.MaxUpstreamBatchSize, config.Server.MaxUpstreamBatchSize,
rpcCache, rpcCache,
config.RateLimit,
config.Server.EnableRequestLog, config.Server.EnableRequestLog,
config.Server.MaxRequestBodyLogLen, config.Server.MaxRequestBodyLogLen,
) )
if err != nil {
return nil, fmt.Errorf("error creating server: %w", err)
}
if config.Metrics.Enabled { if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port) addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)

@ -41,7 +41,7 @@ end
return false return false
` `
type RateLimiter interface { type BackendRateLimiter interface {
IsBackendOnline(name string) (bool, error) IsBackendOnline(name string) (bool, error)
SetBackendOffline(name string, duration time.Duration) error SetBackendOffline(name string, duration time.Duration) error
IncBackendRPS(name string) (int, error) IncBackendRPS(name string) (int, error)
@ -50,14 +50,14 @@ type RateLimiter interface {
FlushBackendWSConns(names []string) error FlushBackendWSConns(names []string) error
} }
type RedisRateLimiter struct { type RedisBackendRateLimiter struct {
rdb *redis.Client rdb *redis.Client
randID string randID string
touchKeys map[string]time.Duration touchKeys map[string]time.Duration
tkMtx sync.Mutex tkMtx sync.Mutex
} }
func NewRedisRateLimiter(url string) (RateLimiter, error) { func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
opts, err := redis.ParseURL(url) opts, err := redis.ParseURL(url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,7 +66,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
if err := rdb.Ping(context.Background()).Err(); err != nil { if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis") return nil, wrapErr(err, "error connecting to redis")
} }
out := &RedisRateLimiter{ out := &RedisBackendRateLimiter{
rdb: rdb, rdb: rdb,
randID: randStr(20), randID: randStr(20),
touchKeys: make(map[string]time.Duration), touchKeys: make(map[string]time.Duration),
@ -75,7 +75,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
return out, nil return out, nil
} }
func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) { func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result() exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result()
if err != nil { if err != nil {
RecordRedisError("IsBackendOnline") RecordRedisError("IsBackendOnline")
@ -85,7 +85,7 @@ func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) {
return exists == 0, nil return exists == 0, nil
} }
func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration) error { func (r *RedisBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
if duration == 0 { if duration == 0 {
return nil return nil
} }
@ -102,7 +102,7 @@ func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration
return nil return nil
} }
func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) { func (r *RedisBackendRateLimiter) IncBackendRPS(name string) (int, error) {
cmd := r.rdb.Eval( cmd := r.rdb.Eval(
context.Background(), context.Background(),
MaxRPSScript, MaxRPSScript,
@ -116,7 +116,7 @@ func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) {
return rps, nil return rps, nil
} }
func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { func (r *RedisBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
r.tkMtx.Lock() r.tkMtx.Lock()
r.touchKeys[connsKey] = 5 * time.Minute r.touchKeys[connsKey] = 5 * time.Minute
@ -142,7 +142,7 @@ func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
return incremented, nil return incremented, nil
} }
func (r *RedisRateLimiter) DecBackendWSConns(name string) error { func (r *RedisBackendRateLimiter) DecBackendWSConns(name string) error {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.Decr(context.Background(), connsKey).Err() err := r.rdb.Decr(context.Background(), connsKey).Err()
if err != nil { if err != nil {
@ -152,7 +152,7 @@ func (r *RedisRateLimiter) DecBackendWSConns(name string) error {
return nil return nil
} }
func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error { func (r *RedisBackendRateLimiter) FlushBackendWSConns(names []string) error {
ctx := context.Background() ctx := context.Background()
for _, name := range names { for _, name := range names {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
@ -172,7 +172,7 @@ func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error {
return nil return nil
} }
func (r *RedisRateLimiter) touch() { func (r *RedisBackendRateLimiter) touch() {
for { for {
r.tkMtx.Lock() r.tkMtx.Lock()
for key, dur := range r.touchKeys { for key, dur := range r.touchKeys {
@ -186,15 +186,15 @@ func (r *RedisRateLimiter) touch() {
} }
} }
type LocalRateLimiter struct { type LocalBackendRateLimiter struct {
deadBackends map[string]time.Time deadBackends map[string]time.Time
backendRPS map[string]int backendRPS map[string]int
backendWSConns map[string]int backendWSConns map[string]int
mtx sync.RWMutex mtx sync.RWMutex
} }
func NewLocalRateLimiter() *LocalRateLimiter { func NewLocalBackendRateLimiter() *LocalBackendRateLimiter {
out := &LocalRateLimiter{ out := &LocalBackendRateLimiter{
deadBackends: make(map[string]time.Time), deadBackends: make(map[string]time.Time),
backendRPS: make(map[string]int), backendRPS: make(map[string]int),
backendWSConns: make(map[string]int), backendWSConns: make(map[string]int),
@ -203,27 +203,27 @@ func NewLocalRateLimiter() *LocalRateLimiter {
return out return out
} }
func (l *LocalRateLimiter) IsBackendOnline(name string) (bool, error) { func (l *LocalBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
l.mtx.RLock() l.mtx.RLock()
defer l.mtx.RUnlock() defer l.mtx.RUnlock()
return l.deadBackends[name].Before(time.Now()), nil return l.deadBackends[name].Before(time.Now()), nil
} }
func (l *LocalRateLimiter) SetBackendOffline(name string, duration time.Duration) error { func (l *LocalBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
l.mtx.Lock() l.mtx.Lock()
defer l.mtx.Unlock() defer l.mtx.Unlock()
l.deadBackends[name] = time.Now().Add(duration) l.deadBackends[name] = time.Now().Add(duration)
return nil return nil
} }
func (l *LocalRateLimiter) IncBackendRPS(name string) (int, error) { func (l *LocalBackendRateLimiter) IncBackendRPS(name string) (int, error) {
l.mtx.Lock() l.mtx.Lock()
defer l.mtx.Unlock() defer l.mtx.Unlock()
l.backendRPS[name] += 1 l.backendRPS[name] += 1
return l.backendRPS[name], nil return l.backendRPS[name], nil
} }
func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { func (l *LocalBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
l.mtx.Lock() l.mtx.Lock()
defer l.mtx.Unlock() defer l.mtx.Unlock()
if l.backendWSConns[name] == max { if l.backendWSConns[name] == max {
@ -233,7 +233,7 @@ func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
return true, nil return true, nil
} }
func (l *LocalRateLimiter) DecBackendWSConns(name string) error { func (l *LocalBackendRateLimiter) DecBackendWSConns(name string) error {
l.mtx.Lock() l.mtx.Lock()
defer l.mtx.Unlock() defer l.mtx.Unlock()
if l.backendWSConns[name] == 0 { if l.backendWSConns[name] == 0 {
@ -243,11 +243,11 @@ func (l *LocalRateLimiter) DecBackendWSConns(name string) error {
return nil return nil
} }
func (l *LocalRateLimiter) FlushBackendWSConns(names []string) error { func (l *LocalBackendRateLimiter) FlushBackendWSConns(names []string) error {
return nil return nil
} }
func (l *LocalRateLimiter) clear() { func (l *LocalBackendRateLimiter) clear() {
for { for {
time.Sleep(time.Second) time.Sleep(time.Second)
l.mtx.Lock() l.mtx.Lock()
@ -263,3 +263,6 @@ func randStr(l int) string {
} }
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
type ServerRateLimiter struct {
}

@ -65,6 +65,14 @@ func (r *RPCErr) Error() string {
return r.Message return r.Message
} }
func (r *RPCErr) Clone() *RPCErr {
return &RPCErr{
Code: r.Code,
Message: r.Message,
HTTPErrorCode: r.HTTPErrorCode,
}
}
func IsValidID(id json.RawMessage) bool { func IsValidID(id json.RawMessage) bool {
// handle the case where the ID is a string // handle the case where the ID is a string
if strings.HasPrefix(string(id), "\"") && strings.HasSuffix(string(id), "\"") { if strings.HasPrefix(string(id), "\"") && strings.HasSuffix(string(id), "\"") {

@ -14,6 +14,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-limiter/noopstore"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -46,6 +50,10 @@ type Server struct {
timeout time.Duration timeout time.Duration
maxUpstreamBatchSize int maxUpstreamBatchSize int
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
lim limiter.Store
limConfig RateLimitConfig
limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool
rpcServer *http.Server rpcServer *http.Server
wsServer *http.Server wsServer *http.Server
cache RPCCache cache RPCCache
@ -62,9 +70,10 @@ func NewServer(
timeout time.Duration, timeout time.Duration,
maxUpstreamBatchSize int, maxUpstreamBatchSize int,
cache RPCCache, cache RPCCache,
rateLimitConfig RateLimitConfig,
enableRequestLog bool, enableRequestLog bool,
maxRequestBodyLogLen int, maxRequestBodyLogLen int,
) *Server { ) (*Server, error) {
if cache == nil { if cache == nil {
cache = &NoopRPCCache{} cache = &NoopRPCCache{}
} }
@ -81,6 +90,29 @@ func NewServer(
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
} }
var lim limiter.Store
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 {
var err error
lim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
if err != nil {
return nil, err
}
for _, origin := range rateLimitConfig.ExemptOrigins {
limExemptOrigins[strings.ToLower(origin)] = true
}
for _, agent := range rateLimitConfig.ExemptUserAgents {
limExemptUserAgents[strings.ToLower(agent)] = true
}
} else {
lim, _ = noopstore.New()
}
return &Server{ return &Server{
backendGroups: backendGroups, backendGroups: backendGroups,
wsBackendGroup: wsBackendGroup, wsBackendGroup: wsBackendGroup,
@ -96,7 +128,11 @@ func NewServer(
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
}, },
} lim: lim,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
}, nil
} }
func (s *Server) RPCListenAndServe(host string, port int) error { func (s *Server) RPCListenAndServe(host string, port int) error {
@ -160,6 +196,28 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx, cancel = context.WithTimeout(ctx, s.timeout) ctx, cancel = context.WithTimeout(ctx, s.timeout)
defer cancel() defer cancel()
exemptOrigin := s.limExemptOrigins[strings.ToLower(r.Header.Get("Origin"))]
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(r.Header.Get("User-Agent"))]
var ok bool
if exemptOrigin || exemptUserAgent {
ok = true
} else {
// Use XFF in context since it will automatically be replaced by the remote IP
xff := stripXFF(GetXForwardedFor(ctx))
if xff == "" {
log.Warn("rejecting request without XFF or remote IP")
ok = false
} else {
_, _, _, ok, _ = s.lim.Take(ctx, xff)
}
}
if !ok {
rpcErr := ErrOverRateLimit.Clone()
rpcErr.Message = s.limConfig.ErrorMessage
writeRPCError(ctx, w, nil, rpcErr)
return
}
log.Info( log.Info(
"received RPC request", "received RPC request",
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
@ -390,6 +448,14 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context { func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r) vars := mux.Vars(r)
authorization := vars["authorization"] authorization := vars["authorization"]
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
ipPort := strings.Split(r.RemoteAddr, ":")
if len(ipPort) == 2 {
xff = ipPort[0]
}
}
ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck
if s.authenticatedPaths == nil { if s.authenticatedPaths == nil {
// handle the edge case where auth is disabled // handle the edge case where auth is disabled
@ -400,30 +466,17 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
w.WriteHeader(404) w.WriteHeader(404)
return nil return nil
} }
return context.WithValue( } else {
r.Context(), if authorization == "" || s.authenticatedPaths[authorization] == "" {
ContextKeyReqID, // nolint:staticcheck log.Info("blocked unauthorized request", "authorization", authorization)
randStr(10), httpResponseCodesTotal.WithLabelValues("401").Inc()
) w.WriteHeader(401)
} return nil
if authorization == "" || s.authenticatedPaths[authorization] == "" {
log.Info("blocked unauthorized request", "authorization", authorization)
httpResponseCodesTotal.WithLabelValues("401").Inc()
w.WriteHeader(401)
return nil
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
ipPort := strings.Split(r.RemoteAddr, ":")
if len(ipPort) == 2 {
xff = ipPort[0]
} }
ctx = context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
} }
ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
ctx = context.WithValue(ctx, ContextKeyXForwardedFor, xff) // nolint:staticcheck
return context.WithValue( return context.WithValue(
ctx, ctx,
ContextKeyReqID, // nolint:staticcheck ContextKeyReqID, // nolint:staticcheck