proxyd: Fix concurrent WS write panic (#2711)

Fixes a panic in the websocket proxyd logic. Normally, the `clientPump` and `backendPump` methods in `WSProxier` send data in one direction. However, when the client sends an invalid RPC, the `clientPump` will send a response _directly to the client_ in order to avoid unnecessary roundtrips to the backend. This could be interleaved with concurrent writes to the client's WS in `backendPump`, and would cause a panic in the WS library.

To test this, this PR includes a dedicated integration test that reliably triggers the issue. In addition, this PR adds additional testing for WS functionality.
This commit is contained in:
Matthew Slipper 2022-06-08 09:09:32 -06:00 committed by GitHub
parent 69f189c0ea
commit e41cfc1d94
6 changed files with 509 additions and 39 deletions

@ -15,6 +15,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
@ -548,6 +549,7 @@ type WSProxier struct {
clientConn *websocket.Conn
backendConn *websocket.Conn
methodWhitelist *StringSet
clientConnMu sync.Mutex
}
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
@ -570,12 +572,11 @@ func (w *WSProxier) Proxy(ctx context.Context) error {
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for {
outConn := w.backendConn
// Block until we get a message.
msgType, msg, err := w.clientConn.ReadMessage()
if err != nil {
errC <- err
if err := outConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing backendConn message", "err", err)
}
return
@ -586,7 +587,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
// Route control messages to the backend. These don't
// count towards the total RPC requests count.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
err := outConn.WriteMessage(msgType, msg)
err := w.backendConn.WriteMessage(msgType, msg)
if err != nil {
errC <- err
return
@ -612,10 +613,18 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
"req_id", GetReqID(ctx),
"err", err,
)
outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(ctx, BackendProxyd, method, err)
} else {
// Send error response to client
err = w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
}
continue
}
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
log.Info(
"forwarded WS message to backend",
@ -623,9 +632,8 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
err = outConn.WriteMessage(msgType, msg)
err = w.backendConn.WriteMessage(msgType, msg)
if err != nil {
errC <- err
return
@ -639,7 +647,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
msgType, msg, err := w.backendConn.ReadMessage()
if err != nil {
errC <- err
if err := w.clientConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing clientConn message", "err", err)
}
return
@ -649,7 +657,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
// Route control messages directly to the client.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
err := w.clientConn.WriteMessage(msgType, msg)
err := w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
@ -664,7 +672,8 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
id = res.ID
}
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
}
log.Info("backend responded with error", "err", err)
} else {
if res.IsError() {
log.Info(
"backend responded with RPC error",
@ -682,8 +691,9 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
"req_id", GetReqID(ctx),
)
}
}
err = w.clientConn.WriteMessage(msgType, msg)
err = w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
@ -726,6 +736,13 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
return res, nil
}
func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
w.clientConnMu.Lock()
err := w.clientConn.WriteMessage(msgType, msg)
w.clientConnMu.Unlock()
return err
}
func mustMarshalJSON(in interface{}) []byte {
out, err := json.Marshal(in)
if err != nil {

@ -7,9 +7,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"github.com/ethereum-optimism/optimism/proxyd"
"github.com/gorilla/websocket"
)
type RecordedRequest struct {
@ -251,3 +253,72 @@ func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
m.handler.ServeHTTP(w, clone)
m.mtx.Unlock()
}
type MockWSBackend struct {
connCB MockWSBackendOnConnect
msgCB MockWSBackendOnMessage
closeCB MockWSBackendOnClose
server *httptest.Server
upgrader websocket.Upgrader
conns []*websocket.Conn
connsMu sync.Mutex
}
type MockWSBackendOnConnect func(conn *websocket.Conn)
type MockWSBackendOnMessage func(conn *websocket.Conn, msgType int, data []byte)
type MockWSBackendOnClose func(conn *websocket.Conn, err error)
func NewMockWSBackend(
connCB MockWSBackendOnConnect,
msgCB MockWSBackendOnMessage,
closeCB MockWSBackendOnClose,
) *MockWSBackend {
mb := &MockWSBackend{
connCB: connCB,
msgCB: msgCB,
closeCB: closeCB,
}
mb.server = httptest.NewServer(mb)
return mb
}
func (m *MockWSBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
panic(err)
}
if m.connCB != nil {
m.connCB(conn)
}
go func() {
for {
mType, msg, err := conn.ReadMessage()
if err != nil {
if m.closeCB != nil {
m.closeCB(conn, err)
}
return
}
if m.msgCB != nil {
m.msgCB(conn, mType, msg)
}
}
}()
m.connsMu.Lock()
m.conns = append(m.conns, conn)
m.connsMu.Unlock()
}
func (m *MockWSBackend) URL() string {
return strings.Replace(m.server.URL, "http://", "ws://", 1)
}
func (m *MockWSBackend) Close() {
m.server.Close()
m.connsMu.Lock()
for _, conn := range m.conns {
conn.Close()
}
m.connsMu.Unlock()
}

