proxyd: configurable IP rate limit header

This commit is contained in:
Danyal Prout 2023-11-02 14:38:04 -05:00
parent 22b7237389
commit 798878e455
3 changed files with 13 additions and 5 deletions

@ -51,6 +51,7 @@ type RateLimitConfig struct {
ExemptUserAgents []string `toml:"exempt_user_agents"` ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"` ErrorMessage string `toml:"error_message"`
MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"` MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"`
IPHeaderOverride string `toml:"ip_header_override"`
} }
type RateLimitMethodOverride struct { type RateLimitMethodOverride struct {

@ -138,7 +138,6 @@ github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY=

@ -44,6 +44,7 @@ const (
defaultWSWriteTimeout = 10 * time.Second defaultWSWriteTimeout = 10 * time.Second
maxRequestBodyLogLen = 2000 maxRequestBodyLogLen = 2000
defaultMaxUpstreamBatchSize = 10 defaultMaxUpstreamBatchSize = 10
defaultRateLimitHeader = "X-Forwarded-For"
) )
var emptyArrayResponse = json.RawMessage("[]") var emptyArrayResponse = json.RawMessage("[]")
@ -73,6 +74,7 @@ type Server struct {
wsServer *http.Server wsServer *http.Server
cache RPCCache cache RPCCache
srvMu sync.Mutex srvMu sync.Mutex
rateLimitHeader string
} }
type limiterFunc func(method string) bool type limiterFunc func(method string) bool
@ -168,6 +170,11 @@ func NewServer(
senderLim = limiterFactory(time.Duration(senderRateLimitConfig.Interval), senderRateLimitConfig.Limit, "senders") senderLim = limiterFactory(time.Duration(senderRateLimitConfig.Interval), senderRateLimitConfig.Limit, "senders")
} }
rateLimitHeader := defaultRateLimitHeader
if rateLimitConfig.IPHeaderOverride != "" {
rateLimitHeader = rateLimitConfig.IPHeaderOverride
}
return &Server{ return &Server{
BackendGroups: backendGroups, BackendGroups: backendGroups,
wsBackendGroup: wsBackendGroup, wsBackendGroup: wsBackendGroup,
@ -192,6 +199,7 @@ func NewServer(
allowedChainIds: senderRateLimitConfig.AllowedChainIds, allowedChainIds: senderRateLimitConfig.AllowedChainIds,
limExemptOrigins: limExemptOrigins, limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents, limExemptUserAgents: limExemptUserAgents,
rateLimitHeader: rateLimitHeader,
}, nil }, nil
} }
@ -608,7 +616,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context { func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r) vars := mux.Vars(r)
authorization := vars["authorization"] authorization := vars["authorization"]
xff := r.Header.Get("X-Forwarded-For") xff := r.Header.Get(s.rateLimitHeader)
if xff == "" { if xff == "" {
ipPort := strings.Split(r.RemoteAddr, ":") ipPort := strings.Split(r.RemoteAddr, ":")
if len(ipPort) == 2 { if len(ipPort) == 2 {