From 69f189c0ea3c5cf30afd793bf2dbfd5489cba6cb Mon Sep 17 00:00:00 2001 From: Murphy Law Date: Wed, 8 Jun 2022 09:56:24 -0400 Subject: [PATCH 1/2] proxyd: Handle unexpected JSON-RPC responses (#2628) This fixes a bug where the infura backend would be labeled offline because it returns an unexpected JSON-RPC response. Unexpected, but well-formed, JSON-RPC response are handled specially. Such errors are surfaced up to the backend proxier so failover still occurs. Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- proxyd/proxyd/backend.go | 31 +++++++++-- .../proxyd/integration_tests/failover_test.go | 51 +++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/proxyd/proxyd/backend.go b/proxyd/proxyd/backend.go index 2e974af..cc06137 100644 --- a/proxyd/proxyd/backend.go +++ b/proxyd/proxyd/backend.go @@ -73,6 +73,8 @@ var ( Message: "gateway timeout", HTTPErrorCode: 504, } + + ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response") ) func ErrInvalidRequest(msg string) *RPCErr { @@ -228,7 +230,20 @@ func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([] ) res, err := b.doForward(ctx, reqs, isBatch) - if err != nil { + switch err { + case nil: // do nothing + // ErrBackendUnexpectedJSONRPC occurs because infura responds with a single JSON-RPC object + // to a batch request whenever any Request Object in the batch would induce a partial error. + // We don't label the the backend offline in this case. But the error is still returned to + // callers so failover can occur if needed. + case ErrBackendUnexpectedJSONRPC: + log.Debug( + "Reecived unexpected JSON-RPC response", + "name", b.Name, + "req_id", GetReqID(ctx), + "err", err, + ) + default: lastError = err log.Warn( "backend request failed, trying again", @@ -244,7 +259,7 @@ func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([] timer.ObserveDuration() MaybeRecordErrorsInRPCRes(ctx, b.Name, reqs, res) - return res, nil + return res, err } b.setOffline() @@ -387,12 +402,15 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool var res []*RPCRes if err := json.Unmarshal(resB, &res); err != nil { + // Infura may return a single JSON-RPC response if, for example, the batch contains a request for an unsupported method + if responseIsNotBatched(resB) { + return nil, ErrBackendUnexpectedJSONRPC + } return nil, ErrBackendBadResponse } - // Alas! Certain node providers (Infura) always return a single JSON object for some types of errors if len(rpcReqs) != len(res) { - return nil, ErrBackendBadResponse + return nil, ErrBackendUnexpectedJSONRPC } // capture the HTTP status code in the response. this will only @@ -407,6 +425,11 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool return res, nil } +func responseIsNotBatched(b []byte) bool { + var r RPCRes + return json.Unmarshal(b, &r) == nil +} + // sortBatchRPCResponse sorts the RPCRes slice according to the position of its corresponding ID in the RPCReq slice func sortBatchRPCResponse(req []*RPCReq, res []*RPCRes) { pos := make(map[string]int, len(req)) diff --git a/proxyd/proxyd/integration_tests/failover_test.go b/proxyd/proxyd/integration_tests/failover_test.go index f99d1e6..f80f47c 100644 --- a/proxyd/proxyd/integration_tests/failover_test.go +++ b/proxyd/proxyd/integration_tests/failover_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/alicebob/miniredis" "github.com/ethereum-optimism/optimism/proxyd" "github.com/stretchr/testify/require" ) @@ -15,6 +16,7 @@ import ( const ( goodResponse = `{"jsonrpc": "2.0", "result": "hello", "id": 999}` noBackendsResponse = `{"error":{"code":-32011,"message":"no backends available for method"},"id":999,"jsonrpc":"2.0"}` + unexpectedResponse = `{"error":{"code":-32011,"message":"some error"},"id":999,"jsonrpc":"2.0"}` ) func TestFailover(t *testing.T) { @@ -240,3 +242,52 @@ func TestBatchWithPartialFailover(t *testing.T) { require.Equal(t, 2, len(badBackend.Requests())) require.Equal(t, 2, len(goodBackend.Requests())) } + +func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) { + InitLogger() + // Scenario: + // 1. Send batch to BAD_BACKEND (Infura) + // 2. Infura fails completely due to a partially errorneous batch request (one of N+1 request object is invalid) + // 3. Assert that the request batch is re-routed to the failover provider + // 4. Assert that BAD_BACKEND is NOT labeled offline + // 5. Assert that BAD_BACKEND is NOT retried + + redis, err := miniredis.Run() + require.NoError(t, err) + defer redis.Close() + + config := ReadConfig("failover") + config.Server.MaxUpstreamBatchSize = 2 + config.BackendOptions.MaxRetries = 2 + // Setup redis to detect offline backends + config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port()) + + goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse)) + defer goodBackend.Close() + badBackend := NewMockBackend(SingleResponseHandler(200, unexpectedResponse)) + defer badBackend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) + require.NoError(t, os.Setenv("BAD_BACKEND_RPC_URL", badBackend.URL())) + + client := NewProxydClient("http://127.0.0.1:8545") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + defer shutdown() + + res, statusCode, err := client.SendBatchRPC( + NewRPCReq("1", "eth_chainId", nil), + NewRPCReq("2", "eth_chainId", nil), + ) + require.NoError(t, err) + require.Equal(t, 200, statusCode) + RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res) + require.Equal(t, 1, len(badBackend.Requests())) + require.Equal(t, 1, len(goodBackend.Requests())) + + rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL) + require.NoError(t, err) + online, err := rr.IsBackendOnline("bad") + require.NoError(t, err) + require.Equal(t, true, online) +} From e41cfc1d945654a63fab26da9f16e44b22248c75 Mon Sep 17 00:00:00 2001 From: Matthew Slipper Date: Wed, 8 Jun 2022 09:09:32 -0600 Subject: [PATCH 2/2] proxyd: Fix concurrent WS write panic (#2711) Fixes a panic in the websocket proxyd logic. Normally, the `clientPump` and `backendPump` methods in `WSProxier` send data in one direction. However, when the client sends an invalid RPC, the `clientPump` will send a response _directly to the client_ in order to avoid unnecessary roundtrips to the backend. This could be interleaved with concurrent writes to the client's WS in `backendPump`, and would cause a panic in the WS library. To test this, this PR includes a dedicated integration test that reliably triggers the issue. In addition, this PR adds additional testing for WS functionality. --- proxyd/proxyd/backend.go | 81 +++-- .../integration_tests/mock_backend_test.go | 71 +++++ .../proxyd/integration_tests/testdata/ws.toml | 25 ++ proxyd/proxyd/integration_tests/util_test.go | 82 ++++- proxyd/proxyd/integration_tests/ws_test.go | 281 ++++++++++++++++++ proxyd/proxyd/server.go | 8 + 6 files changed, 509 insertions(+), 39 deletions(-) create mode 100644 proxyd/proxyd/integration_tests/testdata/ws.toml create mode 100644 proxyd/proxyd/integration_tests/ws_test.go diff --git a/proxyd/proxyd/backend.go b/proxyd/proxyd/backend.go index cc06137..94e94d1 100644 --- a/proxyd/proxyd/backend.go +++ b/proxyd/proxyd/backend.go @@ -15,6 +15,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/ethereum/go-ethereum/log" @@ -548,6 +549,7 @@ type WSProxier struct { clientConn *websocket.Conn backendConn *websocket.Conn methodWhitelist *StringSet + clientConnMu sync.Mutex } func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier { @@ -570,12 +572,11 @@ func (w *WSProxier) Proxy(ctx context.Context) error { func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { for { - outConn := w.backendConn // Block until we get a message. msgType, msg, err := w.clientConn.ReadMessage() if err != nil { errC <- err - if err := outConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil { + if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil { log.Error("error writing backendConn message", "err", err) } return @@ -586,7 +587,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { // Route control messages to the backend. These don't // count towards the total RPC requests count. if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { - err := outConn.WriteMessage(msgType, msg) + err := w.backendConn.WriteMessage(msgType, msg) if err != nil { errC <- err return @@ -612,20 +613,27 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { "req_id", GetReqID(ctx), "err", err, ) - outConn = w.clientConn msg = mustMarshalJSON(NewRPCErrorRes(id, err)) RecordRPCError(ctx, BackendProxyd, method, err) - } else { - RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS) - log.Info( - "forwarded WS message to backend", - "method", req.Method, - "auth", GetAuthCtx(ctx), - "req_id", GetReqID(ctx), - ) + + // Send error response to client + err = w.writeClientConn(msgType, msg) + if err != nil { + errC <- err + return + } + continue } - err = outConn.WriteMessage(msgType, msg) + RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS) + log.Info( + "forwarded WS message to backend", + "method", req.Method, + "auth", GetAuthCtx(ctx), + "req_id", GetReqID(ctx), + ) + + err = w.backendConn.WriteMessage(msgType, msg) if err != nil { errC <- err return @@ -639,7 +647,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { msgType, msg, err := w.backendConn.ReadMessage() if err != nil { errC <- err - if err := w.clientConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil { + if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil { log.Error("error writing clientConn message", "err", err) } return @@ -649,7 +657,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { // Route control messages directly to the client. if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { - err := w.clientConn.WriteMessage(msgType, msg) + err := w.writeClientConn(msgType, msg) if err != nil { errC <- err return @@ -664,26 +672,28 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { id = res.ID } msg = mustMarshalJSON(NewRPCErrorRes(id, err)) - } - if res.IsError() { - log.Info( - "backend responded with RPC error", - "code", res.Error.Code, - "msg", res.Error.Message, - "source", "ws", - "auth", GetAuthCtx(ctx), - "req_id", GetReqID(ctx), - ) - RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error) + log.Info("backend responded with error", "err", err) } else { - log.Info( - "forwarded WS message to client", - "auth", GetAuthCtx(ctx), - "req_id", GetReqID(ctx), - ) + if res.IsError() { + log.Info( + "backend responded with RPC error", + "code", res.Error.Code, + "msg", res.Error.Message, + "source", "ws", + "auth", GetAuthCtx(ctx), + "req_id", GetReqID(ctx), + ) + RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error) + } else { + log.Info( + "forwarded WS message to client", + "auth", GetAuthCtx(ctx), + "req_id", GetReqID(ctx), + ) + } } - err = w.clientConn.WriteMessage(msgType, msg) + err = w.writeClientConn(msgType, msg) if err != nil { errC <- err return @@ -726,6 +736,13 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) { return res, nil } +func (w *WSProxier) writeClientConn(msgType int, msg []byte) error { + w.clientConnMu.Lock() + err := w.clientConn.WriteMessage(msgType, msg) + w.clientConnMu.Unlock() + return err +} + func mustMarshalJSON(in interface{}) []byte { out, err := json.Marshal(in) if err != nil { diff --git a/proxyd/proxyd/integration_tests/mock_backend_test.go b/proxyd/proxyd/integration_tests/mock_backend_test.go index 94b00e1..7f40ffe 100644 --- a/proxyd/proxyd/integration_tests/mock_backend_test.go +++ b/proxyd/proxyd/integration_tests/mock_backend_test.go @@ -7,9 +7,11 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "sync" "github.com/ethereum-optimism/optimism/proxyd" + "github.com/gorilla/websocket" ) type RecordedRequest struct { @@ -251,3 +253,72 @@ func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) { m.handler.ServeHTTP(w, clone) m.mtx.Unlock() } + +type MockWSBackend struct { + connCB MockWSBackendOnConnect + msgCB MockWSBackendOnMessage + closeCB MockWSBackendOnClose + server *httptest.Server + upgrader websocket.Upgrader + conns []*websocket.Conn + connsMu sync.Mutex +} + +type MockWSBackendOnConnect func(conn *websocket.Conn) +type MockWSBackendOnMessage func(conn *websocket.Conn, msgType int, data []byte) +type MockWSBackendOnClose func(conn *websocket.Conn, err error) + +func NewMockWSBackend( + connCB MockWSBackendOnConnect, + msgCB MockWSBackendOnMessage, + closeCB MockWSBackendOnClose, +) *MockWSBackend { + mb := &MockWSBackend{ + connCB: connCB, + msgCB: msgCB, + closeCB: closeCB, + } + mb.server = httptest.NewServer(mb) + return mb +} + +func (m *MockWSBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + panic(err) + } + if m.connCB != nil { + m.connCB(conn) + } + go func() { + for { + mType, msg, err := conn.ReadMessage() + if err != nil { + if m.closeCB != nil { + m.closeCB(conn, err) + } + return + } + if m.msgCB != nil { + m.msgCB(conn, mType, msg) + } + } + }() + m.connsMu.Lock() + m.conns = append(m.conns, conn) + m.connsMu.Unlock() +} + +func (m *MockWSBackend) URL() string { + return strings.Replace(m.server.URL, "http://", "ws://", 1) +} + +func (m *MockWSBackend) Close() { + m.server.Close() + + m.connsMu.Lock() + for _, conn := range m.conns { + conn.Close() + } + m.connsMu.Unlock() +} diff --git a/proxyd/proxyd/integration_tests/testdata/ws.toml b/proxyd/proxyd/integration_tests/testdata/ws.toml new file mode 100644 index 0000000..27ecb23 --- /dev/null +++ b/proxyd/proxyd/integration_tests/testdata/ws.toml @@ -0,0 +1,25 @@ +ws_backend_group = "main" + +ws_method_whitelist = [ + "eth_subscribe" +] + +[server] +rpc_port = 8545 +ws_port = 8546 + +[backend] +response_timeout_seconds = 1 + +[backends] +[backends.good] +rpc_url = "$GOOD_BACKEND_RPC_URL" +ws_url = "$GOOD_BACKEND_RPC_URL" +max_ws_conns = 1 + +[backend_groups] +[backend_groups.main] +backends = ["good"] + +[rpc_method_mappings] +eth_chainId = "main" diff --git a/proxyd/proxyd/integration_tests/util_test.go b/proxyd/proxyd/integration_tests/util_test.go index bd798ec..c5c15a6 100644 --- a/proxyd/proxyd/integration_tests/util_test.go +++ b/proxyd/proxyd/integration_tests/util_test.go @@ -8,22 +8,26 @@ import ( "net/http" "os" "testing" + "time" + + "github.com/ethereum/go-ethereum/log" + + "github.com/gorilla/websocket" "github.com/BurntSushi/toml" "github.com/ethereum-optimism/optimism/proxyd" - "github.com/ethereum/go-ethereum/log" "github.com/stretchr/testify/require" ) -type ProxydClient struct { +type ProxydHTTPClient struct { url string } -func NewProxydClient(url string) *ProxydClient { - return &ProxydClient{url: url} +func NewProxydClient(url string) *ProxydHTTPClient { + return &ProxydHTTPClient{url: url} } -func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) { +func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) { rpcReq := NewRPCReq("999", method, params) body, err := json.Marshal(rpcReq) if err != nil { @@ -32,7 +36,7 @@ func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int return p.SendRequest(body) } -func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) { +func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) { body, err := json.Marshal(reqs) if err != nil { panic(err) @@ -40,7 +44,7 @@ func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) return p.SendRequest(body) } -func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) { +func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) { res, err := http.Post(p.url, "application/json", bytes.NewReader(body)) if err != nil { return nil, -1, err @@ -98,6 +102,70 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq { } } +type ProxydWSClient struct { + conn *websocket.Conn + msgCB ProxydWSClientOnMessage + closeCB ProxydWSClientOnClose +} + +type WSMessage struct { + Type int + Body []byte +} + +type ProxydWSClientOnMessage func(msgType int, data []byte) +type ProxydWSClientOnClose func(err error) + +func NewProxydWSClient( + url string, + msgCB ProxydWSClientOnMessage, + closeCB ProxydWSClientOnClose, +) (*ProxydWSClient, error) { + conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose + if err != nil { + return nil, err + } + + c := &ProxydWSClient{ + conn: conn, + msgCB: msgCB, + closeCB: closeCB, + } + go c.readPump() + return c, nil +} + +func (h *ProxydWSClient) readPump() { + for { + mType, msg, err := h.conn.ReadMessage() + if err != nil { + if h.closeCB != nil { + h.closeCB(err) + } + return + } + if h.msgCB != nil { + h.msgCB(mType, msg) + } + } +} + +func (h *ProxydWSClient) HardClose() { + h.conn.Close() +} + +func (h *ProxydWSClient) SoftClose() error { + return h.WriteMessage(websocket.CloseMessage, nil) +} + +func (h *ProxydWSClient) WriteMessage(msgType int, msg []byte) error { + return h.conn.WriteMessage(msgType, msg) +} + +func (h *ProxydWSClient) WriteControlMessage(msgType int, msg []byte) error { + return h.conn.WriteControl(msgType, msg, time.Now().Add(time.Minute)) +} + func InitLogger() { log.Root().SetHandler( log.LvlFilterHandler(log.LvlDebug, diff --git a/proxyd/proxyd/integration_tests/ws_test.go b/proxyd/proxyd/integration_tests/ws_test.go new file mode 100644 index 0000000..563b689 --- /dev/null +++ b/proxyd/proxyd/integration_tests/ws_test.go @@ -0,0 +1,281 @@ +package integration_tests + +import ( + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum-optimism/optimism/proxyd" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +// TestConcurrentWSPanic tests for a panic in the websocket proxy +// that occurred when messages were sent from the upstream to the +// client right after the client sent an invalid request. +func TestConcurrentWSPanic(t *testing.T) { + var backendToProxyConn *websocket.Conn + var setOnce sync.Once + + readyCh := make(chan struct{}, 1) + quitC := make(chan struct{}) + + // Pull out the backend -> proxyd conn so that we can spam it directly. + // Use a sync.Once to make sure we only do that once, for the first + // connection. + backend := NewMockWSBackend(func(conn *websocket.Conn) { + setOnce.Do(func() { + backendToProxyConn = conn + readyCh <- struct{}{} + }) + }, nil, nil) + defer backend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL())) + + config := ReadConfig("ws") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + client, err := NewProxydWSClient("ws://127.0.0.1:8546", nil, nil) + require.NoError(t, err) + defer shutdown() + + <-readyCh + + // spam messages + go func() { + for { + select { + case <-quitC: + return + default: + _ = backendToProxyConn.WriteMessage(websocket.TextMessage, []byte("garbage")) + } + } + }() + + // spam invalid RPCs + go func() { + for { + select { + case <-quitC: + return + default: + _ = client.WriteMessage(websocket.TextMessage, []byte("{\"id\": 1, \"method\": \"eth_foo\", \"params\": [\"newHeads\"]}")) + } + } + }() + + // 1 second is enough to trigger the panic due to + // concurrent write to websocket connection + time.Sleep(time.Second) + close(quitC) +} + +type backendHandler struct { + msgCB atomic.Value + closeCB atomic.Value +} + +func (b *backendHandler) MsgCB(conn *websocket.Conn, msgType int, data []byte) { + cb := b.msgCB.Load() + if cb == nil { + return + } + cb.(MockWSBackendOnMessage)(conn, msgType, data) +} + +func (b *backendHandler) SetMsgCB(cb MockWSBackendOnMessage) { + b.msgCB.Store(cb) +} + +func (b *backendHandler) CloseCB(conn *websocket.Conn, err error) { + cb := b.closeCB.Load() + if cb == nil { + return + } + cb.(MockWSBackendOnClose)(conn, err) +} + +func (b *backendHandler) SetCloseCB(cb MockWSBackendOnClose) { + b.closeCB.Store(cb) +} + +type clientHandler struct { + msgCB atomic.Value +} + +func (c *clientHandler) MsgCB(msgType int, data []byte) { + cb := c.msgCB.Load().(ProxydWSClientOnMessage) + if cb == nil { + return + } + cb(msgType, data) +} + +func (c *clientHandler) SetMsgCB(cb ProxydWSClientOnMessage) { + c.msgCB.Store(cb) +} + +func TestWS(t *testing.T) { + backendHdlr := new(backendHandler) + clientHdlr := new(clientHandler) + + backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) { + backendHdlr.MsgCB(conn, msgType, data) + }, func(conn *websocket.Conn, err error) { + backendHdlr.CloseCB(conn, err) + }) + defer backend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL())) + + config := ReadConfig("ws") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) { + clientHdlr.MsgCB(msgType, data) + }, nil) + defer client.HardClose() + require.NoError(t, err) + defer shutdown() + + tests := []struct { + name string + backendRes string + expRes string + clientReq string + }{ + { + "ok response", + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}", + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}", + "{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}", + }, + { + "garbage backend response", + "gibblegabble", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32013,\"message\":\"backend returned an invalid response\"},\"id\":null}", + "{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}", + }, + { + "blacklisted RPC", + "}", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\"message\":\"rpc method is not whitelisted\"},\"id\":1}", + "{\"id\": 1, \"method\": \"eth_whatever\", \"params\": []}", + }, + { + "garbage client request", + "{}", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}", + "barf", + }, + { + "invalid client request", + "{}", + "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}", + "{\"jsonrpc\": \"2.0\", \"method\": true}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + timeout := time.NewTicker(30 * time.Second) + doneCh := make(chan struct{}, 1) + backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) { + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes))) + }) + clientHdlr.SetMsgCB(func(msgType int, data []byte) { + require.Equal(t, tt.expRes, string(data)) + doneCh <- struct{}{} + }) + require.NoError(t, client.WriteMessage( + websocket.TextMessage, + []byte(tt.clientReq), + )) + select { + case <-timeout.C: + t.Fatalf("timed out") + case <-doneCh: + return + } + }) + } +} + +func TestWSClientClosure(t *testing.T) { + backendHdlr := new(backendHandler) + clientHdlr := new(clientHandler) + + backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) { + backendHdlr.MsgCB(conn, msgType, data) + }, func(conn *websocket.Conn, err error) { + backendHdlr.CloseCB(conn, err) + }) + defer backend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL())) + + config := ReadConfig("ws") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + defer shutdown() + + for _, closeType := range []string{"soft", "hard"} { + t.Run(closeType, func(t *testing.T) { + client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) { + clientHdlr.MsgCB(msgType, data) + }, nil) + require.NoError(t, err) + + timeout := time.NewTicker(30 * time.Second) + doneCh := make(chan struct{}, 1) + backendHdlr.SetCloseCB(func(conn *websocket.Conn, err error) { + doneCh <- struct{}{} + }) + + if closeType == "soft" { + require.NoError(t, client.SoftClose()) + } else { + client.HardClose() + } + + select { + case <-timeout.C: + t.Fatalf("timed out") + case <-doneCh: + return + } + }) + } +} + +func TestWSClientMaxConns(t *testing.T) { + backend := NewMockWSBackend(nil, nil, nil) + defer backend.Close() + + require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL())) + + config := ReadConfig("ws") + shutdown, err := proxyd.Start(config) + require.NoError(t, err) + defer shutdown() + + doneCh := make(chan struct{}, 1) + _, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, nil) + require.NoError(t, err) + _, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, func(err error) { + require.Contains(t, err.Error(), "unexpected EOF") + doneCh <- struct{}{} + }) + require.NoError(t, err) + + timeout := time.NewTicker(30 * time.Second) + select { + case <-timeout.C: + t.Fatalf("timed out") + case <-doneCh: + return + } +} diff --git a/proxyd/proxyd/server.go b/proxyd/proxyd/server.go index 30559d9..f34da71 100644 --- a/proxyd/proxyd/server.go +++ b/proxyd/proxyd/server.go @@ -11,6 +11,7 @@ import ( "net/http" "strconv" "strings" + "sync" "time" "github.com/ethereum/go-ethereum/log" @@ -44,6 +45,7 @@ type Server struct { rpcServer *http.Server wsServer *http.Server cache RPCCache + srvMu sync.Mutex } func NewServer( @@ -90,6 +92,7 @@ func NewServer( } func (s *Server) RPCListenAndServe(host string, port int) error { + s.srvMu.Lock() hdlr := mux.NewRouter() hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET") hdlr.HandleFunc("/", s.HandleRPC).Methods("POST") @@ -103,10 +106,12 @@ func (s *Server) RPCListenAndServe(host string, port int) error { Addr: addr, } log.Info("starting HTTP server", "addr", addr) + s.srvMu.Unlock() return s.rpcServer.ListenAndServe() } func (s *Server) WSListenAndServe(host string, port int) error { + s.srvMu.Lock() hdlr := mux.NewRouter() hdlr.HandleFunc("/", s.HandleWS) hdlr.HandleFunc("/{authorization}", s.HandleWS) @@ -119,10 +124,13 @@ func (s *Server) WSListenAndServe(host string, port int) error { Addr: addr, } log.Info("starting WS server", "addr", addr) + s.srvMu.Unlock() return s.wsServer.ListenAndServe() } func (s *Server) Shutdown() { + s.srvMu.Lock() + defer s.srvMu.Unlock() if s.rpcServer != nil { _ = s.rpcServer.Shutdown(context.Background()) }