@ -0,0 +1,25 @@
ws_backend_group = "main"
ws_method_whitelist = [
"eth_subscribe"
]
[server]
rpc_port = 8545
ws_port = 8546
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
max_ws_conns = 1
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"

@ -8,22 +8,26 @@ import (
"net/http"
"os"
"testing"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/gorilla/websocket"
"github.com/BurntSushi/toml"
"github.com/ethereum-optimism/optimism/proxyd"
"github.com/ethereum/go-ethereum/log"
"github.com/stretchr/testify/require"
)
type ProxydClient struct {
type ProxydHTTPClient struct {
url string
}
func NewProxydClient(url string) *ProxydClient {
return &ProxydClient{url: url}
func NewProxydClient(url string) *ProxydHTTPClient {
return &ProxydHTTPClient{url: url}
}
func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
rpcReq := NewRPCReq("999", method, params)
body, err := json.Marshal(rpcReq)
if err != nil {
@ -32,7 +36,7 @@ func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int
return p.SendRequest(body)
}
func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
body, err := json.Marshal(reqs)
if err != nil {
panic(err)
@ -40,7 +44,7 @@ func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error)
return p.SendRequest(body)
}
func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
if err != nil {
return nil, -1, err
@ -98,6 +102,70 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
}
}
type ProxydWSClient struct {
conn *websocket.Conn
msgCB ProxydWSClientOnMessage
closeCB ProxydWSClientOnClose
}
type WSMessage struct {
Type int
Body []byte
}
type ProxydWSClientOnMessage func(msgType int, data []byte)
type ProxydWSClientOnClose func(err error)
func NewProxydWSClient(
url string,
msgCB ProxydWSClientOnMessage,
closeCB ProxydWSClientOnClose,
) (*ProxydWSClient, error) {
conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose
if err != nil {
return nil, err
}
c := &ProxydWSClient{
conn: conn,
msgCB: msgCB,
closeCB: closeCB,
}
go c.readPump()
return c, nil
}
func (h *ProxydWSClient) readPump() {
for {
mType, msg, err := h.conn.ReadMessage()
if err != nil {
if h.closeCB != nil {
h.closeCB(err)
}
return
}
if h.msgCB != nil {
h.msgCB(mType, msg)
}
}
}
func (h *ProxydWSClient) HardClose() {
h.conn.Close()
}
func (h *ProxydWSClient) SoftClose() error {
return h.WriteMessage(websocket.CloseMessage, nil)
}
func (h *ProxydWSClient) WriteMessage(msgType int, msg []byte) error {
return h.conn.WriteMessage(msgType, msg)
}
func (h *ProxydWSClient) WriteControlMessage(msgType int, msg []byte) error {
return h.conn.WriteControl(msgType, msg, time.Now().Add(time.Minute))
}
func InitLogger() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlDebug,

