diff --git a/proxyd/proxyd/integration_tests/rate_limit_test.go b/proxyd/proxyd/integration_tests/rate_limit_test.go index 4648017..310be22 100644 --- a/proxyd/proxyd/integration_tests/rate_limit_test.go +++ b/proxyd/proxyd/integration_tests/rate_limit_test.go @@ -1,7 +1,6 @@ package integration_tests import ( - "fmt" "net/http" "os" "testing" @@ -68,7 +67,6 @@ func TestFrontendMaxRPSLimit(t *testing.T) { h.Set("Origin", "exempt_origin") client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h) _, codes := spamReqs(t, client, 429) - fmt.Println(codes) require.Equal(t, 3, codes[200]) }) diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 27f3f0e..5f76e3e 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -196,14 +196,16 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ctx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() - exemptOrigin := s.limExemptOrigins[strings.ToLower(r.Header.Get("Origin"))] - exemptUserAgent := s.limExemptUserAgents[strings.ToLower(r.Header.Get("User-Agent"))] + origin := r.Header.Get("Origin") + userAgent := r.Header.Get("User-Agent") + exemptOrigin := s.limExemptOrigins[strings.ToLower(origin)] + exemptUserAgent := s.limExemptUserAgents[strings.ToLower(userAgent)] + // Use XFF in context since it will automatically be replaced by the remote IP + xff := stripXFF(GetXForwardedFor(ctx)) var ok bool if exemptOrigin || exemptUserAgent { ok = true } else { - // Use XFF in context since it will automatically be replaced by the remote IP - xff := stripXFF(GetXForwardedFor(ctx)) if xff == "" { log.Warn("rejecting request without XFF or remote IP") ok = false @@ -214,6 +216,15 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { if !ok { rpcErr := ErrOverRateLimit.Clone() rpcErr.Message = s.limConfig.ErrorMessage + RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr) + log.Warn( + "rate limited request", + "req_id", GetReqID(ctx), + "auth", GetAuthCtx(ctx), + "user_agent", userAgent, + "origin", origin, + "remote_ip", xff, + ) writeRPCError(ctx, w, nil, rpcErr) return } @@ -222,7 +233,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { "received RPC request", "req_id", GetReqID(ctx), "auth", GetAuthCtx(ctx), - "user_agent", r.Header.Get("user-agent"), + "user_agent", userAgent, ) body, err := ioutil.ReadAll(io.LimitReader(r.Body, s.maxBodySize))