Listen websocket on same port, allow direct connections and more helpful logging

This commit is contained in:
Tornado Contrib 2024-08-09 19:48:33 +00:00
parent d8cec08bd9
commit 1ad517098a
Signed by: tornadocontrib
GPG Key ID: 60B4DF1A076C64B1
3 changed files with 85 additions and 14 deletions

@ -13,12 +13,16 @@ import (
type ServerConfig struct {
RPCHost string `toml:"rpc_host"`
RPCPort int `toml:"rpc_port"`
EnableWS bool `toml:"enable_ws"`
WSHost string `toml:"ws_host"`
WSPort int `toml:"ws_port"`
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
MaxConcurrentRPCs int64 `toml:"max_concurrent_rpcs"`
LogLevel string `toml:"log_level"`
// Allow direct client connection without x_forwarded_for header for local tests
AllowDirect bool `toml:"allow_direct"`
// TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections
TimeoutSeconds int `toml:"timeout_seconds"`

@ -346,7 +346,7 @@ func Start(config *Config) (*Server, func(), error) {
if config.Server.RPCPort != 0 {
go func() {
if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil {
if err := srv.RPCListenAndServe(config.Server); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("RPC server shut down")
return
@ -366,7 +366,7 @@ func Start(config *Config) (*Server, func(), error) {
log.Crit("error starting WS server", "err", err)
}
}()
} else {
} else if !config.Server.EnableWS {
log.Info("WS server not enabled (ws_port is set to 0)")
}

@ -200,12 +200,23 @@ func NewServer(
}, nil
}
func (s *Server) RPCListenAndServe(host string, port int) error {
func (s *Server) RPCListenAndServe(serverConfig ServerConfig) error {
host := serverConfig.RPCHost
port := serverConfig.RPCPort
enableWS := serverConfig.EnableWS
var handleRpc ReqHandle = s.GetRPCHandle(serverConfig)
s.srvMu.Lock()
hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/", handleRpc).Methods("POST")
hdlr.HandleFunc("/{authorization}", handleRpc).Methods("POST")
if enableWS {
var handleWS ReqHandle = s.GetWSHandle(true)
hdlr.HandleFunc("/", handleWS)
hdlr.HandleFunc("/{authorization}", handleWS)
}
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
@ -215,15 +226,20 @@ func (s *Server) RPCListenAndServe(host string, port int) error {
Addr: addr,
}
log.Info("starting HTTP server", "addr", addr)
if enableWS {
log.Info("starting WS server", "addr", addr)
}
s.srvMu.Unlock()
return s.rpcServer.ListenAndServe()
}
func (s *Server) WSListenAndServe(host string, port int) error {
s.srvMu.Lock()
var handleWS ReqHandle = s.GetWSHandle(false)
hdlr := mux.NewRouter()
hdlr.HandleFunc("/", s.HandleWS)
hdlr.HandleFunc("/{authorization}", s.HandleWS)
hdlr.HandleFunc("/", handleWS)
hdlr.HandleFunc("/{authorization}", handleWS)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
@ -255,7 +271,15 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
type ReqHandle func(http.ResponseWriter, *http.Request)
func (s *Server) GetRPCHandle(serverConfig ServerConfig) ReqHandle {
return func(w http.ResponseWriter, r *http.Request) {
s.HandleRPC(w, r, serverConfig)
}
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request, serverConfig ServerConfig) {
ctx := s.populateContext(w, r)
if ctx == nil {
return
@ -272,8 +296,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)
if xff == "" {
writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
return
// Just use remote addr from socket when the request doesn't have x_forwarded_for header
if (serverConfig.AllowDirect) {
xff = r.RemoteAddr
} else {
writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
return
}
}
isLimited := func(method string) bool {
@ -354,7 +383,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return
}
batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
batchRes, batchContainsCached, servedBy, err := s.handleBatchRPC(xff, r, ctx, reqs, isLimited, true)
if err == context.DeadlineExceeded {
writeRPCError(ctx, w, nil, ErrGatewayTimeout)
return
@ -377,7 +406,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
}
rawBody := json.RawMessage(body)
backendRes, cached, servedBy, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
backendRes, cached, servedBy, err := s.handleBatchRPC(xff, r, ctx, []json.RawMessage{rawBody}, isLimited, false)
if err != nil {
if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) {
@ -394,7 +423,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
writeRPCRes(ctx, w, backendRes[0])
}
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) {
func (s *Server) handleBatchRPC(xff string, r *http.Request, ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, string, error) {
// A request set is transformed into groups of batches.
// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
@ -406,6 +435,10 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
backendGroup string
}
// Retrieve info from header
origin := r.Header.Get("Origin")
userAgent := r.Header.Get("User-Agent")
responses := make([]*RPCRes, len(reqs))
batches := make(map[batchGroup][]batchElem)
ids := make(map[string]int, len(reqs))
@ -418,6 +451,15 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
continue
}
log.Debug(
"received RPC method",
"req_id", GetReqID(ctx),
"method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
)
// Simple health check
if len(reqs) == 1 && parsedReq.Method == proxydHealthzMethod {
res := &RPCRes{
@ -449,6 +491,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
"source", "rpc",
"req_id", GetReqID(ctx),
"method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted)
@ -475,6 +520,9 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
"source", "rpc",
"req_id", GetReqID(ctx),
"method", parsedReq.Method,
"user_agent", userAgent,
"origin", origin,
"remote_ip", xff,
)
RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
@ -579,12 +627,31 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isL
return responses, cached, servedByString, nil
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
func (s *Server) GetWSHandle(fromRpc bool) ReqHandle {
return func(w http.ResponseWriter, r *http.Request) {
s.HandleWS(w, r, fromRpc)
}
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request, fromRpc bool) {
ctx := s.populateContext(w, r)
if ctx == nil {
return
}
// Handle upgrade header request
upgrade := false
for _, header := range r.Header["Upgrade"] {
if header == "websocket" {
upgrade = true
break
}
}
// Filter out non websocket requests
if fromRpc && !upgrade {
return
}
log.Info("received WS connection", "req_id", GetReqID(ctx))
clientConn, err := s.upgrader.Upgrade(w, r, nil)