proxyd: Fix concurrent WS write panic (#2711)
Fixes a panic in the websocket proxyd logic. Normally, the `clientPump` and `backendPump` methods in `WSProxier` send data in one direction. However, when the client sends an invalid RPC, the `clientPump` will send a response _directly to the client_ in order to avoid unnecessary roundtrips to the backend. This could be interleaved with concurrent writes to the client's WS in `backendPump`, and would cause a panic in the WS library. To test this, this PR includes a dedicated integration test that reliably triggers the issue. In addition, this PR adds additional testing for WS functionality.
This commit is contained in:
parent
69f189c0ea
commit
e41cfc1d94
@ -15,6 +15,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/log"
|
"github.com/ethereum/go-ethereum/log"
|
||||||
@ -548,6 +549,7 @@ type WSProxier struct {
|
|||||||
clientConn *websocket.Conn
|
clientConn *websocket.Conn
|
||||||
backendConn *websocket.Conn
|
backendConn *websocket.Conn
|
||||||
methodWhitelist *StringSet
|
methodWhitelist *StringSet
|
||||||
|
clientConnMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
|
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
|
||||||
@ -570,12 +572,11 @@ func (w *WSProxier) Proxy(ctx context.Context) error {
|
|||||||
|
|
||||||
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
||||||
for {
|
for {
|
||||||
outConn := w.backendConn
|
|
||||||
// Block until we get a message.
|
// Block until we get a message.
|
||||||
msgType, msg, err := w.clientConn.ReadMessage()
|
msgType, msg, err := w.clientConn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
if err := outConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||||
log.Error("error writing backendConn message", "err", err)
|
log.Error("error writing backendConn message", "err", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -586,7 +587,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
|||||||
// Route control messages to the backend. These don't
|
// Route control messages to the backend. These don't
|
||||||
// count towards the total RPC requests count.
|
// count towards the total RPC requests count.
|
||||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||||
err := outConn.WriteMessage(msgType, msg)
|
err := w.backendConn.WriteMessage(msgType, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
@ -612,10 +613,18 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
|||||||
"req_id", GetReqID(ctx),
|
"req_id", GetReqID(ctx),
|
||||||
"err", err,
|
"err", err,
|
||||||
)
|
)
|
||||||
outConn = w.clientConn
|
|
||||||
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
||||||
RecordRPCError(ctx, BackendProxyd, method, err)
|
RecordRPCError(ctx, BackendProxyd, method, err)
|
||||||
} else {
|
|
||||||
|
// Send error response to client
|
||||||
|
err = w.writeClientConn(msgType, msg)
|
||||||
|
if err != nil {
|
||||||
|
errC <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
|
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
|
||||||
log.Info(
|
log.Info(
|
||||||
"forwarded WS message to backend",
|
"forwarded WS message to backend",
|
||||||
@ -623,9 +632,8 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
|
|||||||
"auth", GetAuthCtx(ctx),
|
"auth", GetAuthCtx(ctx),
|
||||||
"req_id", GetReqID(ctx),
|
"req_id", GetReqID(ctx),
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
err = outConn.WriteMessage(msgType, msg)
|
err = w.backendConn.WriteMessage(msgType, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
@ -639,7 +647,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
|||||||
msgType, msg, err := w.backendConn.ReadMessage()
|
msgType, msg, err := w.backendConn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
if err := w.clientConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
|
if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
|
||||||
log.Error("error writing clientConn message", "err", err)
|
log.Error("error writing clientConn message", "err", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -649,7 +657,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
|||||||
|
|
||||||
// Route control messages directly to the client.
|
// Route control messages directly to the client.
|
||||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||||
err := w.clientConn.WriteMessage(msgType, msg)
|
err := w.writeClientConn(msgType, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
@ -664,7 +672,8 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
|||||||
id = res.ID
|
id = res.ID
|
||||||
}
|
}
|
||||||
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
|
||||||
}
|
log.Info("backend responded with error", "err", err)
|
||||||
|
} else {
|
||||||
if res.IsError() {
|
if res.IsError() {
|
||||||
log.Info(
|
log.Info(
|
||||||
"backend responded with RPC error",
|
"backend responded with RPC error",
|
||||||
@ -682,8 +691,9 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
|
|||||||
"req_id", GetReqID(ctx),
|
"req_id", GetReqID(ctx),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = w.clientConn.WriteMessage(msgType, msg)
|
err = w.writeClientConn(msgType, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errC <- err
|
errC <- err
|
||||||
return
|
return
|
||||||
@ -726,6 +736,13 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
|
||||||
|
w.clientConnMu.Lock()
|
||||||
|
err := w.clientConn.WriteMessage(msgType, msg)
|
||||||
|
w.clientConnMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func mustMarshalJSON(in interface{}) []byte {
|
func mustMarshalJSON(in interface{}) []byte {
|
||||||
out, err := json.Marshal(in)
|
out, err := json.Marshal(in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -7,9 +7,11 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ethereum-optimism/optimism/proxyd"
|
"github.com/ethereum-optimism/optimism/proxyd"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecordedRequest struct {
|
type RecordedRequest struct {
|
||||||
@ -251,3 +253,72 @@ func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
m.handler.ServeHTTP(w, clone)
|
m.handler.ServeHTTP(w, clone)
|
||||||
m.mtx.Unlock()
|
m.mtx.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MockWSBackend struct {
|
||||||
|
connCB MockWSBackendOnConnect
|
||||||
|
msgCB MockWSBackendOnMessage
|
||||||
|
closeCB MockWSBackendOnClose
|
||||||
|
server *httptest.Server
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
conns []*websocket.Conn
|
||||||
|
connsMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockWSBackendOnConnect func(conn *websocket.Conn)
|
||||||
|
type MockWSBackendOnMessage func(conn *websocket.Conn, msgType int, data []byte)
|
||||||
|
type MockWSBackendOnClose func(conn *websocket.Conn, err error)
|
||||||
|
|
||||||
|
func NewMockWSBackend(
|
||||||
|
connCB MockWSBackendOnConnect,
|
||||||
|
msgCB MockWSBackendOnMessage,
|
||||||
|
closeCB MockWSBackendOnClose,
|
||||||
|
) *MockWSBackend {
|
||||||
|
mb := &MockWSBackend{
|
||||||
|
connCB: connCB,
|
||||||
|
msgCB: msgCB,
|
||||||
|
closeCB: closeCB,
|
||||||
|
}
|
||||||
|
mb.server = httptest.NewServer(mb)
|
||||||
|
return mb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWSBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if m.connCB != nil {
|
||||||
|
m.connCB(conn)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
mType, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if m.closeCB != nil {
|
||||||
|
m.closeCB(conn, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.msgCB != nil {
|
||||||
|
m.msgCB(conn, mType, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
m.connsMu.Lock()
|
||||||
|
m.conns = append(m.conns, conn)
|
||||||
|
m.connsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWSBackend) URL() string {
|
||||||
|
return strings.Replace(m.server.URL, "http://", "ws://", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWSBackend) Close() {
|
||||||
|
m.server.Close()
|
||||||
|
|
||||||
|
m.connsMu.Lock()
|
||||||
|
for _, conn := range m.conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
m.connsMu.Unlock()
|
||||||
|
}
|
||||||
|
25
proxyd/proxyd/integration_tests/testdata/ws.toml
vendored
Normal file
25
proxyd/proxyd/integration_tests/testdata/ws.toml
vendored
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
ws_backend_group = "main"
|
||||||
|
|
||||||
|
ws_method_whitelist = [
|
||||||
|
"eth_subscribe"
|
||||||
|
]
|
||||||
|
|
||||||
|
[server]
|
||||||
|
rpc_port = 8545
|
||||||
|
ws_port = 8546
|
||||||
|
|
||||||
|
[backend]
|
||||||
|
response_timeout_seconds = 1
|
||||||
|
|
||||||
|
[backends]
|
||||||
|
[backends.good]
|
||||||
|
rpc_url = "$GOOD_BACKEND_RPC_URL"
|
||||||
|
ws_url = "$GOOD_BACKEND_RPC_URL"
|
||||||
|
max_ws_conns = 1
|
||||||
|
|
||||||
|
[backend_groups]
|
||||||
|
[backend_groups.main]
|
||||||
|
backends = ["good"]
|
||||||
|
|
||||||
|
[rpc_method_mappings]
|
||||||
|
eth_chainId = "main"
|
@ -8,22 +8,26 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/log"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
"github.com/ethereum-optimism/optimism/proxyd"
|
"github.com/ethereum-optimism/optimism/proxyd"
|
||||||
"github.com/ethereum/go-ethereum/log"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxydClient struct {
|
type ProxydHTTPClient struct {
|
||||||
url string
|
url string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxydClient(url string) *ProxydClient {
|
func NewProxydClient(url string) *ProxydHTTPClient {
|
||||||
return &ProxydClient{url: url}
|
return &ProxydHTTPClient{url: url}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
|
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
|
||||||
rpcReq := NewRPCReq("999", method, params)
|
rpcReq := NewRPCReq("999", method, params)
|
||||||
body, err := json.Marshal(rpcReq)
|
body, err := json.Marshal(rpcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -32,7 +36,7 @@ func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int
|
|||||||
return p.SendRequest(body)
|
return p.SendRequest(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
|
func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
|
||||||
body, err := json.Marshal(reqs)
|
body, err := json.Marshal(reqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -40,7 +44,7 @@ func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error)
|
|||||||
return p.SendRequest(body)
|
return p.SendRequest(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) {
|
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
|
||||||
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
|
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, -1, err
|
return nil, -1, err
|
||||||
@ -98,6 +102,70 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProxydWSClient struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
msgCB ProxydWSClientOnMessage
|
||||||
|
closeCB ProxydWSClientOnClose
|
||||||
|
}
|
||||||
|
|
||||||
|
type WSMessage struct {
|
||||||
|
Type int
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxydWSClientOnMessage func(msgType int, data []byte)
|
||||||
|
type ProxydWSClientOnClose func(err error)
|
||||||
|
|
||||||
|
func NewProxydWSClient(
|
||||||
|
url string,
|
||||||
|
msgCB ProxydWSClientOnMessage,
|
||||||
|
closeCB ProxydWSClientOnClose,
|
||||||
|
) (*ProxydWSClient, error) {
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &ProxydWSClient{
|
||||||
|
conn: conn,
|
||||||
|
msgCB: msgCB,
|
||||||
|
closeCB: closeCB,
|
||||||
|
}
|
||||||
|
go c.readPump()
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxydWSClient) readPump() {
|
||||||
|
for {
|
||||||
|
mType, msg, err := h.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if h.closeCB != nil {
|
||||||
|
h.closeCB(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.msgCB != nil {
|
||||||
|
h.msgCB(mType, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxydWSClient) HardClose() {
|
||||||
|
h.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxydWSClient) SoftClose() error {
|
||||||
|
return h.WriteMessage(websocket.CloseMessage, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxydWSClient) WriteMessage(msgType int, msg []byte) error {
|
||||||
|
return h.conn.WriteMessage(msgType, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProxydWSClient) WriteControlMessage(msgType int, msg []byte) error {
|
||||||
|
return h.conn.WriteControl(msgType, msg, time.Now().Add(time.Minute))
|
||||||
|
}
|
||||||
|
|
||||||
func InitLogger() {
|
func InitLogger() {
|
||||||
log.Root().SetHandler(
|
log.Root().SetHandler(
|
||||||
log.LvlFilterHandler(log.LvlDebug,
|
log.LvlFilterHandler(log.LvlDebug,
|
||||||
|
281
proxyd/proxyd/integration_tests/ws_test.go
Normal file
281
proxyd/proxyd/integration_tests/ws_test.go
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
package integration_tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum-optimism/optimism/proxyd"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConcurrentWSPanic tests for a panic in the websocket proxy
|
||||||
|
// that occurred when messages were sent from the upstream to the
|
||||||
|
// client right after the client sent an invalid request.
|
||||||
|
func TestConcurrentWSPanic(t *testing.T) {
|
||||||
|
var backendToProxyConn *websocket.Conn
|
||||||
|
var setOnce sync.Once
|
||||||
|
|
||||||
|
readyCh := make(chan struct{}, 1)
|
||||||
|
quitC := make(chan struct{})
|
||||||
|
|
||||||
|
// Pull out the backend -> proxyd conn so that we can spam it directly.
|
||||||
|
// Use a sync.Once to make sure we only do that once, for the first
|
||||||
|
// connection.
|
||||||
|
backend := NewMockWSBackend(func(conn *websocket.Conn) {
|
||||||
|
setOnce.Do(func() {
|
||||||
|
backendToProxyConn = conn
|
||||||
|
readyCh <- struct{}{}
|
||||||
|
})
|
||||||
|
}, nil, nil)
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||||
|
|
||||||
|
config := ReadConfig("ws")
|
||||||
|
shutdown, err := proxyd.Start(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
client, err := NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
<-readyCh
|
||||||
|
|
||||||
|
// spam messages
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-quitC:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_ = backendToProxyConn.WriteMessage(websocket.TextMessage, []byte("garbage"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// spam invalid RPCs
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-quitC:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_ = client.WriteMessage(websocket.TextMessage, []byte("{\"id\": 1, \"method\": \"eth_foo\", \"params\": [\"newHeads\"]}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 1 second is enough to trigger the panic due to
|
||||||
|
// concurrent write to websocket connection
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
close(quitC)
|
||||||
|
}
|
||||||
|
|
||||||
|
type backendHandler struct {
|
||||||
|
msgCB atomic.Value
|
||||||
|
closeCB atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backendHandler) MsgCB(conn *websocket.Conn, msgType int, data []byte) {
|
||||||
|
cb := b.msgCB.Load()
|
||||||
|
if cb == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cb.(MockWSBackendOnMessage)(conn, msgType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backendHandler) SetMsgCB(cb MockWSBackendOnMessage) {
|
||||||
|
b.msgCB.Store(cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backendHandler) CloseCB(conn *websocket.Conn, err error) {
|
||||||
|
cb := b.closeCB.Load()
|
||||||
|
if cb == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cb.(MockWSBackendOnClose)(conn, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backendHandler) SetCloseCB(cb MockWSBackendOnClose) {
|
||||||
|
b.closeCB.Store(cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientHandler struct {
|
||||||
|
msgCB atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientHandler) MsgCB(msgType int, data []byte) {
|
||||||
|
cb := c.msgCB.Load().(ProxydWSClientOnMessage)
|
||||||
|
if cb == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cb(msgType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientHandler) SetMsgCB(cb ProxydWSClientOnMessage) {
|
||||||
|
c.msgCB.Store(cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWS(t *testing.T) {
|
||||||
|
backendHdlr := new(backendHandler)
|
||||||
|
clientHdlr := new(clientHandler)
|
||||||
|
|
||||||
|
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
|
||||||
|
backendHdlr.MsgCB(conn, msgType, data)
|
||||||
|
}, func(conn *websocket.Conn, err error) {
|
||||||
|
backendHdlr.CloseCB(conn, err)
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||||
|
|
||||||
|
config := ReadConfig("ws")
|
||||||
|
shutdown, err := proxyd.Start(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
|
||||||
|
clientHdlr.MsgCB(msgType, data)
|
||||||
|
}, nil)
|
||||||
|
defer client.HardClose()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
backendRes string
|
||||||
|
expRes string
|
||||||
|
clientReq string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok response",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"0xcd0c3e8af590364c09d0fa6a1210faf5\"}",
|
||||||
|
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"garbage backend response",
|
||||||
|
"gibblegabble",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32013,\"message\":\"backend returned an invalid response\"},\"id\":null}",
|
||||||
|
"{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"newHeads\"]}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"blacklisted RPC",
|
||||||
|
"}",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\"message\":\"rpc method is not whitelisted\"},\"id\":1}",
|
||||||
|
"{\"id\": 1, \"method\": \"eth_whatever\", \"params\": []}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"garbage client request",
|
||||||
|
"{}",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
|
||||||
|
"barf",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid client request",
|
||||||
|
"{}",
|
||||||
|
"{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32700,\"message\":\"parse error\"},\"id\":null}",
|
||||||
|
"{\"jsonrpc\": \"2.0\", \"method\": true}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
timeout := time.NewTicker(30 * time.Second)
|
||||||
|
doneCh := make(chan struct{}, 1)
|
||||||
|
backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) {
|
||||||
|
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes)))
|
||||||
|
})
|
||||||
|
clientHdlr.SetMsgCB(func(msgType int, data []byte) {
|
||||||
|
require.Equal(t, tt.expRes, string(data))
|
||||||
|
doneCh <- struct{}{}
|
||||||
|
})
|
||||||
|
require.NoError(t, client.WriteMessage(
|
||||||
|
websocket.TextMessage,
|
||||||
|
[]byte(tt.clientReq),
|
||||||
|
))
|
||||||
|
select {
|
||||||
|
case <-timeout.C:
|
||||||
|
t.Fatalf("timed out")
|
||||||
|
case <-doneCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSClientClosure(t *testing.T) {
|
||||||
|
backendHdlr := new(backendHandler)
|
||||||
|
clientHdlr := new(clientHandler)
|
||||||
|
|
||||||
|
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
|
||||||
|
backendHdlr.MsgCB(conn, msgType, data)
|
||||||
|
}, func(conn *websocket.Conn, err error) {
|
||||||
|
backendHdlr.CloseCB(conn, err)
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||||
|
|
||||||
|
config := ReadConfig("ws")
|
||||||
|
shutdown, err := proxyd.Start(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
for _, closeType := range []string{"soft", "hard"} {
|
||||||
|
t.Run(closeType, func(t *testing.T) {
|
||||||
|
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
|
||||||
|
clientHdlr.MsgCB(msgType, data)
|
||||||
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
timeout := time.NewTicker(30 * time.Second)
|
||||||
|
doneCh := make(chan struct{}, 1)
|
||||||
|
backendHdlr.SetCloseCB(func(conn *websocket.Conn, err error) {
|
||||||
|
doneCh <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
if closeType == "soft" {
|
||||||
|
require.NoError(t, client.SoftClose())
|
||||||
|
} else {
|
||||||
|
client.HardClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeout.C:
|
||||||
|
t.Fatalf("timed out")
|
||||||
|
case <-doneCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWSClientMaxConns(t *testing.T) {
|
||||||
|
backend := NewMockWSBackend(nil, nil, nil)
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
|
||||||
|
|
||||||
|
config := ReadConfig("ws")
|
||||||
|
shutdown, err := proxyd.Start(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
doneCh := make(chan struct{}, 1)
|
||||||
|
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = NewProxydWSClient("ws://127.0.0.1:8546", nil, func(err error) {
|
||||||
|
require.Contains(t, err.Error(), "unexpected EOF")
|
||||||
|
doneCh <- struct{}{}
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
timeout := time.NewTicker(30 * time.Second)
|
||||||
|
select {
|
||||||
|
case <-timeout.C:
|
||||||
|
t.Fatalf("timed out")
|
||||||
|
case <-doneCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
@ -11,6 +11,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/log"
|
"github.com/ethereum/go-ethereum/log"
|
||||||
@ -44,6 +45,7 @@ type Server struct {
|
|||||||
rpcServer *http.Server
|
rpcServer *http.Server
|
||||||
wsServer *http.Server
|
wsServer *http.Server
|
||||||
cache RPCCache
|
cache RPCCache
|
||||||
|
srvMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(
|
func NewServer(
|
||||||
@ -90,6 +92,7 @@ func NewServer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RPCListenAndServe(host string, port int) error {
|
func (s *Server) RPCListenAndServe(host string, port int) error {
|
||||||
|
s.srvMu.Lock()
|
||||||
hdlr := mux.NewRouter()
|
hdlr := mux.NewRouter()
|
||||||
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
|
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
|
||||||
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
|
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
|
||||||
@ -103,10 +106,12 @@ func (s *Server) RPCListenAndServe(host string, port int) error {
|
|||||||
Addr: addr,
|
Addr: addr,
|
||||||
}
|
}
|
||||||
log.Info("starting HTTP server", "addr", addr)
|
log.Info("starting HTTP server", "addr", addr)
|
||||||
|
s.srvMu.Unlock()
|
||||||
return s.rpcServer.ListenAndServe()
|
return s.rpcServer.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) WSListenAndServe(host string, port int) error {
|
func (s *Server) WSListenAndServe(host string, port int) error {
|
||||||
|
s.srvMu.Lock()
|
||||||
hdlr := mux.NewRouter()
|
hdlr := mux.NewRouter()
|
||||||
hdlr.HandleFunc("/", s.HandleWS)
|
hdlr.HandleFunc("/", s.HandleWS)
|
||||||
hdlr.HandleFunc("/{authorization}", s.HandleWS)
|
hdlr.HandleFunc("/{authorization}", s.HandleWS)
|
||||||
@ -119,10 +124,13 @@ func (s *Server) WSListenAndServe(host string, port int) error {
|
|||||||
Addr: addr,
|
Addr: addr,
|
||||||
}
|
}
|
||||||
log.Info("starting WS server", "addr", addr)
|
log.Info("starting WS server", "addr", addr)
|
||||||
|
s.srvMu.Unlock()
|
||||||
return s.wsServer.ListenAndServe()
|
return s.wsServer.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Shutdown() {
|
func (s *Server) Shutdown() {
|
||||||
|
s.srvMu.Lock()
|
||||||
|
defer s.srvMu.Unlock()
|
||||||
if s.rpcServer != nil {
|
if s.rpcServer != nil {
|
||||||
_ = s.rpcServer.Shutdown(context.Background())
|
_ = s.rpcServer.Shutdown(context.Background())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user