proxyd: Fix concurrent WS write panic (#2711)
Fixes a panic in the websocket proxyd logic. Normally, the `clientPump` and `backendPump` methods in `WSProxier` send data in one direction. However, when the client sends an invalid RPC, the `clientPump` will send a response _directly to the client_ in order to avoid unnecessary roundtrips to the backend. This could be interleaved with concurrent writes to the client's WS in `backendPump`, and would cause a panic in the WS library. To test this, this PR includes a dedicated integration test that reliably triggers the issue. In addition, this PR adds additional testing for WS functionality.
This commit is contained in:
parent
69f189c0ea
commit
e41cfc1d94
@ -15,6 +15,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
@ -548,6 +549,7 @@ type WSProxier struct {
|
||||
clientConn *websocket.Conn
|
||||
backendConn *websocket.Conn
|
||||
methodWhitelist *StringSet
|
||||
clientConnMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
|
||||
@ -570,12 +572,11 @@ func (w *WSProxier) Proxy(ctx context.Context) error {
|
||||
|
||||
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
||||
for {
|
||||
outConn := w.backendConn
|
||||
// Block until we get a message.
|
||||
msgType, msg, err := w.clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
errC <- err
|
||||
if err := outConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||
if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||
log.Error("error writing backendConn message", "err", err)
|
||||
}
|
||||
return
|
||||
@ -586,7 +587,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
||||
// Route control messages to the backend. These don't
|
||||
// count towards the total RPC requests count.
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
err := outConn.WriteMessage(msgType, msg)
|
||||
err := w.backendConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
@ -612,20 +613,27 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
||||
"req_id", GetReqID(ctx),
|
||||
"err", err,
|
||||
)
|
||||
outConn = w.clientConn
|
||||
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
||||
RecordRPCError(ctx, BackendProxyd, method, err)
|
||||
} else {
|
||||
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
|
||||
log.Info(
|
||||
"forwarded WS message to backend",
|
||||
"method", req.Method,
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
|
||||
// Send error response to client
|
||||
err = w.writeClientConn(msgType, msg)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
err = outConn.WriteMessage(msgType, msg)
|
||||
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
|
||||
log.Info(
|
||||
"forwarded WS message to backend",
|
||||
"method", req.Method,
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
|
||||
err = w.backendConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
@ -639,7 +647,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
||||
msgType, msg, err := w.backendConn.ReadMessage()
|
||||
if err != nil {
|
||||
errC <- err
|
||||
if err := w.clientConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||
if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||
log.Error("error writing clientConn message", "err", err)
|
||||
}
|
||||
return
|
||||
@ -649,7 +657,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
||||
|
||||
// Route control messages directly to the client.
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
err := w.clientConn.WriteMessage(msgType, msg)
|
||||
err := w.writeClientConn(msgType, msg)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
@ -664,26 +672,28 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
||||
id = res.ID
|
||||
}
|
||||
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
||||
}
|
||||
if res.IsError() {
|
||||
log.Info(
|
||||
"backend responded with RPC error",
|
||||
"code", res.Error.Code,
|
||||
"msg", res.Error.Message,
|
||||
"source", "ws",
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
|
||||
log.Info("backend responded with error", "err", err)
|
||||
} else {
|
||||
log.Info(
|
||||
"forwarded WS message to client",
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
if res.IsError() {
|
||||
log.Info(
|
||||
"backend responded with RPC error",
|
||||
"code", res.Error.Code,
|
||||
"msg", res.Error.Message,
|
||||
"source", "ws",
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
|
||||
} else {
|
||||
log.Info(
|
||||
"forwarded WS message to client",
|
||||
"auth", GetAuthCtx(ctx),
|
||||
"req_id", GetReqID(ctx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
err = w.clientConn.WriteMessage(msgType, msg)
|
||||
err = w.writeClientConn(msgType, msg)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
@ -726,6 +736,13 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
|
||||
w.clientConnMu.Lock()
|
||||
err := w.clientConn.WriteMessage(msgType, msg)
|
||||
w.clientConnMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func mustMarshalJSON(in interface{}) []byte {
|
||||
out, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
|
@ -7,9 +7,11 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum-optimism/optimism/proxyd"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type RecordedRequest struct {
|
||||
@ -251,3 +253,72 @@ func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
|
||||
m.handler.ServeHTTP(w, clone)
|
||||
m.mtx.Unlock()
|
||||
}
|
||||
|
||||
type MockWSBackend struct {
|
||||
connCB MockWSBackendOnConnect
|
||||
msgCB MockWSBackendOnMessage
|
||||
closeCB MockWSBackendOnClose
|
||||
server *httptest.Server
|
||||
upgrader websocket.Upgrader
|
||||
conns []*websocket.Conn
|
||||
connsMu sync.Mutex
|
||||
}
|
||||
|
||||
type MockWSBackendOnConnect func(conn *websocket.Conn)
|
||||
type MockWSBackendOnMessage func(conn *websocket.Conn, msgType int, data []byte)
|
||||
type MockWSBackendOnClose func(conn *websocket.Conn, err error)
|
||||
|
||||
func NewMockWSBackend(
|
||||
connCB MockWSBackendOnConnect,
|
||||
msgCB MockWSBackendOnMessage,
|
||||
closeCB MockWSBackendOnClose,
|
||||
) *MockWSBackend {
|
||||
mb := &MockWSBackend{
|
||||
connCB: connCB,
|
||||
msgCB: msgCB,
|
||||
closeCB: closeCB,
|
||||
}
|
||||
mb.server = httptest.NewServer(mb)
|
||||
return mb
|
||||
}
|
||||
|
||||
func (m *MockWSBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if m.connCB != nil {
|
||||
m.connCB(conn)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
mType, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if m.closeCB != nil {
|
||||
m.closeCB(conn, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if m.msgCB != nil {
|
||||
m.msgCB(conn, mType, msg)
|
||||
}
|
||||
}
|
||||
}()
|
||||
m.connsMu.Lock()
|
||||
m.conns = append(m.conns, conn)
|
||||
m.connsMu.Unlock()
|
||||
}
|
||||
|
||||
func (m *MockWSBackend) URL() string {
|
||||
return strings.Replace(m.server.URL, "http://", "ws://", 1)
|
||||
}
|
||||
|
||||
func (m *MockWSBackend) Close() {
|
||||
m.server.Close()
|
||||
|
||||
m.connsMu.Lock()
|
||||
for _, conn := range m.conns {
|
||||
conn.Close()
|
||||
}
|
||||
m.connsMu.Unlock()
|
||||
}
|
||||
|
25
proxyd/proxyd/integration_tests/testdata/ws.toml
vendored
Normal file
25
proxyd/proxyd/integration_tests/testdata/ws.toml
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
ws_backend_group = "main"
|
||||
|
||||
ws_method_whitelist = [
|
||||
"eth_subscribe"
|
||||
]
|
||||
|
||||
[server]
|
||||
rpc_port = 8545
|
||||
ws_port = 8546
|
||||
|
||||
[backend]
|
||||
response_timeout_seconds = 1
|
||||
|
||||
[backends]
|
||||
[backends.good]
|
||||
rpc_url = "$GOOD_BACKEND_RPC_URL"
|
||||
ws_url = "$GOOD_BACKEND_RPC_URL"
|
||||
max_ws_conns = 1
|
||||
|
||||
[backend_groups]
|
||||
[backend_groups.main]
|
||||
backends = ["good"]
|
||||
|
||||
[rpc_method_mappings]
|
||||
eth_chainId = "main"
|
@ -8,22 +8,26 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/ethereum-optimism/optimism/proxyd"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type ProxydClient struct {
|
||||
type ProxydHTTPClient struct {
|
||||
url string
|
||||
}
|
||||
|
||||
func NewProxydClient(url string) *ProxydClient {
|
||||
return &ProxydClient{url: url}
|
||||
func NewProxydClient(url string) *ProxydHTTPClient {
|
||||
return &ProxydHTTPClient{url: url}
|
||||
}
|
||||
|
||||
func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
|
||||
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
|
||||
rpcReq := NewRPCReq("999", method, params)
|
||||
body, err := json.Marshal(rpcReq)
|
||||
if err != nil {
|
||||
@ -32,7 +36,7 @@ func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int
|
||||
return p.SendRequest(body)
|
||||
}
|
||||
|
||||
func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
|
||||
func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
|
||||
body, err := json.Marshal(reqs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -40,7 +44,7 @@ func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error)
|
||||
return p.SendRequest(body)
|
||||
}
|
||||
|
||||
func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) {
|
||||
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
|
||||
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
@ -98,6 +102,70 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
|
||||
}
|
||||
}
|
||||
|
||||
type ProxydWSClient struct {
|
||||
conn *websocket.Conn
|
||||
msgCB ProxydWSClientOnMessage
|
||||
closeCB ProxydWSClientOnClose
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type int
|
||||
Body []byte
|
||||
}
|
||||
|
||||
type ProxydWSClientOnMessage func(msgType int, data []byte)
|
||||
type ProxydWSClientOnClose func(err error)
|
||||
|
||||
func NewProxydWSClient(
|
||||
url string,
|
||||
msgCB ProxydWSClientOnMessage,
|
||||
closeCB ProxydWSClientOnClose,
|
||||
) (*ProxydWSClient, error) {
|
||||
conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &ProxydWSClient{
|
||||
conn: conn,
|
||||
msgCB: msgCB,
|
||||
closeCB: closeCB,
|
||||
}
|
||||
go c.readPump()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (h *ProxydWSClient) readPump() {
|
||||
for {
|
||||
mType, msg, err := h.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if h.closeCB != nil {
|
||||
h.closeCB(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if h.msgCB != nil {
|
||||
h.msgCB(mType, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxydWSClient) HardClose() {
|
||||
h.conn.Close()
|
||||
}
|
||||
|
||||
func (h *ProxydWSClient) SoftClose() error {
|
||||
return h.WriteMessage(websocket.CloseMessage, nil)
|
||||
}
|
||||
|
||||
func (h *ProxydWSClient) WriteMessage(msgType int, msg []byte) error {
|
||||
return h.conn.WriteMessage(msgType, msg)
|
||||
}
|
||||
|
||||
func (h *ProxydWSClient) WriteControlMessage(msgType int, msg []byte) error {
|
||||
return h.conn.WriteControl(msgType, msg, time.Now().Add(time.Minute))
|
||||
}
|
||||
|
||||
func InitLogger() {
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(log.LvlDebug,
|
||||
|
281
proxyd/proxyd/integration_tests/ws_test.go
Normal file
281
proxyd/proxyd/integration_tests/ws_test.go
Normal file
@ -0,0 +1,281 @@
|
||||
package integration_tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum-optimism/optimism/proxyd"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConcurrentWSPanic tests for a panic in the websocket proxy
|
||||
// that occurred when messages were sent from the upstream to the
|
||||
// client right after the client sent an invalid request.
|
||||
func TestConcurrentWSPanic(t *testing.T) {
|
||||
var backendToProxyConn *websocket.Conn
|
||||
var setOnce sync.Once
|
||||
|
||||
readyCh := make(chan struct{}, 1)
|
||||
quitC := make(chan struct{})
|
||||
|
||||
// Pull out the backend -> proxyd conn so that we can spam it directly.
|
||||
// Use a sync.Once to make sure we only do that once, for the first
|
||||
// connection.
|
||||
backend := NewMockWSBackend(func(conn *websocket.Conn) {
|
||||
setOnce.Do(func() {
|
||||
backendToProxyConn = conn
|
||||
readyCh <- struct{}{}
|
||||
})
|
||||
}, nil, nil)
|
||||
defer backend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||
|
||||
config := ReadConfig("ws")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
client, err := NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
<-readyCh
|
||||
|
||||
// spam messages
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-quitC:
|
||||
return
|
||||
default:
|
||||
_ = backendToProxyConn.WriteMessage(websocket.TextMessage, []byte("garbage"))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// spam invalid RPCs
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-quitC:
|
||||
return
|
||||
default:
|
||||
_ = client.WriteMessage(websocket.TextMessage, []byte("{\"id\": 1, \"method\": \"eth_foo\", \"params\": [\"newHeads\"]}"))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 1 second is enough to trigger the panic due to
|
||||
// concurrent write to websocket connection
|
||||
time.Sleep(time.Second)
|
||||
close(quitC)
|
||||
}
|
||||
|
||||
type backendHandler struct {
|
||||
msgCB atomic.Value
|
||||
closeCB atomic.Value
|
||||
}
|
||||
|
||||
func (b *backendHandler) MsgCB(conn *websocket.Conn, msgType int, data []byte) {
|
||||
cb := b.msgCB.Load()
|
||||
if cb == nil {
|
||||
return
|
||||
}
|
||||
cb.(MockWSBackendOnMessage)(conn, msgType, data)
|
||||
}
|
||||
|
||||
func (b *backendHandler) SetMsgCB(cb MockWSBackendOnMessage) {
|
||||
b.msgCB.Store(cb)
|
||||
}
|
||||
|
||||
func (b *backendHandler) CloseCB(conn *websocket.Conn, err error) {
|
||||
cb := b.closeCB.Load()
|
||||
if cb == nil {
|
||||
return
|
||||
}
|
||||
cb.(MockWSBackendOnClose)(conn, err)
|
||||
}
|
||||
|
||||
func (b *backendHandler) SetCloseCB(cb MockWSBackendOnClose) {
|
||||
b.closeCB.Store(cb)
|
||||
}
|
||||
|
||||
type clientHandler struct {
|
||||
msgCB atomic.Value
|
||||
}
|
||||
|
||||
func (c *clientHandler) MsgCB(msgType int, data []byte) {
|
||||
cb := c.msgCB.Load().(ProxydWSClientOnMessage)
|
||||
if cb == nil {
|
||||
return
|
||||
}
|
||||
cb(msgType, data)
|
||||
}
|
||||
|
||||
func (c *clientHandler) SetMsgCB(cb ProxydWSClientOnMessage) {
|
||||
c.msgCB.Store(cb)
|
||||
}
|
||||
|
||||
func TestWS(t *testing.T) {
|
||||
backendHdlr := new(backendHandler)
|
||||
clientHdlr := new(clientHandler)
|
||||
|
||||
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
|
||||
backendHdlr.MsgCB(conn, msgType, data)
|
||||
}, func(conn *websocket.Conn, err error) {
|
||||
backendHdlr.CloseCB(conn, err)
|
||||
})
|
||||
defer backend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||
|
||||
config := ReadConfig("ws")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
|
||||
clientHdlr.MsgCB(msgType, data)
|
||||
}, nil)
|
||||
defer client.HardClose()
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
backendRes string
|
||||
expRes string
|
||||
clientReq string
|
||||
}{
|
||||
{
|
||||
"ok response",
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
|
||||
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
|
||||
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
|
||||
},
|
||||
{
|
||||
"garbage backend response",
|
||||
"gibblegabble",
|
||||
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32013,\"message\":\"backend returned an invalid response\"},\"id\":null}",
|
||||
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
|
||||
},
|
||||
{
|
||||
"blacklisted RPC",
|
||||
"}",
|
||||
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\"message\":\"rpc method is not whitelisted\"},\"id\":1}",
|
||||
"{\"id\": 1, \"method\": \"eth_whatever\", \"params\": []}",
|
||||
},
|
||||
{
|
||||
"garbage client request",
|
||||
"{}",
|
||||
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
|
||||
"barf",
|
||||
},
|
||||
{
|
||||
"invalid client request",
|
||||
"{}",
|
||||
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
|
||||
"{\"jsonrpc\": \"2.0\", \"method\": true}",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
timeout := time.NewTicker(30 * time.Second)
|
||||
doneCh := make(chan struct{}, 1)
|
||||
backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) {
|
||||
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes)))
|
||||
})
|
||||
clientHdlr.SetMsgCB(func(msgType int, data []byte) {
|
||||
require.Equal(t, tt.expRes, string(data))
|
||||
doneCh <- struct{}{}
|
||||
})
|
||||
require.NoError(t, client.WriteMessage(
|
||||
websocket.TextMessage,
|
||||
[]byte(tt.clientReq),
|
||||
))
|
||||
select {
|
||||
case <-timeout.C:
|
||||
t.Fatalf("timed out")
|
||||
case <-doneCh:
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSClientClosure(t *testing.T) {
|
||||
backendHdlr := new(backendHandler)
|
||||
clientHdlr := new(clientHandler)
|
||||
|
||||
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
|
||||
backendHdlr.MsgCB(conn, msgType, data)
|
||||
}, func(conn *websocket.Conn, err error) {
|
||||
backendHdlr.CloseCB(conn, err)
|
||||
})
|
||||
defer backend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||
|
||||
config := ReadConfig("ws")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
for _, closeType := range []string{"soft", "hard"} {
|
||||
t.Run(closeType, func(t *testing.T) {
|
||||
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
|
||||
clientHdlr.MsgCB(msgType, data)
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
timeout := time.NewTicker(30 * time.Second)
|
||||
doneCh := make(chan struct{}, 1)
|
||||
backendHdlr.SetCloseCB(func(conn *websocket.Conn, err error) {
|
||||
doneCh <- struct{}{}
|
||||
})
|
||||
|
||||
if closeType == "soft" {
|
||||
require.NoError(t, client.SoftClose())
|
||||
} else {
|
||||
client.HardClose()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-timeout.C:
|
||||
t.Fatalf("timed out")
|
||||
case <-doneCh:
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSClientMaxConns(t *testing.T) {
|
||||
backend := NewMockWSBackend(nil, nil, nil)
|
||||
defer backend.Close()
|
||||
|
||||
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||
|
||||
config := ReadConfig("ws")
|
||||
shutdown, err := proxyd.Start(config)
|
||||
require.NoError(t, err)
|
||||
defer shutdown()
|
||||
|
||||
doneCh := make(chan struct{}, 1)
|
||||
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, func(err error) {
|
||||
require.Contains(t, err.Error(), "unexpected EOF")
|
||||
doneCh <- struct{}{}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
timeout := time.NewTicker(30 * time.Second)
|
||||
select {
|
||||
case <-timeout.C:
|
||||
t.Fatalf("timed out")
|
||||
case <-doneCh:
|
||||
return
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
@ -44,6 +45,7 @@ type Server struct {
|
||||
rpcServer *http.Server
|
||||
wsServer *http.Server
|
||||
cache RPCCache
|
||||
srvMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
@ -90,6 +92,7 @@ func NewServer(
|
||||
}
|
||||
|
||||
func (s *Server) RPCListenAndServe(host string, port int) error {
|
||||
s.srvMu.Lock()
|
||||
hdlr := mux.NewRouter()
|
||||
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
|
||||
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
|
||||
@ -103,10 +106,12 @@ func (s *Server) RPCListenAndServe(host string, port int) error {
|
||||
Addr: addr,
|
||||
}
|
||||
log.Info("starting HTTP server", "addr", addr)
|
||||
s.srvMu.Unlock()
|
||||
return s.rpcServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) WSListenAndServe(host string, port int) error {
|
||||
s.srvMu.Lock()
|
||||
hdlr := mux.NewRouter()
|
||||
hdlr.HandleFunc("/", s.HandleWS)
|
||||
hdlr.HandleFunc("/{authorization}", s.HandleWS)
|
||||
@ -119,10 +124,13 @@ func (s *Server) WSListenAndServe(host string, port int) error {
|
||||
Addr: addr,
|
||||
}
|
||||
log.Info("starting WS server", "addr", addr)
|
||||
s.srvMu.Unlock()
|
||||
return s.wsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown() {
|
||||
s.srvMu.Lock()
|
||||
defer s.srvMu.Unlock()
|
||||
if s.rpcServer != nil {
|
||||
_ = s.rpcServer.Shutdown(context.Background())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user