proxyd: Add frontend rate limiting (#3166)

* proxyd: Add frontend rate limiting

To give us more flexibiltiy with rate limiting, proxyd now supports rate limiting of client (frontend) requests in addition to upstream (backend) requests. This PR also gives us the ability to exempt certain user agents/origins from rate limiting.

* lint
This commit is contained in:
Matthew Slipper 2022-08-04 11:34:43 -06:00 committed by GitHub
parent 4ea6a054c3
commit f3d3492a81
12 changed files with 256 additions and 66 deletions

@ -74,6 +74,11 @@ var (
Message: "gateway timeout",
HTTPErrorCode: 504,
}
ErrOverRateLimit = &RPCErr{
Code: JSONRPCErrorInternal - 16,
Message: "rate limited",
HTTPErrorCode: 429,
}
ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response")
)
@ -92,7 +97,7 @@ type Backend struct {
wsURL string
authUsername string
authPassword string
rateLimiter RateLimiter
rateLimiter BackendRateLimiter
client *LimitedHTTPClient
dialer *websocket.Dialer
maxRetries int
@ -174,7 +179,7 @@ func NewBackend(
name string,
rpcURL string,
wsURL string,
rateLimiter RateLimiter,
rateLimiter BackendRateLimiter,
rpcSemaphore *semaphore.Weighted,
opts ...BackendOpt,
) *Backend {
@ -372,10 +377,7 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool
xForwardedFor := GetXForwardedFor(ctx)
if b.stripTrailingXFF {
ipList := strings.Split(xForwardedFor, ", ")
if len(ipList) > 0 {
xForwardedFor = ipList[0]
}
xForwardedFor = stripXFF(xForwardedFor)
} else if b.proxydIP != "" {
xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP)
}
@ -855,3 +857,8 @@ func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCR
RecordRPCForward(ctx, backendName, req.Method, source)
}
}
func stripXFF(xff string) string {
ipList := strings.Split(xff, ", ")
return strings.TrimSpace(ipList[0])
}

