diff --git a/rpc/websocket.go b/rpc/websocket.go index b1213fdfa..86cf50594 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -278,24 +278,21 @@ type websocketCodec struct { conn *websocket.Conn info PeerInfo - wg sync.WaitGroup - pingReset chan struct{} + wg sync.WaitGroup + pingReset chan struct{} + pongReceived chan struct{} } func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { conn.SetReadLimit(wsMessageSizeLimit) - conn.SetPongHandler(func(appData string) error { - conn.SetReadDeadline(time.Time{}) - return nil - }) - encode := func(v interface{}, isErrorResponse bool) error { return conn.WriteJSON(v) } wc := &websocketCodec{ - jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), - conn: conn, - pingReset: make(chan struct{}, 1), + jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), + conn: conn, + pingReset: make(chan struct{}, 1), + pongReceived: make(chan struct{}), info: PeerInfo{ Transport: "ws", RemoteAddr: conn.RemoteAddr().String(), @@ -306,6 +303,13 @@ func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) Serve wc.info.HTTP.Origin = req.Get("Origin") wc.info.HTTP.UserAgent = req.Get("User-Agent") // Start pinger. + conn.SetPongHandler(func(appData string) error { + select { + case wc.pongReceived <- struct{}{}: + case <-wc.closed(): + } + return nil + }) wc.wg.Add(1) go wc.pingLoop() return wc @@ -334,26 +338,31 @@ func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError // pingLoop sends periodic ping frames when the connection is idle. func (wc *websocketCodec) pingLoop() { - var timer = time.NewTimer(wsPingInterval) + var pingTimer = time.NewTimer(wsPingInterval) defer wc.wg.Done() - defer timer.Stop() + defer pingTimer.Stop() for { select { case <-wc.closed(): return + case <-wc.pingReset: - if !timer.Stop() { - <-timer.C + if !pingTimer.Stop() { + <-pingTimer.C } - timer.Reset(wsPingInterval) - case <-timer.C: + pingTimer.Reset(wsPingInterval) + + case <-pingTimer.C: wc.jsonCodec.encMu.Lock() wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) wc.conn.WriteMessage(websocket.PingMessage, nil) wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) wc.jsonCodec.encMu.Unlock() - timer.Reset(wsPingInterval) + pingTimer.Reset(wsPingInterval) + + case <-wc.pongReceived: + wc.conn.SetReadDeadline(time.Time{}) } } }