From f3d3492a816fc96442fec640bf558be3ebb3677a Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Thu, 4 Aug 2022 11:34:43 -0600 Subject: [PATCH] proxyd: Add frontend rate limiting (#3166) * proxyd: Add frontend rate limiting To give us more flexibiltiy with rate limiting, proxyd now supports rate limiting of client (frontend) requests in addition to upstream (backend) requests. This PR also gives us the ability to exempt certain user agents/origins from rate limiting. * lint --- proxyd/proxyd/backend.go | 19 ++-- proxyd/proxyd/config.go | 8 ++ proxyd/proxyd/go.mod | 3 +- proxyd/proxyd/go.sum | 6 +- .../integration_tests/rate_limit_test.go | 79 +++++++++++++-- ...ate_limit.toml => backend_rate_limit.toml} | 0 .../testdata/frontend_rate_limit.toml | 23 +++++ proxyd/proxyd/integration_tests/util_test.go | 22 ++++- proxyd/proxyd/proxyd.go | 10 +- proxyd/proxyd/rate_limiter.go | 45 +++++---- proxyd/proxyd/rpc.go | 8 ++ proxyd/proxyd/server.go | 99 ++++++++++++++----- 12 files changed, 256 insertions(+), 66 deletions(-) rename proxyd/proxyd/integration_tests/testdata/{rate_limit.toml => backend_rate_limit.toml} (100%) create mode 100644 proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml diff --git a/proxyd/proxyd/backend.go b/proxyd/proxyd/backend.go index f4d7d93..80de372 100644 --- a/proxyd/proxyd/backend.go +++ b/proxyd/proxyd/backend.go @@ -74,6 +74,11 @@ var ( Message: "gateway timeout", HTTPErrorCode: 504, } + ErrOverRateLimit = &RPCErr{ + Code: JSONRPCErrorInternal - 16, + Message: "rate limited", + HTTPErrorCode: 429, + } ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response") ) @@ -92,7 +97,7 @@ type Backend struct { wsURL string authUsername string authPassword string - rateLimiter RateLimiter + rateLimiter BackendRateLimiter client *LimitedHTTPClient dialer *websocket.Dialer maxRetries int @@ -174,7 +179,7 @@ func NewBackend( name string, rpcURL string, wsURL string, - rateLimiter RateLimiter, + rateLimiter BackendRateLimiter, rpcSemaphore *semaphore.Weighted, opts ...BackendOpt, ) *Backend { @@ -372,10 +377,7 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool xForwardedFor := GetXForwardedFor(ctx) if b.stripTrailingXFF { - ipList := strings.Split(xForwardedFor, ", ") - if len(ipList) > 0 { - xForwardedFor = ipList[0] - } + xForwardedFor = stripXFF(xForwardedFor) } else if b.proxydIP != "" { xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP) } @@ -855,3 +857,8 @@ func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCR RecordRPCForward(ctx, backendName, req.Method, source) } } + +func stripXFF(xff string) string { + ipList := strings.Split(xff, ", ") + return strings.TrimSpace(ipList[0]) +} diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index db46167..56eb3cc 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -39,6 +39,13 @@ type MetricsConfig struct { Port int `toml:"port"` } +type RateLimitConfig struct { + RatePerSecond int `toml:"rate_per_second"` + ExemptOrigins []string `toml:"exempt_origins"` + ExemptUserAgents []string `toml:"exempt_user_agents"` + ErrorMessage string `toml:"error_message"` +} + type BackendOptions struct { ResponseTimeoutSeconds int `toml:"response_timeout_seconds"` MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"` @@ -75,6 +82,7 @@ type Config struct { Cache CacheConfig `toml:"cache"` Redis RedisConfig `toml:"redis"` Metrics MetricsConfig `toml:"metrics"` + RateLimit RateLimitConfig `toml:"rate_limit"` BackendOptions BackendOptions `toml:"backend"` Backends BackendsConfig `toml:"backends"` Authentication map[string]string `toml:"authentication"` diff --git a/proxyd/proxyd/go.mod b/proxyd/proxyd/go.mod index 73d7d85..0a866b8 100644 --- a/proxyd/proxyd/go.mod +++ b/proxyd/proxyd/go.mod @@ -13,6 +13,7 @@ require ( github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/prometheus/client_golang v1.11.0 github.com/rs/cors v1.8.2 + github.com/sethvargo/go-limiter v0.7.2 github.com/stretchr/testify v1.7.0 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c ) @@ -59,7 +60,7 @@ require ( github.com/yusufpapurcu/wmi v1.2.2 // indirect golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect - golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect + golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect google.golang.org/protobuf v1.27.1 // indirect gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect diff --git a/proxyd/proxyd/go.sum b/proxyd/proxyd/go.sum index d6d5053..8e5fd8d 100644 --- a/proxyd/proxyd/go.sum +++ b/proxyd/proxyd/go.sum @@ -451,6 +451,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8= +github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= @@ -701,8 +703,8 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ= +golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/proxyd/proxyd/integration_tests/rate_limit_test.go b/proxyd/proxyd/integration_tests/rate_limit_test.go index 409598e..4648017 100644 --- a/proxyd/proxyd/integration_tests/rate_limit_test.go +++ b/proxyd/proxyd/integration_tests/rate_limit_test.go @@ -1,8 +1,11 @@ package integration_tests import ( + "fmt" + "net/http" "os" "testing" + "time" "github.com/ethereum-optimism/optimism/proxyd" "github.com/stretchr/testify/require" @@ -13,18 +16,83 @@ type resWithCode struct { res []byte } -func TestMaxRPSLimit(t *testing.T) { +const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}` + +func TestBackendMaxRPSLimit(t *testing.T) { goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) defer goodBackend.Close() require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) - config := ReadConfig("rate_limit") + config := ReadConfig("backend_rate_limit") client := NewProxydClient("http://127.0.0.1:8545") shutdown, err := proxyd.Start(config) require.NoError(t, err) defer shutdown() + limitedRes, codes := spamReqs(t, client, 503) + require.Equal(t, 2, codes[200]) + require.Equal(t, 1, codes[503]) + RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes) +} + +func TestFrontendMaxRPSLimit(t *testing.T) { + goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse)) + defer goodBackend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) + + config := ReadConfig("frontend_rate_limit") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + defer shutdown() + + t.Run("non-exempt over limit", func(t *testing.T) { + client := NewProxydClient("http://127.0.0.1:8545") + limitedRes, codes := spamReqs(t, client, 429) + require.Equal(t, 1, codes[429]) + require.Equal(t, 2, codes[200]) + RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes) + }) + + t.Run("exempt user agent over limit", func(t *testing.T) { + h := make(http.Header) + h.Set("User-Agent", "exempt_agent") + client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) + _, codes := spamReqs(t, client, 429) + require.Equal(t, 3, codes[200]) + }) + + t.Run("exempt origin over limit", func(t *testing.T) { + h := make(http.Header) + h.Set("Origin", "exempt_origin") + client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) + _, codes := spamReqs(t, client, 429) + fmt.Println(codes) + require.Equal(t, 3, codes[200]) + }) + + t.Run("multiple xff", func(t *testing.T) { + h1 := make(http.Header) + h1.Set("X-Forwarded-For", "0.0.0.0") + h2 := make(http.Header) + h2.Set("X-Forwarded-For", "1.1.1.1") + client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1) + client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2) + _, codes := spamReqs(t, client1, 429) + require.Equal(t, 1, codes[429]) + require.Equal(t, 2, codes[200]) + _, code, err := client2.SendRPC("eth_chainId", nil) + require.Equal(t, 200, code) + require.NoError(t, err) + time.Sleep(time.Second) + _, code, err = client2.SendRPC("eth_chainId", nil) + require.Equal(t, 200, code) + require.NoError(t, err) + }) +} + +func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) { resCh := make(chan *resWithCode) for i := 0; i < 3; i++ { go func() { @@ -48,13 +116,10 @@ func TestMaxRPSLimit(t *testing.T) { codes[code] += 1 } - // 503 because there's only one backend available - if code == 503 { + if code == limCode { limitedRes = res.res } } - require.Equal(t, 2, codes[200]) - require.Equal(t, 1, codes[503]) - RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes) + return limitedRes, codes } diff --git a/proxyd/proxyd/integration_tests/testdata/rate_limit.toml b/proxyd/proxyd/integration_tests/testdata/backend_rate_limit.toml similarity index 100% rename from proxyd/proxyd/integration_tests/testdata/rate_limit.toml rename to proxyd/proxyd/integration_tests/testdata/backend_rate_limit.toml diff --git a/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml new file mode 100644 index 0000000..e838ace --- /dev/null +++ b/proxyd/proxyd/integration_tests/testdata/frontend_rate_limit.toml @@ -0,0 +1,23 @@ +[server] +rpc_port = 8545 + +[backend] +response_timeout_seconds = 1 + +[backends] +[backends.good] +rpc_url = "$GOOD_BACKEND_RPC_URL" +ws_url = "$GOOD_BACKEND_RPC_URL" + +[backend_groups] +[backend_groups.main] +backends = ["good"] + +[rpc_method_mappings] +eth_chainId = "main" + +[rate_limit] +rate_per_second = 2 +exempt_origins = ["exempt_origin"] +exempt_user_agents = ["exempt_agent"] +error_message = "over rate limit" diff --git a/proxyd/proxyd/integration_tests/util_test.go b/proxyd/proxyd/integration_tests/util_test.go index c5c15a6..2e443f6 100644 --- a/proxyd/proxyd/integration_tests/util_test.go +++ b/proxyd/proxyd/integration_tests/util_test.go @@ -20,11 +20,21 @@ import ( ) type ProxydHTTPClient struct { - url string + url string + headers http.Header } func NewProxydClient(url string) *ProxydHTTPClient { - return &ProxydHTTPClient{url: url} + return NewProxydClientWithHeaders(url, make(http.Header)) +} + +func NewProxydClientWithHeaders(url string, headers http.Header) *ProxydHTTPClient { + clonedHeaders := headers.Clone() + clonedHeaders.Set("Content-Type", "application/json") + return &ProxydHTTPClient{ + url: url, + headers: clonedHeaders, + } } func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) { @@ -45,7 +55,13 @@ func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, er } func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) { - res, err := http.Post(p.url, "application/json", bytes.NewReader(body)) + req, err := http.NewRequest("POST", p.url, bytes.NewReader(body)) + if err != nil { + panic(err) + } + req.Header = p.headers + + res, err := http.DefaultClient.Do(req) if err != nil { return nil, -1, err } diff --git a/proxyd/proxyd/proxyd.go b/proxyd/proxyd/proxyd.go index 4b616c5..e9bbe42 100644 --- a/proxyd/proxyd/proxyd.go +++ b/proxyd/proxyd/proxyd.go @@ -43,11 +43,11 @@ func Start(config *Config) (func(), error) { redisURL = rURL } - var lim RateLimiter + var lim BackendRateLimiter var err error if redisURL == "" { log.Warn("redis is not configured, using local rate limiter") - lim = NewLocalRateLimiter() + lim = NewLocalBackendRateLimiter() } else { lim, err = NewRedisRateLimiter(redisURL) if err != nil { @@ -212,7 +212,7 @@ func Start(config *Config) (func(), error) { rpcCache = newRPCCache(newCacheWithCompression(cache), blockNumFn, gasPriceFn, config.Cache.NumBlockConfirmations) } - srv := NewServer( + srv, err := NewServer( backendGroups, wsBackendGroup, NewStringSetFromStrings(config.WSMethodWhitelist), @@ -222,9 +222,13 @@ func Start(config *Config) (func(), error) { secondsToDuration(config.Server.TimeoutSeconds), config.Server.MaxUpstreamBatchSize, rpcCache, + config.RateLimit, config.Server.EnableRequestLog, config.Server.MaxRequestBodyLogLen, ) + if err != nil { + return nil, fmt.Errorf("error creating server: %w", err) + } if config.Metrics.Enabled { addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port) diff --git a/proxyd/proxyd/rate_limiter.go b/proxyd/proxyd/rate_limiter.go index 5c4a4d6..fe286e6 100644 --- a/proxyd/proxyd/rate_limiter.go +++ b/proxyd/proxyd/rate_limiter.go @@ -41,7 +41,7 @@ end return false ` -type RateLimiter interface { +type BackendRateLimiter interface { IsBackendOnline(name string) (bool, error) SetBackendOffline(name string, duration time.Duration) error IncBackendRPS(name string) (int, error) @@ -50,14 +50,14 @@ type RateLimiter interface { FlushBackendWSConns(names []string) error } -type RedisRateLimiter struct { +type RedisBackendRateLimiter struct { rdb *redis.Client randID string touchKeys map[string]time.Duration tkMtx sync.Mutex } -func NewRedisRateLimiter(url string) (RateLimiter, error) { +func NewRedisRateLimiter(url string) (BackendRateLimiter, error) { opts, err := redis.ParseURL(url) if err != nil { return nil, err @@ -66,7 +66,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) { if err := rdb.Ping(context.Background()).Err(); err != nil { return nil, wrapErr(err, "error connecting to redis") } - out := &RedisRateLimiter{ + out := &RedisBackendRateLimiter{ rdb: rdb, randID: randStr(20), touchKeys: make(map[string]time.Duration), @@ -75,7 +75,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) { return out, nil } -func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) { +func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) { exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result() if err != nil { RecordRedisError("IsBackendOnline") @@ -85,7 +85,7 @@ func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) { return exists == 0, nil } -func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration) error { +func (r *RedisBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error { if duration == 0 { return nil } @@ -102,7 +102,7 @@ func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration return nil } -func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) { +func (r *RedisBackendRateLimiter) IncBackendRPS(name string) (int, error) { cmd := r.rdb.Eval( context.Background(), MaxRPSScript, @@ -116,7 +116,7 @@ func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) { return rps, nil } -func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { +func (r *RedisBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) r.tkMtx.Lock() r.touchKeys[connsKey] = 5 * time.Minute @@ -142,7 +142,7 @@ func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error) return incremented, nil } -func (r *RedisRateLimiter) DecBackendWSConns(name string) error { +func (r *RedisBackendRateLimiter) DecBackendWSConns(name string) error { connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) err := r.rdb.Decr(context.Background(), connsKey).Err() if err != nil { @@ -152,7 +152,7 @@ func (r *RedisRateLimiter) DecBackendWSConns(name string) error { return nil } -func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error { +func (r *RedisBackendRateLimiter) FlushBackendWSConns(names []string) error { ctx := context.Background() for _, name := range names { connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name) @@ -172,7 +172,7 @@ func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error { return nil } -func (r *RedisRateLimiter) touch() { +func (r *RedisBackendRateLimiter) touch() { for { r.tkMtx.Lock() for key, dur := range r.touchKeys { @@ -186,15 +186,15 @@ func (r *RedisRateLimiter) touch() { } } -type LocalRateLimiter struct { +type LocalBackendRateLimiter struct { deadBackends map[string]time.Time backendRPS map[string]int backendWSConns map[string]int mtx sync.RWMutex } -func NewLocalRateLimiter() *LocalRateLimiter { - out := &LocalRateLimiter{ +func NewLocalBackendRateLimiter() *LocalBackendRateLimiter { + out := &LocalBackendRateLimiter{ deadBackends: make(map[string]time.Time), backendRPS: make(map[string]int), backendWSConns: make(map[string]int), @@ -203,27 +203,27 @@ func NewLocalRateLimiter() *LocalRateLimiter { return out } -func (l *LocalRateLimiter) IsBackendOnline(name string) (bool, error) { +func (l *LocalBackendRateLimiter) IsBackendOnline(name string) (bool, error) { l.mtx.RLock() defer l.mtx.RUnlock() return l.deadBackends[name].Before(time.Now()), nil } -func (l *LocalRateLimiter) SetBackendOffline(name string, duration time.Duration) error { +func (l *LocalBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error { l.mtx.Lock() defer l.mtx.Unlock() l.deadBackends[name] = time.Now().Add(duration) return nil } -func (l *LocalRateLimiter) IncBackendRPS(name string) (int, error) { +func (l *LocalBackendRateLimiter) IncBackendRPS(name string) (int, error) { l.mtx.Lock() defer l.mtx.Unlock() l.backendRPS[name] += 1 return l.backendRPS[name], nil } -func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { +func (l *LocalBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) { l.mtx.Lock() defer l.mtx.Unlock() if l.backendWSConns[name] == max { @@ -233,7 +233,7 @@ func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error) return true, nil } -func (l *LocalRateLimiter) DecBackendWSConns(name string) error { +func (l *LocalBackendRateLimiter) DecBackendWSConns(name string) error { l.mtx.Lock() defer l.mtx.Unlock() if l.backendWSConns[name] == 0 { @@ -243,11 +243,11 @@ func (l *LocalRateLimiter) DecBackendWSConns(name string) error { return nil } -func (l *LocalRateLimiter) FlushBackendWSConns(names []string) error { +func (l *LocalBackendRateLimiter) FlushBackendWSConns(names []string) error { return nil } -func (l *LocalRateLimiter) clear() { +func (l *LocalBackendRateLimiter) clear() { for { time.Sleep(time.Second) l.mtx.Lock() @@ -263,3 +263,6 @@ func randStr(l int) string { } return hex.EncodeToString(b) } + +type ServerRateLimiter struct { +} diff --git a/proxyd/proxyd/rpc.go b/proxyd/proxyd/rpc.go index 5f16822..ccd7c5f 100644 --- a/proxyd/proxyd/rpc.go +++ b/proxyd/proxyd/rpc.go @@ -65,6 +65,14 @@ func (r *RPCErr) Error() string { return r.Message } +func (r *RPCErr) Clone() *RPCErr { + return &RPCErr{ + Code: r.Code, + Message: r.Message, + HTTPErrorCode: r.HTTPErrorCode, + } +} + func IsValidID(id json.RawMessage) bool { // handle the case where the ID is a string if strings.HasPrefix(string(id), "\"") && strings.HasSuffix(string(id), "\"") { diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 58cd739..27f3f0e 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -14,6 +14,10 @@ import ( "sync" "time" + "github.com/sethvargo/go-limiter" + "github.com/sethvargo/go-limiter/memorystore" + "github.com/sethvargo/go-limiter/noopstore" + "github.com/ethereum/go-ethereum/log" "github.com/gorilla/mux" "github.com/gorilla/websocket" @@ -46,6 +50,10 @@ type Server struct { timeout time.Duration maxUpstreamBatchSize int upgrader *websocket.Upgrader + lim limiter.Store + limConfig RateLimitConfig + limExemptOrigins map[string]bool + limExemptUserAgents map[string]bool rpcServer *http.Server wsServer *http.Server cache RPCCache @@ -62,9 +70,10 @@ func NewServer( timeout time.Duration, maxUpstreamBatchSize int, cache RPCCache, + rateLimitConfig RateLimitConfig, enableRequestLog bool, maxRequestBodyLogLen int, -) *Server { +) (*Server, error) { if cache == nil { cache = &NoopRPCCache{} } @@ -81,6 +90,29 @@ func NewServer( maxUpstreamBatchSize = defaultMaxUpstreamBatchSize } + var lim limiter.Store + limExemptOrigins := make(map[string]bool) + limExemptUserAgents := make(map[string]bool) + if rateLimitConfig.RatePerSecond > 0 { + var err error + lim, err = memorystore.New(&memorystore.Config{ + Tokens: uint64(rateLimitConfig.RatePerSecond), + Interval: time.Second, + }) + if err != nil { + return nil, err + } + + for _, origin := range rateLimitConfig.ExemptOrigins { + limExemptOrigins[strings.ToLower(origin)] = true + } + for _, agent := range rateLimitConfig.ExemptUserAgents { + limExemptUserAgents[strings.ToLower(agent)] = true + } + } else { + lim, _ = noopstore.New() + } + return &Server{ backendGroups: backendGroups, wsBackendGroup: wsBackendGroup, @@ -96,7 +128,11 @@ func NewServer( upgrader: &websocket.Upgrader{ HandshakeTimeout: 5 * time.Second, }, - } + lim: lim, + limConfig: rateLimitConfig, + limExemptOrigins: limExemptOrigins, + limExemptUserAgents: limExemptUserAgents, + }, nil } func (s *Server) RPCListenAndServe(host string, port int) error { @@ -160,6 +196,28 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() + exemptOrigin := s.limExemptOrigins[strings.ToLower(r.Header.Get("Origin"))] + exemptUserAgent := s.limExemptUserAgents[strings.ToLower(r.Header.Get("User-Agent"))] + var ok bool + if exemptOrigin || exemptUserAgent { + ok = true + } else { + // Use XFF in context since it will automatically be replaced by the remote IP + xff := stripXFF(GetXForwardedFor(ctx)) + if xff == "" { + log.Warn("rejecting request without XFF or remote IP") + ok = false + } else { + _, _, _, ok, _ = s.lim.Take(ctx, xff) + } + } + if !ok { + rpcErr := ErrOverRateLimit.Clone() + rpcErr.Message = s.limConfig.ErrorMessage + writeRPCError(ctx, w, nil, rpcErr) + return + } + log.Info( "received RPC request", "req_id", GetReqID(ctx), @@ -390,6 +448,14 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context { vars := mux.Vars(r) authorization := vars["authorization"] + xff := r.Header.Get("X-Forwarded-For") + if xff == "" { + ipPort := strings.Split(r.RemoteAddr, ":") + if len(ipPort) == 2 { + xff = ipPort[0] + } + } + ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck if s.authenticatedPaths == nil { // handle the edge case where auth is disabled @@ -400,30 +466,17 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context w.WriteHeader(404) return nil } - return context.WithValue( - r.Context(), - ContextKeyReqID, // nolint:staticcheck - randStr(10), - ) - } - - if authorization == "" || s.authenticatedPaths[authorization] == "" { - log.Info("blocked unauthorized request", "authorization", authorization) - httpResponseCodesTotal.WithLabelValues("401").Inc() - w.WriteHeader(401) - return nil - } - - xff := r.Header.Get("X-Forwarded-For") - if xff == "" { - ipPort := strings.Split(r.RemoteAddr, ":") - if len(ipPort) == 2 { - xff = ipPort[0] + } else { + if authorization == "" || s.authenticatedPaths[authorization] == "" { + log.Info("blocked unauthorized request", "authorization", authorization) + httpResponseCodesTotal.WithLabelValues("401").Inc() + w.WriteHeader(401) + return nil } + + ctx = context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck } - ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck - ctx = context.WithValue(ctx, ContextKeyXForwardedFor, xff) // nolint:staticcheck return context.WithValue( ctx, ContextKeyReqID, // nolint:staticcheck