feat(proxyd): betterer timeoutz

This commit is contained in:
Felipe Andrade 2023-07-27 11:48:46 -07:00
parent 73ce23c025
commit a65810b467
2 changed files with 54 additions and 14 deletions

@ -854,9 +854,12 @@ func calcBackoff(i int) time.Duration {
type WSProxier struct { type WSProxier struct {
backend *Backend backend *Backend
clientConn *websocket.Conn clientConn *websocket.Conn
backendConn *websocket.Conn
methodWhitelist *StringSet
clientConnMu sync.Mutex clientConnMu sync.Mutex
backendConn *websocket.Conn
backendConnMu sync.Mutex
methodWhitelist *StringSet
readTimeout time.Duration
writeTimeout time.Duration
} }
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier { func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
@ -865,6 +868,8 @@ func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, met
clientConn: clientConn, clientConn: clientConn,
backendConn: backendConn, backendConn: backendConn,
methodWhitelist: methodWhitelist, methodWhitelist: methodWhitelist,
readTimeout: defaultWSReadTimeout,
writeTimeout: defaultWSWriteTimeout,
} }
} }
@ -879,14 +884,21 @@ func (w *WSProxier) Proxy(ctx context.Context) error {
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for { for {
err := w.clientConn.SetReadDeadline(time.Now().Add(w.readTimeout))
if err != nil {
log.Error("ws client read timeout", "err", err)
errC <- err
return
}
// Block until we get a message. // Block until we get a message.
msgType, msg, err := w.clientConn.ReadMessage() msgType, msg, err := w.clientConn.ReadMessage()
if err != nil { if err != nil {
errC <- err if err := w.writeBackendConn(websocket.CloseMessage, formatWSError(err)); err != nil {
if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing backendConn message", "err", err) log.Error("error writing backendConn message", "err", err)
errC <- err
return
} }
return
} }
RecordWSMessage(ctx, w.backend.Name, SourceClient) RecordWSMessage(ctx, w.backend.Name, SourceClient)
@ -894,7 +906,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
// Route control messages to the backend. These don't // Route control messages to the backend. These don't
// count towards the total RPC requests count. // count towards the total RPC requests count.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
err := w.backendConn.WriteMessage(msgType, msg) err := w.writeBackendConn(msgType, msg)
if err != nil { if err != nil {
errC <- err errC <- err
return return
@ -952,7 +964,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
) )
err = w.backendConn.WriteMessage(msgType, msg) err = w.writeBackendConn(msgType, msg)
if err != nil { if err != nil {
errC <- err errC <- err
return return
@ -962,14 +974,21 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
for { for {
err := w.backendConn.SetReadDeadline(time.Now().Add(w.readTimeout))
if err != nil {
log.Error("ws backend read timeout", "err", err)
errC <- err
return
}
// Block until we get a message. // Block until we get a message.
msgType, msg, err := w.backendConn.ReadMessage() msgType, msg, err := w.backendConn.ReadMessage()
if err != nil { if err != nil {
errC <- err
if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil { if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing clientConn message", "err", err) log.Error("error writing clientConn message", "err", err)
errC <- err
return
} }
return
} }
RecordWSMessage(ctx, w.backend.Name, SourceBackend) RecordWSMessage(ctx, w.backend.Name, SourceBackend)
@ -1050,8 +1069,23 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
func (w *WSProxier) writeClientConn(msgType int, msg []byte) error { func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
w.clientConnMu.Lock() w.clientConnMu.Lock()
defer w.clientConnMu.Unlock()
if err := w.clientConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
log.Error("ws client write timeout", "err", err)
return err
}
err := w.clientConn.WriteMessage(msgType, msg) err := w.clientConn.WriteMessage(msgType, msg)
w.clientConnMu.Unlock() return err
}
func (w *WSProxier) writeBackendConn(msgType int, msg []byte) error {
w.backendConnMu.Lock()
defer w.backendConnMu.Unlock()
if err := w.backendConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil {
log.Error("ws backend write timeout", "err", err)
return err
}
err := w.writeBackendConn(msgType, msg)
return err return err
} }

@ -27,6 +27,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/rs/cors" "github.com/rs/cors"
"github.com/syndtr/goleveldb/leveldb/opt"
) )
const ( const (
@ -35,7 +36,11 @@ const (
ContextKeyXForwardedFor = "x_forwarded_for" ContextKeyXForwardedFor = "x_forwarded_for"
MaxBatchRPCCallsHardLimit = 100 MaxBatchRPCCallsHardLimit = 100
cacheStatusHdr = "X-Proxyd-Cache-Status" cacheStatusHdr = "X-Proxyd-Cache-Status"
defaultServerTimeout = time.Second * 10 defaultRPCTimeout = 10 * time.Second
defaultBodySizeLimit = 256 * opt.KiB
defaultWSHandshakeTimeout = 10 * time.Second
defaultWSReadTimeout = 2 * time.Minute
defaultWSWriteTimeout = 10 * time.Second
maxRequestBodyLogLen = 2000 maxRequestBodyLogLen = 2000
defaultMaxUpstreamBatchSize = 10 defaultMaxUpstreamBatchSize = 10
) )
@ -92,11 +97,11 @@ func NewServer(
} }
if maxBodySize == 0 { if maxBodySize == 0 {
maxBodySize = math.MaxInt64 maxBodySize = defaultBodySizeLimit
} }
if timeout == 0 { if timeout == 0 {
timeout = defaultServerTimeout timeout = defaultRPCTimeout
} }
if maxUpstreamBatchSize == 0 { if maxUpstreamBatchSize == 0 {
@ -170,7 +175,7 @@ func NewServer(
maxRequestBodyLogLen: maxRequestBodyLogLen, maxRequestBodyLogLen: maxRequestBodyLogLen,
maxBatchSize: maxBatchSize, maxBatchSize: maxBatchSize,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: defaultWSHandshakeTimeout,
}, },
mainLim: mainLim, mainLim: mainLim,
overrideLims: overrideLims, overrideLims: overrideLims,
@ -547,6 +552,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err) log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
return return
} }
clientConn.SetReadLimit(s.maxBodySize)
proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist) proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist)
if err != nil { if err != nil {