@ -0,0 +1,281 @@
package integration_tests
import (
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ethereum-optimism/optimism/proxyd"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
// TestConcurrentWSPanic tests for a panic in the websocket proxy
// that occurred when messages were sent from the upstream to the
// client right after the client sent an invalid request.
func TestConcurrentWSPanic(t *testing.T) {
var backendToProxyConn *websocket.Conn
var setOnce sync.Once
readyCh := make(chan struct{}, 1)
quitC := make(chan struct{})
// Pull out the backend -> proxyd conn so that we can spam it directly.
// Use a sync.Once to make sure we only do that once, for the first
// connection.
backend := NewMockWSBackend(func(conn *websocket.Conn) {
setOnce.Do(func() {
backendToProxyConn = conn
readyCh <- struct{}{}
})
}, nil, nil)
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
client, err := NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
require.NoError(t, err)
defer shutdown()
<-readyCh
// spam messages
go func() {
for {
select {
case <-quitC:
return
default:
_ = backendToProxyConn.WriteMessage(websocket.TextMessage, []byte("garbage"))
}
}
}()
// spam invalid RPCs
go func() {
for {
select {
case <-quitC:
return
default:
_ = client.WriteMessage(websocket.TextMessage, []byte("{\"id\": 1, \"method\": \"eth_foo\", \"params\": [\"newHeads\"]}"))
}
}
}()
// 1 second is enough to trigger the panic due to
// concurrent write to websocket connection
time.Sleep(time.Second)
close(quitC)
}
type backendHandler struct {
msgCB atomic.Value
closeCB atomic.Value
}
func (b *backendHandler) MsgCB(conn *websocket.Conn, msgType int, data []byte) {
cb := b.msgCB.Load()
if cb == nil {
return
}
cb.(MockWSBackendOnMessage)(conn, msgType, data)
}
func (b *backendHandler) SetMsgCB(cb MockWSBackendOnMessage) {
b.msgCB.Store(cb)
}
func (b *backendHandler) CloseCB(conn *websocket.Conn, err error) {
cb := b.closeCB.Load()
if cb == nil {
return
}
cb.(MockWSBackendOnClose)(conn, err)
}
func (b *backendHandler) SetCloseCB(cb MockWSBackendOnClose) {
b.closeCB.Store(cb)
}
type clientHandler struct {
msgCB atomic.Value
}
func (c *clientHandler) MsgCB(msgType int, data []byte) {
cb := c.msgCB.Load().(ProxydWSClientOnMessage)
if cb == nil {
return
}
cb(msgType, data)
}
func (c *clientHandler) SetMsgCB(cb ProxydWSClientOnMessage) {
c.msgCB.Store(cb)
}
func TestWS(t *testing.T) {
backendHdlr := new(backendHandler)
clientHdlr := new(clientHandler)
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
backendHdlr.MsgCB(conn, msgType, data)
}, func(conn *websocket.Conn, err error) {
backendHdlr.CloseCB(conn, err)
})
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
clientHdlr.MsgCB(msgType, data)
}, nil)
defer client.HardClose()
require.NoError(t, err)
defer shutdown()
tests := []struct {
name string
backendRes string
expRes string
clientReq string
}{
{
"ok response",
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
},
{
"garbage backend response",
"gibblegabble",
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32013,\"message\":\"backend returned an invalid response\"},\"id\":null}",
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
},
{
"blacklisted RPC",
"}",
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\"message\":\"rpc method is not whitelisted\"},\"id\":1}",
"{\"id\": 1, \"method\": \"eth_whatever\", \"params\": []}",
},
{
"garbage client request",
"{}",
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
"barf",
},
{
"invalid client request",
"{}",
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
"{\"jsonrpc\": \"2.0\", \"method\": true}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
timeout := time.NewTicker(30 * time.Second)
doneCh := make(chan struct{}, 1)
backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) {
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes)))
})
clientHdlr.SetMsgCB(func(msgType int, data []byte) {
require.Equal(t, tt.expRes, string(data))
doneCh <- struct{}{}
})
require.NoError(t, client.WriteMessage(
websocket.TextMessage,
[]byte(tt.clientReq),
))
select {
case <-timeout.C:
t.Fatalf("timed out")
case <-doneCh:
return
}
})
}
}
func TestWSClientClosure(t *testing.T) {
backendHdlr := new(backendHandler)
clientHdlr := new(clientHandler)
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
backendHdlr.MsgCB(conn, msgType, data)
}, func(conn *websocket.Conn, err error) {
backendHdlr.CloseCB(conn, err)
})
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
for _, closeType := range []string{"soft", "hard"} {
t.Run(closeType, func(t *testing.T) {
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
clientHdlr.MsgCB(msgType, data)
}, nil)
require.NoError(t, err)
timeout := time.NewTicker(30 * time.Second)
doneCh := make(chan struct{}, 1)
backendHdlr.SetCloseCB(func(conn *websocket.Conn, err error) {
doneCh <- struct{}{}
})
if closeType == "soft" {
require.NoError(t, client.SoftClose())
} else {
client.HardClose()
}
select {
case <-timeout.C:
t.Fatalf("timed out")
case <-doneCh:
return
}
})
}
}
func TestWSClientMaxConns(t *testing.T) {
backend := NewMockWSBackend(nil, nil, nil)
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
doneCh := make(chan struct{}, 1)
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
require.NoError(t, err)
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, func(err error) {
require.Contains(t, err.Error(), "unexpected EOF")
doneCh <- struct{}{}
})
require.NoError(t, err)
timeout := time.NewTicker(30 * time.Second)
select {
case <-timeout.C:
t.Fatalf("timed out")
case <-doneCh:
return
}
}

@ -11,6 +11,7 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
@ -44,6 +45,7 @@ type Server struct {
rpcServer *http.Server
wsServer *http.Server
cache RPCCache
srvMu sync.Mutex
}
func NewServer(
@ -90,6 +92,7 @@ func NewServer(
}
func (s *Server) RPCListenAndServe(host string, port int) error {
s.srvMu.Lock()
hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
@ -103,10 +106,12 @@ func (s *Server) RPCListenAndServe(host string, port int) error {
Addr: addr,
}
log.Info("starting HTTP server", "addr", addr)
s.srvMu.Unlock()
return s.rpcServer.ListenAndServe()
}
func (s *Server) WSListenAndServe(host string, port int) error {
s.srvMu.Lock()
hdlr := mux.NewRouter()
hdlr.HandleFunc("/", s.HandleWS)
hdlr.HandleFunc("/{authorization}", s.HandleWS)
@ -119,10 +124,13 @@ func (s *Server) WSListenAndServe(host string, port int) error {
Addr: addr,
}
log.Info("starting WS server", "addr", addr)
s.srvMu.Unlock()
return s.wsServer.ListenAndServe()
}
func (s *Server) Shutdown() {
s.srvMu.Lock()
defer s.srvMu.Unlock()
if s.rpcServer != nil {
_ = s.rpcServer.Shutdown(context.Background())
}