Listen websocket on same port, allow direct connections and more helpful logging

This commit is contained in:
Tornado Contrib 2024-06-21 11:45:08 +00:00
parent fa11fda45e
commit bba4471c79
Signed by: tornadocontrib
GPG Key ID: 60B4DF1A076C64B1
3 changed files with 85 additions and 14 deletions

@ -11,12 +11,16 @@ import (
type ServerConfig struct { type ServerConfig struct {
RPCHost string `toml:"rpc_host"` RPCHost string `toml:"rpc_host"`
RPCPort int `toml:"rpc_port"` RPCPort int `toml:"rpc_port"`
EnableWS bool `toml:"enable_ws"`
WSHost string `toml:"ws_host"` WSHost string `toml:"ws_host"`
WSPort int `toml:"ws_port"` WSPort int `toml:"ws_port"`
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"` MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
MaxConcurrentRPCs int64 `toml:"max_concurrent_rpcs"` MaxConcurrentRPCs int64 `toml:"max_concurrent_rpcs"`
LogLevel string `toml:"log_level"` LogLevel string `toml:"log_level"`
// Allow direct client connection without x_forwarded_for header for local tests
AllowDirect bool `toml:"allow_direct"`
// TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections // TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections
TimeoutSeconds int `toml:"timeout_seconds"` TimeoutSeconds int `toml:"timeout_seconds"`

@ -327,7 +327,7 @@ func Start(config *Config) (*Server, func(), error) {
if config.Server.RPCPort != 0 { if config.Server.RPCPort != 0 {
go func() { go func() {
if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil { if err := srv.RPCListenAndServe(config.Server); err != nil {
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
log.Info("RPC server shut down") log.Info("RPC server shut down")
return return
@ -347,7 +347,7 @@ func Start(config *Config) (*Server, func(), error) {
log.Crit("error starting WS server", "err", err) log.Crit("error starting WS server", "err", err)
} }
}() }()
} else { } else if !config.Server.EnableWS {
log.Info("WS server not enabled (ws_port is set to 0)") log.Info("WS server not enabled (ws_port is set to 0)")
} }

@ -200,12 +200,23 @@ func NewServer(
}, nil }, nil
} }
func (s *Server) RPCListenAndServe(host string, port int) error { func (s *Server) RPCListenAndServe(serverConfig ServerConfig) error {
host := serverConfig.RPCHost
port := serverConfig.RPCPort
enableWS := serverConfig.EnableWS
var handleRpc ReqHandle = s.GetRPCHandle(serverConfig)
s.srvMu.Lock() s.srvMu.Lock()
hdlr := mux.NewRouter() hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET") hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST") hdlr.HandleFunc("/", handleRpc).Methods("POST")
hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST") hdlr.HandleFunc("/{authorization}", handleRpc).Methods("POST")
if enableWS {
var handleWS ReqHandle = s.GetWSHandle(true)
hdlr.HandleFunc("/", handleWS)
hdlr.HandleFunc("/{authorization}", handleWS)
}
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
}) })
@ -215,15 +226,20 @@ func (s *Server) RPCListenAndServe(host string, port int) error {
Addr: addr, Addr: addr,
} }
log.Info("starting HTTP server", "addr", addr) log.Info("starting HTTP server", "addr", addr)
if enableWS {
log.Info("starting WS server", "addr", addr)
}
s.srvMu.Unlock() s.srvMu.Unlock()
return s.rpcServer.ListenAndServe() return s.rpcServer.ListenAndServe()
} }
func (s *Server) WSListenAndServe(host string, port int) error { func (s *Server) WSListenAndServe(host string, port int) error {
s.srvMu.Lock() s.srvMu.Lock()
var handleWS ReqHandle = s.GetWSHandle(false)
hdlr := mux.NewRouter() hdlr := mux.NewRouter()
hdlr.HandleFunc("/", s.HandleWS) hdlr.HandleFunc("/", handleWS)
hdlr.HandleFunc("/{authorization}", s.HandleWS) hdlr.HandleFunc("/{authorization}", handleWS)
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
}) })
@ -255,7 +271,15 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK")) _, _ = w.Write([]byte("OK"))
} }
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { type ReqHandle func(http.ResponseWriter, *http.Request)
func (s *Server) GetRPCHandle(serverConfig ServerConfig) ReqHandle {
return func(w http.ResponseWriter, r *http.Request) {
s.HandleRPC(w, r, serverConfig)
}
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request, serverConfig ServerConfig) {
ctx := s.populateContext(w, r) ctx := s.populateContext(w, r)
if ctx == nil { if ctx == nil {
return return
@ -272,9 +296,14 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent) isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)
if xff == "" { if xff == "" {
// Just use remote addr from socket when the request doesn't have x_forwarded_for header
if (serverConfig.AllowDirect) {
xff = r.RemoteAddr
} else {
writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP")) writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
return return
} }
}
isLimited := func(method string) bool { isLimited := func(method string) bool {
isGloballyLimitedMethod := s.isGlobalLimit(method) isGloballyLimitedMethod := s.isGlobalLimit(method)
@ -368,7 +397,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return return
} }
batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(ctx, reqs, isLimited, true) batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(xff, r, ctx, reqs, isLimited, true)
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
writeRPCError(ctx, w, nil, ErrGatewayTimeout) writeRPCError(ctx, w, nil, ErrGatewayTimeout)
return return
@ -391,7 +420,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
} }
rawBody := json.RawMessage(body) rawBody := json.RawMessage(body)
backendRes, cached, servedBy, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false) backendRes, cached, servedBy, err := s.handleBatchRPC(xff, r, ctx, []json.RawMessage{rawBody}, isLimited, false)
if err != nil { if err != nil {
if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) || if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) { errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) {
@ -408,7 +437,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
writeRPCRes(ctx, w, backendRes[0]) writeRPCRes(ctx, w, backendRes[0])
} }
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) { func (s *Server) handleBatchRPC(xff string, r *http.Request, ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) {
// A request set is transformed into groups of batches. // A request set is transformed into groups of batches.
// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints) // Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's // A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
@ -420,6 +449,10 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
backendGroup string backendGroup string
} }
// Retrieve info from header
origin := r.Header.Get("Origin")
userAgent := r.Header.Get("User-Agent")
responses := make([]*RPCRes, len(reqs)) responses := make([]*RPCRes, len(reqs))
batches := make(map[batchGroup][]batchElem) batches := make(map[batchGroup][]batchElem)
ids := make(map[string]int, len(reqs)) ids := make(map[string]int, len(reqs))
@ -432,6 +465,15 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
continue continue
} }
log.Debug(
"received RPC method",
"req_id", GetReqID(ctx),
"method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
)
// Simple health check // Simple health check
if len(reqs) == 1 && parsedReq.Method == proxydHealthzMethod { if len(reqs) == 1 && parsedReq.Method == proxydHealthzMethod {
res := &RPCRes{ res := &RPCRes{
@ -463,6 +505,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
"source", "rpc", "source", "rpc",
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
"method", parsedReq.Method, "method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
) )
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted) RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted) responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted)
@ -479,6 +524,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
"source", "rpc", "source", "rpc",
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
"method", parsedReq.Method, "method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
) )
RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit) RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit) responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
@ -583,12 +631,31 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
return responses, cached, servedByString, nil return responses, cached, servedByString, nil
} }
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) GetWSHandle(fromRpc bool) ReqHandle {
return func(w http.ResponseWriter, r *http.Request) {
s.HandleWS(w, r, fromRpc)
}
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request, fromRpc bool) {
ctx := s.populateContext(w, r) ctx := s.populateContext(w, r)
if ctx == nil { if ctx == nil {
return return
} }
// Handle upgrade header request
upgrade := false
for _, header := range r.Header["Upgrade"] {
if header == "websocket" {
upgrade = true
break
}
}
// Filter out non websocket requests
if fromRpc && !upgrade {
return
}
log.Info("received WS connection", "req_id", GetReqID(ctx)) log.Info("received WS connection", "req_id", GetReqID(ctx))
clientConn, err := s.upgrader.Upgrade(w, r, nil) clientConn, err := s.upgrader.Upgrade(w, r, nil)