diff --git a/proxyd/proxyd/config.go b/proxyd/proxyd/config.go index fefca9f..63f557c 100644 --- a/proxyd/proxyd/config.go +++ b/proxyd/proxyd/config.go @@ -22,9 +22,9 @@ type ServerConfig struct { MaxUpstreamBatchSize int `toml:"max_upstream_batch_size"` - EnableRequestLog bool `toml:"enable_request_log"` - MaxRequestBodyLogLen int `toml:"max_request_body_log_len"` - EnablePprof bool `toml:"enable_pprof"` + EnableRequestLog bool `toml:"enable_request_log"` + MaxRequestBodyLogLen int `toml:"max_request_body_log_len"` + EnablePprof bool `toml:"enable_pprof"` EnableXServedByHeader bool `toml:"enable_served_by_header"` } @@ -51,6 +51,7 @@ type RateLimitConfig struct { ExemptUserAgents []string `toml:"exempt_user_agents"` ErrorMessage string `toml:"error_message"` MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"` + IPHeaderOverride string `toml:"ip_header_override"` } type RateLimitMethodOverride struct { diff --git a/proxyd/proxyd/go.sum b/proxyd/proxyd/go.sum index e759ce5..a54ffb5 100644 --- a/proxyd/proxyd/go.sum +++ b/proxyd/proxyd/go.sum @@ -138,7 +138,6 @@ github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2 github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 2b7a1bd..5d262da 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -44,6 +44,7 @@ const ( defaultWSWriteTimeout = 10 * time.Second maxRequestBodyLogLen = 2000 defaultMaxUpstreamBatchSize = 10 + defaultRateLimitHeader = "X-Forwarded-For" ) var emptyArrayResponse = json.RawMessage("[]") @@ -73,6 +74,7 @@ type Server struct { wsServer *http.Server cache RPCCache srvMu sync.Mutex + rateLimitHeader string } type limiterFunc func(method string) bool @@ -168,6 +170,11 @@ func NewServer( senderLim = limiterFactory(time.Duration(senderRateLimitConfig.Interval), senderRateLimitConfig.Limit, "senders") } + rateLimitHeader := defaultRateLimitHeader + if rateLimitConfig.IPHeaderOverride != "" { + rateLimitHeader = rateLimitConfig.IPHeaderOverride + } + return &Server{ BackendGroups: backendGroups, wsBackendGroup: wsBackendGroup, @@ -192,6 +199,7 @@ func NewServer( allowedChainIds: senderRateLimitConfig.AllowedChainIds, limExemptOrigins: limExemptOrigins, limExemptUserAgents: limExemptUserAgents, + rateLimitHeader: rateLimitHeader, }, nil } @@ -608,7 +616,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context { vars := mux.Vars(r) authorization := vars["authorization"] - xff := r.Header.Get("X-Forwarded-For") + xff := r.Header.Get(s.rateLimitHeader) if xff == "" { ipPort := strings.Split(r.RemoteAddr, ":") if len(ipPort) == 2 {