diff --git a/proxyd/proxyd/backend.go b/proxyd/proxyd/backend.go index 9c4234f..4314848 100644 --- a/proxyd/proxyd/backend.go +++ b/proxyd/proxyd/backend.go @@ -854,9 +854,12 @@ func calcBackoff(i int) time.Duration { type WSProxier struct { backend *Backend clientConn *websocket.Conn - backendConn *websocket.Conn - methodWhitelist *StringSet clientConnMu sync.Mutex + backendConn *websocket.Conn + backendConnMu sync.Mutex + methodWhitelist *StringSet + readTimeout time.Duration + writeTimeout time.Duration } func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier { @@ -865,6 +868,8 @@ func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, met clientConn: clientConn, backendConn: backendConn, methodWhitelist: methodWhitelist, + readTimeout: defaultWSReadTimeout, + writeTimeout: defaultWSWriteTimeout, } } @@ -879,14 +884,21 @@ func (w *WSProxier) Proxy(ctx context.Context) error { func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { for { + err := w.clientConn.SetReadDeadline(time.Now().Add(w.readTimeout)) + if err != nil { + log.Error("ws client read timeout", "err", err) + errC <- err + return + } + // Block until we get a message. msgType, msg, err := w.clientConn.ReadMessage() if err != nil { - errC <- err - if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil { + if err := w.writeBackendConn(websocket.CloseMessage, formatWSError(err)); err != nil { log.Error("error writing backendConn message", "err", err) + errC <- err + return } - return } RecordWSMessage(ctx, w.backend.Name, SourceClient) @@ -894,7 +906,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { // Route control messages to the backend. These don't // count towards the total RPC requests count. if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { - err := w.backendConn.WriteMessage(msgType, msg) + err := w.writeBackendConn(msgType, msg) if err != nil { errC <- err return @@ -952,7 +964,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { "req_id", GetReqID(ctx), ) - err = w.backendConn.WriteMessage(msgType, msg) + err = w.writeBackendConn(msgType, msg) if err != nil { errC <- err return @@ -962,14 +974,21 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { for { + err := w.backendConn.SetReadDeadline(time.Now().Add(w.readTimeout)) + if err != nil { + log.Error("ws backend read timeout", "err", err) + errC <- err + return + } + // Block until we get a message. msgType, msg, err := w.backendConn.ReadMessage() if err != nil { - errC <- err if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil { log.Error("error writing clientConn message", "err", err) + errC <- err + return } - return } RecordWSMessage(ctx, w.backend.Name, SourceBackend) @@ -1050,8 +1069,23 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) { func (w *WSProxier) writeClientConn(msgType int, msg []byte) error { w.clientConnMu.Lock() + defer w.clientConnMu.Unlock() + if err := w.clientConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { + log.Error("ws client write timeout", "err", err) + return err + } err := w.clientConn.WriteMessage(msgType, msg) - w.clientConnMu.Unlock() + return err +} + +func (w *WSProxier) writeBackendConn(msgType int, msg []byte) error { + w.backendConnMu.Lock() + defer w.backendConnMu.Unlock() + if err := w.backendConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { + log.Error("ws backend write timeout", "err", err) + return err + } + err := w.writeBackendConn(msgType, msg) return err } diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index efe7c5f..280dc30 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -27,6 +27,7 @@ import ( "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" "github.com/rs/cors" + "github.com/syndtr/goleveldb/leveldb/opt" ) const ( @@ -35,7 +36,11 @@ const ( ContextKeyXForwardedFor = "x_forwarded_for" MaxBatchRPCCallsHardLimit = 100 cacheStatusHdr = "X-Proxyd-Cache-Status" - defaultServerTimeout = time.Second * 10 + defaultRPCTimeout = 10 * time.Second + defaultBodySizeLimit = 256 * opt.KiB + defaultWSHandshakeTimeout = 10 * time.Second + defaultWSReadTimeout = 2 * time.Minute + defaultWSWriteTimeout = 10 * time.Second maxRequestBodyLogLen = 2000 defaultMaxUpstreamBatchSize = 10 ) @@ -92,11 +97,11 @@ func NewServer( } if maxBodySize == 0 { - maxBodySize = math.MaxInt64 + maxBodySize = defaultBodySizeLimit } if timeout == 0 { - timeout = defaultServerTimeout + timeout = defaultRPCTimeout } if maxUpstreamBatchSize == 0 { @@ -170,7 +175,7 @@ func NewServer( maxRequestBodyLogLen: maxRequestBodyLogLen, maxBatchSize: maxBatchSize, upgrader: &websocket.Upgrader{ - HandshakeTimeout: 5 * time.Second, + HandshakeTimeout: defaultWSHandshakeTimeout, }, mainLim: mainLim, overrideLims: overrideLims, @@ -547,6 +552,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err) return } + clientConn.SetReadLimit(s.maxBodySize) proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist) if err != nil {