diff --git a/proxyd/config.go b/proxyd/config.go index 4719a55..e89b86d 100644 --- a/proxyd/config.go +++ b/proxyd/config.go @@ -11,12 +11,16 @@ import ( type ServerConfig struct { RPCHost string `toml:"rpc_host"` RPCPort int `toml:"rpc_port"` + EnableWS bool `toml:"enable_ws"` WSHost string `toml:"ws_host"` WSPort int `toml:"ws_port"` MaxBodySizeBytes int64 `toml:"max_body_size_bytes"` MaxConcurrentRPCs int64 `toml:"max_concurrent_rpcs"` LogLevel string `toml:"log_level"` + // Allow direct client connection without x_forwarded_for header for local tests + AllowDirect bool `toml:"allow_direct"` + // TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections TimeoutSeconds int `toml:"timeout_seconds"` diff --git a/proxyd/proxyd.go b/proxyd/proxyd.go index 402909b..02551e7 100644 --- a/proxyd/proxyd.go +++ b/proxyd/proxyd.go @@ -327,7 +327,7 @@ func Start(config *Config) (*Server, func(), error) { if config.Server.RPCPort != 0 { go func() { - if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil { + if err := srv.RPCListenAndServe(config.Server); err != nil { if errors.Is(err, http.ErrServerClosed) { log.Info("RPC server shut down") return @@ -347,7 +347,7 @@ func Start(config *Config) (*Server, func(), error) { log.Crit("error starting WS server", "err", err) } }() - } else { + } else if !config.Server.EnableWS { log.Info("WS server not enabled (ws_port is set to 0)") } diff --git a/proxyd/server.go b/proxyd/server.go index c663f42..56edc16 100644 --- a/proxyd/server.go +++ b/proxyd/server.go @@ -200,12 +200,23 @@ func NewServer( }, nil } -func (s *Server) RPCListenAndServe(host string, port int) error { +func (s *Server) RPCListenAndServe(serverConfig ServerConfig) error { + host := serverConfig.RPCHost + port := serverConfig.RPCPort + enableWS := serverConfig.EnableWS + + var handleRpc ReqHandle = s.GetRPCHandle(serverConfig) + s.srvMu.Lock() hdlr := mux.NewRouter() hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET") - hdlr.HandleFunc("/", s.HandleRPC).Methods("POST") - hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST") + hdlr.HandleFunc("/", handleRpc).Methods("POST") + hdlr.HandleFunc("/{authorization}", handleRpc).Methods("POST") + if enableWS { + var handleWS ReqHandle = s.GetWSHandle(true) + hdlr.HandleFunc("/", handleWS) + hdlr.HandleFunc("/{authorization}", handleWS) + } c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, }) @@ -215,15 +226,20 @@ func (s *Server) RPCListenAndServe(host string, port int) error { Addr: addr, } log.Info("starting HTTP server", "addr", addr) + if enableWS { + log.Info("starting WS server", "addr", addr) + } s.srvMu.Unlock() return s.rpcServer.ListenAndServe() } func (s *Server) WSListenAndServe(host string, port int) error { s.srvMu.Lock() + var handleWS ReqHandle = s.GetWSHandle(false) + hdlr := mux.NewRouter() - hdlr.HandleFunc("/", s.HandleWS) - hdlr.HandleFunc("/{authorization}", s.HandleWS) + hdlr.HandleFunc("/", handleWS) + hdlr.HandleFunc("/{authorization}", handleWS) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, }) @@ -255,7 +271,15 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("OK")) } -func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { +type ReqHandle func(http.ResponseWriter, *http.Request) + +func (s *Server) GetRPCHandle(serverConfig ServerConfig) ReqHandle { + return func(w http.ResponseWriter, r *http.Request) { + s.HandleRPC(w, r, serverConfig) + } +} + +func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request, serverConfig ServerConfig) { ctx := s.populateContext(w, r) if ctx == nil { return @@ -272,8 +296,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent) if xff == "" { - writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP")) - return + // Just use remote addr from socket when the request doesn't have x_forwarded_for header + if (serverConfig.AllowDirect) { + xff = r.RemoteAddr + } else { + writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP")) + return + } } isLimited := func(method string) bool { @@ -368,7 +397,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { return } - batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(ctx, reqs, isLimited, true) + batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(xff, r, ctx, reqs, isLimited, true) if err == context.DeadlineExceeded { writeRPCError(ctx, w, nil, ErrGatewayTimeout) return @@ -391,7 +420,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { } rawBody := json.RawMessage(body) - backendRes, cached, servedBy, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false) + backendRes, cached, servedBy, err := s.handleBatchRPC(xff, r, ctx, []json.RawMessage{rawBody}, isLimited, false) if err != nil { if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) || errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) { @@ -408,7 +437,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { writeRPCRes(ctx, w, backendRes[0]) } -func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) { +func (s *Server) handleBatchRPC(xff string, r *http.Request, ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) { // A request set is transformed into groups of batches. // Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints) // A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's @@ -420,6 +449,10 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL backendGroup string } + // Retrieve info from header + origin := r.Header.Get("Origin") + userAgent := r.Header.Get("User-Agent") + responses := make([]*RPCRes, len(reqs)) batches := make(map[batchGroup][]batchElem) ids := make(map[string]int, len(reqs)) @@ -432,6 +465,15 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL continue } + log.Debug( + "received RPC method", + "req_id", GetReqID(ctx), + "method", parsedReq.Method, + "user_agent", userAgent, + "origin", origin, + "remote_ip", xff, + ) + // Simple health check if len(reqs) == 1 && parsedReq.Method == proxydHealthzMethod { res := &RPCRes{ @@ -463,6 +505,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL "source", "rpc", "req_id", GetReqID(ctx), "method", parsedReq.Method, + "user_agent", userAgent, + "origin", origin, + "remote_ip", xff, ) RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted) responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted) @@ -479,6 +524,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL "source", "rpc", "req_id", GetReqID(ctx), "method", parsedReq.Method, + "user_agent", userAgent, + "origin", origin, + "remote_ip", xff, ) RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit) responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit) @@ -583,12 +631,31 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL return responses, cached, servedByString, nil } -func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { +func (s *Server) GetWSHandle(fromRpc bool) ReqHandle { + return func(w http.ResponseWriter, r *http.Request) { + s.HandleWS(w, r, fromRpc) + } +} + +func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request, fromRpc bool) { ctx := s.populateContext(w, r) if ctx == nil { return } + // Handle upgrade header request + upgrade := false + for _, header := range r.Header["Upgrade"] { + if header == "websocket" { + upgrade = true + break + } + } + // Filter out non websocket requests + if fromRpc && !upgrade { + return + } + log.Info("received WS connection", "req_id", GetReqID(ctx)) clientConn, err := s.upgrader.Upgrade(w, r, nil)