proxyd: Add frontend rate limiting (#3166)
* proxyd: Add frontend rate limiting To give us more flexibiltiy with rate limiting, proxyd now supports rate limiting of client (frontend) requests in addition to upstream (backend) requests. This PR also gives us the ability to exempt certain user agents/origins from rate limiting. * lint
This commit is contained in:
parent
4ea6a054c3
commit
f3d3492a81
@ -74,6 +74,11 @@ var (
|
||||
Message: "gateway timeout",
|
||||
HTTPErrorCode: 504,
|
||||
}
|
||||
ErrOverRateLimit = &RPCErr{
|
||||
Code: JSONRPCErrorInternal - 16,
|
||||
Message: "rate limited",
|
||||
HTTPErrorCode: 429,
|
||||
}
|
||||
|
||||
ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response")
|
||||
)
|
||||
@ -92,7 +97,7 @@ type Backend struct {
|
||||
wsURL string
|
||||
authUsername string
|
||||
authPassword string
|
||||
rateLimiter RateLimiter
|
||||
rateLimiter BackendRateLimiter
|
||||
client *LimitedHTTPClient
|
||||
dialer *websocket.Dialer
|
||||
maxRetries int
|
||||
@ -174,7 +179,7 @@ func NewBackend(
|
||||
name string,
|
||||
rpcURL string,
|
||||
wsURL string,
|
||||
rateLimiter RateLimiter,
|
||||
rateLimiter BackendRateLimiter,
|
||||
rpcSemaphore *semaphore.Weighted,
|
||||
opts ...BackendOpt,
|
||||
) *Backend {
|
||||
@ -372,10 +377,7 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool
|
||||
|
||||
xForwardedFor := GetXForwardedFor(ctx)
|
||||
if b.stripTrailingXFF {
|
||||
ipList := strings.Split(xForwardedFor, ", ")
|
||||
if len(ipList) > 0 {
|
||||
xForwardedFor = ipList[0]
|
||||
}
|
||||
xForwardedFor = stripXFF(xForwardedFor)
|
||||
} else if b.proxydIP != "" {
|
||||
xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP)
|
||||
}
|
||||
@ -855,3 +857,8 @@ func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCR
|
||||
RecordRPCForward(ctx, backendName, req.Method, source)
|
||||
}
|
||||
}
|
||||
|
||||
func stripXFF(xff string) string {
|
||||
ipList := strings.Split(xff, ", ")
|
||||
return strings.TrimSpace(ipList[0])
|
||||
}
|
||||
|
@ -39,6 +39,13 @@ type MetricsConfig struct {
|
||||
Port int `toml:"port"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
RatePerSecond int `toml:"rate_per_second"`
|
||||
ExemptOrigins []string `toml:"exempt_origins"`
|
||||
ExemptUserAgents []string `toml:"exempt_user_agents"`
|
||||
ErrorMessage string `toml:"error_message"`
|
||||
}
|
||||
|
||||
type BackendOptions struct {
|
||||
ResponseTimeoutSeconds int `toml:"response_timeout_seconds"`
|
||||
MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"`
|
||||
@ -75,6 +82,7 @@ type Config struct {
|
||||
Cache CacheConfig `toml:"cache"`
|
||||
Redis RedisConfig `toml:"redis"`
|
||||
Metrics MetricsConfig `toml:"metrics"`
|
||||
RateLimit RateLimitConfig `toml:"rate_limit"`
|
||||
BackendOptions BackendOptions `toml:"backend"`
|
||||
Backends BackendsConfig `toml:"backends"`
|
||||
Authentication map[string]string `toml:"authentication"`
|
||||
|
@ -13,6 +13,7 @@ require (
|
||||
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/rs/cors v1.8.2
|
||||
github.com/sethvargo/go-limiter v0.7.2
|
||||
github.com/stretchr/testify v1.7.0
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
)
|
||||
@ -59,7 +60,7 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.2 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect
|
||||
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect
|
||||
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect
|
||||
google.golang.org/protobuf v1.27.1 // indirect
|
||||
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
|
@ -451,6 +451,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
|
||||
github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
|
||||
github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
|
||||
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
|
||||
github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
@ -701,8 +703,8 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs=
|
||||
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ=
|
||||
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
@ -1,8 +1,11 @@
|
||||
package integration_tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum-optimism/optimism/proxyd"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -13,18 +16,83 @@ type resWithCode struct {
|
||||
res []byte
|
||||
}
|
||||
|
||||
func TestMaxRPSLimit(t *testing.T) {
|
||||
const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`
|
||||
|
||||
func TestBackendMaxRPSLimit(t *testing.T) {
|
||||
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
|
||||
defer goodBackend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
|
||||
|
||||
config := ReadConfig("rate_limit")
|
||||
config := ReadConfig("backend_rate_limit")
|
||||
client := NewProxydClient("http://127.0.0.1:8545")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
limitedRes, codes := spamReqs(t, client, 503)
|
||||
require.Equal(t, 2, codes[200])
|
||||
require.Equal(t, 1, codes[503])
|
||||
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
|
||||
}
|
||||
|
||||
func TestFrontendMaxRPSLimit(t *testing.T) {
|
||||
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
|
||||
defer goodBackend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
|
||||
|
||||
config := ReadConfig("frontend_rate_limit")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
t.Run("non-exempt over limit", func(t *testing.T) {
|
||||
client := NewProxydClient("http://127.0.0.1:8545")
|
||||
limitedRes, codes := spamReqs(t, client, 429)
|
||||
require.Equal(t, 1, codes[429])
|
||||
require.Equal(t, 2, codes[200])
|
||||
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
|
||||
})
|
||||
|
||||
t.Run("exempt user agent over limit", func(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("User-Agent", "exempt_agent")
|
||||
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
|
||||
_, codes := spamReqs(t, client, 429)
|
||||
require.Equal(t, 3, codes[200])
|
||||
})
|
||||
|
||||
t.Run("exempt origin over limit", func(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("Origin", "exempt_origin")
|
||||
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
|
||||
_, codes := spamReqs(t, client, 429)
|
||||
fmt.Println(codes)
|
||||
require.Equal(t, 3, codes[200])
|
||||
})
|
||||
|
||||
t.Run("multiple xff", func(t *testing.T) {
|
||||
h1 := make(http.Header)
|
||||
h1.Set("X-Forwarded-For", "0.0.0.0")
|
||||
h2 := make(http.Header)
|
||||
h2.Set("X-Forwarded-For", "1.1.1.1")
|
||||
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
|
||||
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
|
||||
_, codes := spamReqs(t, client1, 429)
|
||||
require.Equal(t, 1, codes[429])
|
||||
require.Equal(t, 2, codes[200])
|
||||
_, code, err := client2.SendRPC("eth_chainId", nil)
|
||||
require.Equal(t, 200, code)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
_, code, err = client2.SendRPC("eth_chainId", nil)
|
||||
require.Equal(t, 200, code)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) {
|
||||
resCh := make(chan *resWithCode)
|
||||
for i := 0; i < 3; i++ {
|
||||
go func() {
|
||||
@ -48,13 +116,10 @@ func TestMaxRPSLimit(t *testing.T) {
|
||||
codes[code] += 1
|
||||
}
|
||||
|
||||
// 503 because there's only one backend available
|
||||
if code == 503 {
|
||||
if code == limCode {
|
||||
limitedRes = res.res
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 2, codes[200])
|
||||
require.Equal(t, 1, codes[503])
|
||||
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
|
||||
return limitedRes, codes
|
||||
}
|
||||
|
23
proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml
vendored
Normal file
23
proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
[server]
|
||||
rpc_port = 8545
|
||||
|
||||
[backend]
|
||||
response_timeout_seconds = 1
|
||||
|
||||
[backends]
|
||||
[backends.good]
|
||||
rpc_url = "$GOOD_BACKEND_RPC_URL"
|
||||
ws_url = "$GOOD_BACKEND_RPC_URL"
|
||||
|
||||
[backend_groups]
|
||||
[backend_groups.main]
|
||||
backends = ["good"]
|
||||
|
||||
[rpc_method_mappings]
|
||||
eth_chainId = "main"
|
||||
|
||||
[rate_limit]
|
||||
rate_per_second = 2
|
||||
exempt_origins = ["exempt_origin"]
|
||||
exempt_user_agents = ["exempt_agent"]
|
||||
error_message = "over rate limit"
|
@ -20,11 +20,21 @@ import (
|
||||
)
|
||||
|
||||
type ProxydHTTPClient struct {
|
||||
url string
|
||||
url string
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func NewProxydClient(url string) *ProxydHTTPClient {
|
||||
return &ProxydHTTPClient{url: url}
|
||||
return NewProxydClientWithHeaders(url, make(http.Header))
|
||||
}
|
||||
|
||||
func NewProxydClientWithHeaders(url string, headers http.Header) *ProxydHTTPClient {
|
||||
clonedHeaders := headers.Clone()
|
||||
clonedHeaders.Set("Content-Type", "application/json")
|
||||
return &ProxydHTTPClient{
|
||||
url: url,
|
||||
headers: clonedHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
|
||||
@ -45,7 +55,13 @@ func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, er
|
||||
}
|
||||
|
||||
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
|
||||
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
|
||||
req, err := http.NewRequest("POST", p.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header = p.headers
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
}
|
||||
|
@ -43,11 +43,11 @@ func Start(config *Config) (func(), error) {
|
||||
redisURL = rURL
|
||||
}
|
||||
|
||||
var lim RateLimiter
|
||||
var lim BackendRateLimiter
|
||||
var err error
|
||||
if redisURL == "" {
|
||||
log.Warn("redis is not configured, using local rate limiter")
|
||||
lim = NewLocalRateLimiter()
|
||||
lim = NewLocalBackendRateLimiter()
|
||||
} else {
|
||||
lim, err = NewRedisRateLimiter(redisURL)
|
||||
if err != nil {
|
||||
@ -212,7 +212,7 @@ func Start(config *Config) (func(), error) {
|
||||
rpcCache = newRPCCache(newCacheWithCompression(cache), blockNumFn, gasPriceFn, config.Cache.NumBlockConfirmations)
|
||||
}
|
||||
|
||||
srv := NewServer(
|
||||
srv, err := NewServer(
|
||||
backendGroups,
|
||||
wsBackendGroup,
|
||||
NewStringSetFromStrings(config.WSMethodWhitelist),
|
||||
@ -222,9 +222,13 @@ func Start(config *Config) (func(), error) {
|
||||
secondsToDuration(config.Server.TimeoutSeconds),
|
||||
config.Server.MaxUpstreamBatchSize,
|
||||
rpcCache,
|
||||
config.RateLimit,
|
||||
config.Server.EnableRequestLog,
|
||||
config.Server.MaxRequestBodyLogLen,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating server: %w", err)
|
||||
}
|
||||
|
||||
if config.Metrics.Enabled {
|
||||
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
|
||||
|
@ -41,7 +41,7 @@ end
|
||||
return false
|
||||
`
|
||||
|
||||
type RateLimiter interface {
|
||||
type BackendRateLimiter interface {
|
||||
IsBackendOnline(name string) (bool, error)
|
||||
SetBackendOffline(name string, duration time.Duration) error
|
||||
IncBackendRPS(name string) (int, error)
|
||||
@ -50,14 +50,14 @@ type RateLimiter interface {
|
||||
FlushBackendWSConns(names []string) error
|
||||
}
|
||||
|
||||
type RedisRateLimiter struct {
|
||||
type RedisBackendRateLimiter struct {
|
||||
rdb *redis.Client
|
||||
randID string
|
||||
touchKeys map[string]time.Duration
|
||||
tkMtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewRedisRateLimiter(url string) (RateLimiter, error) {
|
||||
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -66,7 +66,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
|
||||
if err := rdb.Ping(context.Background()).Err(); err != nil {
|
||||
return nil, wrapErr(err, "error connecting to redis")
|
||||
}
|
||||
out := &RedisRateLimiter{
|
||||
out := &RedisBackendRateLimiter{
|
||||
rdb: rdb,
|
||||
randID: randStr(20),
|
||||
touchKeys: make(map[string]time.Duration),
|
||||
@ -75,7 +75,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result()
|
||||
if err != nil {
|
||||
RecordRedisError("IsBackendOnline")
|
||||
@ -85,7 +85,7 @@ func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
return exists == 0, nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
|
||||
func (r *RedisBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
|
||||
if duration == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -102,7 +102,7 @@ func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) {
|
||||
func (r *RedisBackendRateLimiter) IncBackendRPS(name string) (int, error) {
|
||||
cmd := r.rdb.Eval(
|
||||
context.Background(),
|
||||
MaxRPSScript,
|
||||
@ -116,7 +116,7 @@ func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) {
|
||||
return rps, nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
|
||||
func (r *RedisBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
|
||||
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
|
||||
r.tkMtx.Lock()
|
||||
r.touchKeys[connsKey] = 5 * time.Minute
|
||||
@ -142,7 +142,7 @@ func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
|
||||
return incremented, nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) DecBackendWSConns(name string) error {
|
||||
func (r *RedisBackendRateLimiter) DecBackendWSConns(name string) error {
|
||||
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
|
||||
err := r.rdb.Decr(context.Background(), connsKey).Err()
|
||||
if err != nil {
|
||||
@ -152,7 +152,7 @@ func (r *RedisRateLimiter) DecBackendWSConns(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error {
|
||||
func (r *RedisBackendRateLimiter) FlushBackendWSConns(names []string) error {
|
||||
ctx := context.Background()
|
||||
for _, name := range names {
|
||||
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
|
||||
@ -172,7 +172,7 @@ func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisRateLimiter) touch() {
|
||||
func (r *RedisBackendRateLimiter) touch() {
|
||||
for {
|
||||
r.tkMtx.Lock()
|
||||
for key, dur := range r.touchKeys {
|
||||
@ -186,15 +186,15 @@ func (r *RedisRateLimiter) touch() {
|
||||
}
|
||||
}
|
||||
|
||||
type LocalRateLimiter struct {
|
||||
type LocalBackendRateLimiter struct {
|
||||
deadBackends map[string]time.Time
|
||||
backendRPS map[string]int
|
||||
backendWSConns map[string]int
|
||||
mtx sync.RWMutex
|
||||
}
|
||||
|
||||
func NewLocalRateLimiter() *LocalRateLimiter {
|
||||
out := &LocalRateLimiter{
|
||||
func NewLocalBackendRateLimiter() *LocalBackendRateLimiter {
|
||||
out := &LocalBackendRateLimiter{
|
||||
deadBackends: make(map[string]time.Time),
|
||||
backendRPS: make(map[string]int),
|
||||
backendWSConns: make(map[string]int),
|
||||
@ -203,27 +203,27 @@ func NewLocalRateLimiter() *LocalRateLimiter {
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
func (l *LocalBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
|
||||
l.mtx.RLock()
|
||||
defer l.mtx.RUnlock()
|
||||
return l.deadBackends[name].Before(time.Now()), nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
|
||||
func (l *LocalBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
|
||||
l.mtx.Lock()
|
||||
defer l.mtx.Unlock()
|
||||
l.deadBackends[name] = time.Now().Add(duration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) IncBackendRPS(name string) (int, error) {
|
||||
func (l *LocalBackendRateLimiter) IncBackendRPS(name string) (int, error) {
|
||||
l.mtx.Lock()
|
||||
defer l.mtx.Unlock()
|
||||
l.backendRPS[name] += 1
|
||||
return l.backendRPS[name], nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
|
||||
func (l *LocalBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
|
||||
l.mtx.Lock()
|
||||
defer l.mtx.Unlock()
|
||||
if l.backendWSConns[name] == max {
|
||||
@ -233,7 +233,7 @@ func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) DecBackendWSConns(name string) error {
|
||||
func (l *LocalBackendRateLimiter) DecBackendWSConns(name string) error {
|
||||
l.mtx.Lock()
|
||||
defer l.mtx.Unlock()
|
||||
if l.backendWSConns[name] == 0 {
|
||||
@ -243,11 +243,11 @@ func (l *LocalRateLimiter) DecBackendWSConns(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) FlushBackendWSConns(names []string) error {
|
||||
func (l *LocalBackendRateLimiter) FlushBackendWSConns(names []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalRateLimiter) clear() {
|
||||
func (l *LocalBackendRateLimiter) clear() {
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
l.mtx.Lock()
|
||||
@ -263,3 +263,6 @@ func randStr(l int) string {
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
type ServerRateLimiter struct {
|
||||
}
|
||||
|
@ -65,6 +65,14 @@ func (r *RPCErr) Error() string {
|
||||
return r.Message
|
||||
}
|
||||
|
||||
func (r *RPCErr) Clone() *RPCErr {
|
||||
return &RPCErr{
|
||||
Code: r.Code,
|
||||
Message: r.Message,
|
||||
HTTPErrorCode: r.HTTPErrorCode,
|
||||
}
|
||||
}
|
||||
|
||||
func IsValidID(id json.RawMessage) bool {
|
||||
// handle the case where the ID is a string
|
||||
if strings.HasPrefix(string(id), "\"") && strings.HasSuffix(string(id), "\"") {
|
||||
|
@ -14,6 +14,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-limiter"
|
||||
"github.com/sethvargo/go-limiter/memorystore"
|
||||
"github.com/sethvargo/go-limiter/noopstore"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -46,6 +50,10 @@ type Server struct {
|
||||
timeout time.Duration
|
||||
maxUpstreamBatchSize int
|
||||
upgrader *websocket.Upgrader
|
||||
lim limiter.Store
|
||||
limConfig RateLimitConfig
|
||||
limExemptOrigins map[string]bool
|
||||
limExemptUserAgents map[string]bool
|
||||
rpcServer *http.Server
|
||||
wsServer *http.Server
|
||||
cache RPCCache
|
||||
@ -62,9 +70,10 @@ func NewServer(
|
||||
timeout time.Duration,
|
||||
maxUpstreamBatchSize int,
|
||||
cache RPCCache,
|
||||
rateLimitConfig RateLimitConfig,
|
||||
enableRequestLog bool,
|
||||
maxRequestBodyLogLen int,
|
||||
) *Server {
|
||||
) (*Server, error) {
|
||||
if cache == nil {
|
||||
cache = &NoopRPCCache{}
|
||||
}
|
||||
@ -81,6 +90,29 @@ func NewServer(
|
||||
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
|
||||
}
|
||||
|
||||
var lim limiter.Store
|
||||
limExemptOrigins := make(map[string]bool)
|
||||
limExemptUserAgents := make(map[string]bool)
|
||||
if rateLimitConfig.RatePerSecond > 0 {
|
||||
var err error
|
||||
lim, err = memorystore.New(&memorystore.Config{
|
||||
Tokens: uint64(rateLimitConfig.RatePerSecond),
|
||||
Interval: time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, origin := range rateLimitConfig.ExemptOrigins {
|
||||
limExemptOrigins[strings.ToLower(origin)] = true
|
||||
}
|
||||
for _, agent := range rateLimitConfig.ExemptUserAgents {
|
||||
limExemptUserAgents[strings.ToLower(agent)] = true
|
||||
}
|
||||
} else {
|
||||
lim, _ = noopstore.New()
|
||||
}
|
||||
|
||||
return &Server{
|
||||
backendGroups: backendGroups,
|
||||
wsBackendGroup: wsBackendGroup,
|
||||
@ -96,7 +128,11 @@ func NewServer(
|
||||
upgrader: &websocket.Upgrader{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
lim: lim,
|
||||
limConfig: rateLimitConfig,
|
||||
limExemptOrigins: limExemptOrigins,
|
||||
limExemptUserAgents: limExemptUserAgents,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RPCListenAndServe(host string, port int) error {
|
||||
@ -160,6 +196,28 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
|
||||
exemptOrigin := s.limExemptOrigins[strings.ToLower(r.Header.Get("Origin"))]
|
||||
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(r.Header.Get("User-Agent"))]
|
||||
var ok bool
|
||||
if exemptOrigin || exemptUserAgent {
|
||||
ok = true
|
||||
} else {
|
||||
// Use XFF in context since it will automatically be replaced by the remote IP
|
||||
xff := stripXFF(GetXForwardedFor(ctx))
|
||||
if xff == "" {
|
||||
log.Warn("rejecting request without XFF or remote IP")
|
||||
ok = false
|
||||
} else {
|
||||
_, _, _, ok, _ = s.lim.Take(ctx, xff)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
rpcErr := ErrOverRateLimit.Clone()
|
||||
rpcErr.Message = s.limConfig.ErrorMessage
|
||||
writeRPCError(ctx, w, nil, rpcErr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"received RPC request",
|
||||
"req_id", GetReqID(ctx),
|
||||
@ -390,6 +448,14 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
|
||||
vars := mux.Vars(r)
|
||||
authorization := vars["authorization"]
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
ipPort := strings.Split(r.RemoteAddr, ":")
|
||||
if len(ipPort) == 2 {
|
||||
xff = ipPort[0]
|
||||
}
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck
|
||||
|
||||
if s.authenticatedPaths == nil {
|
||||
// handle the edge case where auth is disabled
|
||||
@ -400,30 +466,17 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
|
||||
w.WriteHeader(404)
|
||||
return nil
|
||||
}
|
||||
return context.WithValue(
|
||||
r.Context(),
|
||||
ContextKeyReqID, // nolint:staticcheck
|
||||
randStr(10),
|
||||
)
|
||||
}
|
||||
|
||||
if authorization == "" || s.authenticatedPaths[authorization] == "" {
|
||||
log.Info("blocked unauthorized request", "authorization", authorization)
|
||||
httpResponseCodesTotal.WithLabelValues("401").Inc()
|
||||
w.WriteHeader(401)
|
||||
return nil
|
||||
}
|
||||
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
ipPort := strings.Split(r.RemoteAddr, ":")
|
||||
if len(ipPort) == 2 {
|
||||
xff = ipPort[0]
|
||||
} else {
|
||||
if authorization == "" || s.authenticatedPaths[authorization] == "" {
|
||||
log.Info("blocked unauthorized request", "authorization", authorization)
|
||||
httpResponseCodesTotal.WithLabelValues("401").Inc()
|
||||
w.WriteHeader(401)
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx = context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, ContextKeyXForwardedFor, xff) // nolint:staticcheck
|
||||
return context.WithValue(
|
||||
ctx,
|
||||
ContextKeyReqID, // nolint:staticcheck
|
||||
|
Loading…
Reference in New Issue
Block a user