@ -39,6 +39,13 @@ type MetricsConfig struct {
Port int `toml:"port"`
}
type RateLimitConfig struct {
RatePerSecond int `toml:"rate_per_second"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`
}
type BackendOptions struct {
ResponseTimeoutSeconds int `toml:"response_timeout_seconds"`
MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"`
@ -75,6 +82,7 @@ type Config struct {
Cache CacheConfig `toml:"cache"`
Redis RedisConfig `toml:"redis"`
Metrics MetricsConfig `toml:"metrics"`
RateLimit RateLimitConfig `toml:"rate_limit"`
BackendOptions BackendOptions `toml:"backend"`
Backends BackendsConfig `toml:"backends"`
Authentication map[string]string `toml:"authentication"`

@ -13,6 +13,7 @@ require (
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.2
github.com/sethvargo/go-limiter v0.7.2
github.com/stretchr/testify v1.7.0
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
)
@ -59,7 +60,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect
golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

@ -451,6 +451,8 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/segmentio/kafka-go v0.1.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/segmentio/kafka-go v0.2.0/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sethvargo/go-limiter v0.7.2 h1:FgC4N7RMpV5gMrUdda15FaFTkQ/L4fEqM7seXMs4oO8=
github.com/sethvargo/go-limiter v0.7.2/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
@ -701,8 +703,8 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs=
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 h1:ftMN5LMiBFjbzleLqtoBZk7KdJwhuybIU+FckUHgoyQ=
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

@ -1,8 +1,11 @@
package integration_tests
import (
"fmt"
"net/http"
"os"
"testing"
"time"
"github.com/ethereum-optimism/optimism/proxyd"
"github.com/stretchr/testify/require"
@ -13,18 +16,83 @@ type resWithCode struct {
res []byte
}
func TestMaxRPSLimit(t *testing.T) {
const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`
func TestBackendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("rate_limit")
config := ReadConfig("backend_rate_limit")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
limitedRes, codes := spamReqs(t, client, 503)
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
}
func TestFrontendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("frontend_rate_limit")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
t.Run("non-exempt over limit", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
})
t.Run("exempt user agent over limit", func(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
require.Equal(t, 3, codes[200])
})
t.Run("exempt origin over limit", func(t *testing.T) {
h := make(http.Header)
h.Set("Origin", "exempt_origin")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
fmt.Println(codes)
require.Equal(t, 3, codes[200])
})
t.Run("multiple xff", func(t *testing.T) {
h1 := make(http.Header)
h1.Set("X-Forwarded-For", "0.0.0.0")
h2 := make(http.Header)
h2.Set("X-Forwarded-For", "1.1.1.1")
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
_, codes := spamReqs(t, client1, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
_, code, err := client2.SendRPC("eth_chainId", nil)
require.Equal(t, 200, code)
require.NoError(t, err)
time.Sleep(time.Second)
_, code, err = client2.SendRPC("eth_chainId", nil)
require.Equal(t, 200, code)
require.NoError(t, err)
})
}
func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) {
resCh := make(chan *resWithCode)
for i := 0; i < 3; i++ {
go func() {
@ -48,13 +116,10 @@ func TestMaxRPSLimit(t *testing.T) {
codes[code] += 1
}
// 503 because there's only one backend available
if code == 503 {
if code == limCode {
limitedRes = res.res
}
}
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
return limitedRes, codes
}

@ -0,0 +1,23 @@
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
[rate_limit]
rate_per_second = 2
exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit"

@ -21,10 +21,20 @@ import (
type ProxydHTTPClient struct {
url string
headers http.Header
}
func NewProxydClient(url string) *ProxydHTTPClient {
return &ProxydHTTPClient{url: url}
return NewProxydClientWithHeaders(url, make(http.Header))
}
func NewProxydClientWithHeaders(url string, headers http.Header) *ProxydHTTPClient {
clonedHeaders := headers.Clone()
clonedHeaders.Set("Content-Type", "application/json")
return &ProxydHTTPClient{
url: url,
headers: clonedHeaders,
}
}
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
@ -45,7 +55,13 @@ func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, er
}
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
req, err := http.NewRequest("POST", p.url, bytes.NewReader(body))
if err != nil {
panic(err)
}
req.Header = p.headers
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, -1, err
}

@ -43,11 +43,11 @@ func Start(config *Config) (func(), error) {
redisURL = rURL
}
var lim RateLimiter
var lim BackendRateLimiter
var err error
if redisURL == "" {
log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalRateLimiter()
lim = NewLocalBackendRateLimiter()
} else {
lim, err = NewRedisRateLimiter(redisURL)
if err != nil {
@ -212,7 +212,7 @@ func Start(config *Config) (func(), error) {
rpcCache = newRPCCache(newCacheWithCompression(cache), blockNumFn, gasPriceFn, config.Cache.NumBlockConfirmations)
}
srv := NewServer(
srv, err := NewServer(
backendGroups,
wsBackendGroup,
NewStringSetFromStrings(config.WSMethodWhitelist),
@ -222,9 +222,13 @@ func Start(config *Config) (func(), error) {
secondsToDuration(config.Server.TimeoutSeconds),
config.Server.MaxUpstreamBatchSize,
rpcCache,
config.RateLimit,
config.Server.EnableRequestLog,
config.Server.MaxRequestBodyLogLen,
)
if err != nil {
return nil, fmt.Errorf("error creating server: %w", err)
}
if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)

@ -41,7 +41,7 @@ end
return false
`
type RateLimiter interface {
type BackendRateLimiter interface {
IsBackendOnline(name string) (bool, error)
SetBackendOffline(name string, duration time.Duration) error
IncBackendRPS(name string) (int, error)
@ -50,14 +50,14 @@ type RateLimiter interface {
FlushBackendWSConns(names []string) error
}
type RedisRateLimiter struct {
type RedisBackendRateLimiter struct {
rdb *redis.Client
randID string
touchKeys map[string]time.Duration
tkMtx sync.Mutex
}
func NewRedisRateLimiter(url string) (RateLimiter, error) {
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
@ -66,7 +66,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
out := &RedisRateLimiter{
out := &RedisBackendRateLimiter{
rdb: rdb,
randID: randStr(20),
touchKeys: make(map[string]time.Duration),
@ -75,7 +75,7 @@ func NewRedisRateLimiter(url string) (RateLimiter, error) {
return out, nil
}
func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) {
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result()
if err != nil {
RecordRedisError("IsBackendOnline")
@ -85,7 +85,7 @@ func (r *RedisRateLimiter) IsBackendOnline(name string) (bool, error) {
return exists == 0, nil
}
func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
func (r *RedisBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
if duration == 0 {
return nil
}
@ -102,7 +102,7 @@ func (r *RedisRateLimiter) SetBackendOffline(name string, duration time.Duration
return nil
}
func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) {
func (r *RedisBackendRateLimiter) IncBackendRPS(name string) (int, error) {
cmd := r.rdb.Eval(
context.Background(),
MaxRPSScript,
@ -116,7 +116,7 @@ func (r *RedisRateLimiter) IncBackendRPS(name string) (int, error) {
return rps, nil
}
func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
func (r *RedisBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
r.tkMtx.Lock()
r.touchKeys[connsKey] = 5 * time.Minute
@ -142,7 +142,7 @@ func (r *RedisRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
return incremented, nil
}
func (r *RedisRateLimiter) DecBackendWSConns(name string) error {
func (r *RedisBackendRateLimiter) DecBackendWSConns(name string) error {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.Decr(context.Background(), connsKey).Err()
if err != nil {
@ -152,7 +152,7 @@ func (r *RedisRateLimiter) DecBackendWSConns(name string) error {
return nil
}
func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error {
func (r *RedisBackendRateLimiter) FlushBackendWSConns(names []string) error {
ctx := context.Background()
for _, name := range names {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
@ -172,7 +172,7 @@ func (r *RedisRateLimiter) FlushBackendWSConns(names []string) error {
return nil
}
func (r *RedisRateLimiter) touch() {
func (r *RedisBackendRateLimiter) touch() {
for {
r.tkMtx.Lock()
for key, dur := range r.touchKeys {
@ -186,15 +186,15 @@ func (r *RedisRateLimiter) touch() {
}
}
type LocalRateLimiter struct {
type LocalBackendRateLimiter struct {
deadBackends map[string]time.Time
backendRPS map[string]int
backendWSConns map[string]int
mtx sync.RWMutex
}
func NewLocalRateLimiter() *LocalRateLimiter {
out := &LocalRateLimiter{
func NewLocalBackendRateLimiter() *LocalBackendRateLimiter {
out := &LocalBackendRateLimiter{
deadBackends: make(map[string]time.Time),
backendRPS: make(map[string]int),
backendWSConns: make(map[string]int),
@ -203,27 +203,27 @@ func NewLocalRateLimiter() *LocalRateLimiter {
return out
}
func (l *LocalRateLimiter) IsBackendOnline(name string) (bool, error) {
func (l *LocalBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
l.mtx.RLock()
defer l.mtx.RUnlock()
return l.deadBackends[name].Before(time.Now()), nil
}
func (l *LocalRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
func (l *LocalBackendRateLimiter) SetBackendOffline(name string, duration time.Duration) error {
l.mtx.Lock()
defer l.mtx.Unlock()
l.deadBackends[name] = time.Now().Add(duration)
return nil
}
func (l *LocalRateLimiter) IncBackendRPS(name string) (int, error) {
func (l *LocalBackendRateLimiter) IncBackendRPS(name string) (int, error) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.backendRPS[name] += 1
return l.backendRPS[name], nil
}
func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
func (l *LocalBackendRateLimiter) IncBackendWSConns(name string, max int) (bool, error) {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.backendWSConns[name] == max {
@ -233,7 +233,7 @@ func (l *LocalRateLimiter) IncBackendWSConns(name string, max int) (bool, error)
return true, nil
}
func (l *LocalRateLimiter) DecBackendWSConns(name string) error {
func (l *LocalBackendRateLimiter) DecBackendWSConns(name string) error {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.backendWSConns[name] == 0 {
@ -243,11 +243,11 @@ func (l *LocalRateLimiter) DecBackendWSConns(name string) error {
return nil
}
func (l *LocalRateLimiter) FlushBackendWSConns(names []string) error {
func (l *LocalBackendRateLimiter) FlushBackendWSConns(names []string) error {
return nil
}
func (l *LocalRateLimiter) clear() {
func (l *LocalBackendRateLimiter) clear() {
for {
time.Sleep(time.Second)
l.mtx.Lock()
@ -263,3 +263,6 @@ func randStr(l int) string {
}
return hex.EncodeToString(b)
}
type ServerRateLimiter struct {
}

@ -65,6 +65,14 @@ func (r *RPCErr) Error() string {
return r.Message
}
func (r *RPCErr) Clone() *RPCErr {
return &RPCErr{
Code: r.Code,
Message: r.Message,
HTTPErrorCode: r.HTTPErrorCode,
}
}
func IsValidID(id json.RawMessage) bool {
// handle the case where the ID is a string
if strings.HasPrefix(string(id), "\"") && strings.HasSuffix(string(id), "\"") {

@ -14,6 +14,10 @@ import (
"sync"
"time"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-limiter/noopstore"
"github.com/ethereum/go-ethereum/log"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
@ -46,6 +50,10 @@ type Server struct {
timeout time.Duration
maxUpstreamBatchSize int
upgrader *websocket.Upgrader
lim limiter.Store
limConfig RateLimitConfig
limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool
rpcServer *http.Server
wsServer *http.Server
cache RPCCache
@ -62,9 +70,10 @@ func NewServer(
timeout time.Duration,
maxUpstreamBatchSize int,
cache RPCCache,
rateLimitConfig RateLimitConfig,
enableRequestLog bool,
maxRequestBodyLogLen int,
) *Server {
) (*Server, error) {
if cache == nil {
cache = &NoopRPCCache{}
}
@ -81,6 +90,29 @@ func NewServer(
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
}
var lim limiter.Store
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 {
var err error
lim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
if err != nil {
return nil, err
}
for _, origin := range rateLimitConfig.ExemptOrigins {
limExemptOrigins[strings.ToLower(origin)] = true
}
for _, agent := range rateLimitConfig.ExemptUserAgents {
limExemptUserAgents[strings.ToLower(agent)] = true
}
} else {
lim, _ = noopstore.New()
}
return &Server{
backendGroups: backendGroups,
wsBackendGroup: wsBackendGroup,
@ -96,7 +128,11 @@ func NewServer(
upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
},
}
lim: lim,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
}, nil
}
func (s *Server) RPCListenAndServe(host string, port int) error {
@ -160,6 +196,28 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx, cancel = context.WithTimeout(ctx, s.timeout)
defer cancel()
exemptOrigin := s.limExemptOrigins[strings.ToLower(r.Header.Get("Origin"))]
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(r.Header.Get("User-Agent"))]
var ok bool
if exemptOrigin || exemptUserAgent {
ok = true
} else {
// Use XFF in context since it will automatically be replaced by the remote IP
xff := stripXFF(GetXForwardedFor(ctx))
if xff == "" {
log.Warn("rejecting request without XFF or remote IP")
ok = false
} else {
_, _, _, ok, _ = s.lim.Take(ctx, xff)
}
}
if !ok {
rpcErr := ErrOverRateLimit.Clone()
rpcErr.Message = s.limConfig.ErrorMessage
writeRPCError(ctx, w, nil, rpcErr)
return
}
log.Info(
"received RPC request",
"req_id", GetReqID(ctx),
@ -390,6 +448,14 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r)
authorization := vars["authorization"]
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
ipPort := strings.Split(r.RemoteAddr, ":")
if len(ipPort) == 2 {
xff = ipPort[0]
}
}
ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck
if s.authenticatedPaths == nil {
// handle the edge case where auth is disabled
@ -400,13 +466,7 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
w.WriteHeader(404)
return nil
}
return context.WithValue(
r.Context(),
ContextKeyReqID, // nolint:staticcheck
randStr(10),
)
}
} else {
if authorization == "" || s.authenticatedPaths[authorization] == "" {
log.Info("blocked unauthorized request", "authorization", authorization)
httpResponseCodesTotal.WithLabelValues("401").Inc()
@ -414,16 +474,9 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
return nil
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
ipPort := strings.Split(r.RemoteAddr, ":")
if len(ipPort) == 2 {
xff = ipPort[0]
}
ctx = context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
}
ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
ctx = context.WithValue(ctx, ContextKeyXForwardedFor, xff) // nolint:staticcheck
return context.WithValue(
ctx,
ContextKeyReqID, // nolint:staticcheck