This commit is contained in:
Felipe Andrade 2023-07-27 13:29:44 -07:00
parent a65810b467
commit 3f48703f26
2 changed files with 49 additions and 18 deletions

@ -884,13 +884,6 @@ func (w *WSProxier) Proxy(ctx context.Context) error {
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for { for {
err := w.clientConn.SetReadDeadline(time.Now().Add(w.readTimeout))
if err != nil {
log.Error("ws client read timeout", "err", err)
errC <- err
return
}
// Block until we get a message. // Block until we get a message.
msgType, msg, err := w.clientConn.ReadMessage() msgType, msg, err := w.clientConn.ReadMessage()
if err != nil { if err != nil {
@ -974,13 +967,6 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
for { for {
err := w.backendConn.SetReadDeadline(time.Now().Add(w.readTimeout))
if err != nil {
log.Error("ws backend read timeout", "err", err)
errC <- err
return
}
// Block until we get a message. // Block until we get a message.
msgType, msg, err := w.backendConn.ReadMessage() msgType, msg, err := w.backendConn.ReadMessage()
if err != nil { if err != nil {
@ -1085,7 +1071,7 @@ func (w *WSProxier) writeBackendConn(msgType int, msg []byte) error {
log.Error("ws backend write timeout", "err", err) log.Error("ws backend write timeout", "err", err)
return err return err
} }
err := w.writeBackendConn(msgType, msg) err := w.backendConn.WriteMessage(msgType, msg)
return err return err
} }

@ -2,16 +2,17 @@ package integration_tests
import ( import (
"os" "os"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum-optimism/optimism/proxyd" "github.com/ethereum-optimism/optimism/proxyd"
"github.com/ethereum/go-ethereum/log"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/syndtr/goleveldb/leveldb/opt"
) )
// TestConcurrentWSPanic tests for a panic in the websocket proxy // TestConcurrentWSPanic tests for a panic in the websocket proxy
@ -201,7 +202,7 @@ func TestWS(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
timeout := time.NewTicker(30 * time.Second) timeout := time.NewTicker(10 * time.Second)
doneCh := make(chan struct{}, 1) doneCh := make(chan struct{}, 1)
backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) { backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) {
require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes))) require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(tt.backendRes)))
@ -270,3 +271,47 @@ func TestWSClientClosure(t *testing.T) {
}) })
} }
} }
func TestWSClientExceedReadLimit(t *testing.T) {
backendHdlr := new(backendHandler)
clientHdlr := new(clientHandler)
backend := NewMockWSBackend(nil, func(conn *websocket.Conn, msgType int, data []byte) {
backendHdlr.MsgCB(conn, msgType, data)
}, func(conn *websocket.Conn, err error) {
backendHdlr.CloseCB(conn, err)
})
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("ws")
_, shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
client, err := NewProxydWSClient("ws://127.0.0.1:8546", func(msgType int, data []byte) {
clientHdlr.MsgCB(msgType, data)
}, nil)
require.NoError(t, err)
closed := false
originalHandler := client.conn.CloseHandler()
client.conn.SetCloseHandler(func(code int, text string) error {
closed = true
return originalHandler(code, text)
})
backendHdlr.SetMsgCB(func(conn *websocket.Conn, msgType int, data []byte) {
t.Fatalf("backend should not get the large message")
})
clientReq := "{\"id\": 1, \"method\": \"eth_subscribe\", \"params\": [\"" + strings.Repeat("barf", 256*opt.KiB+1) + "\"]}"
err = client.WriteMessage(
websocket.TextMessage,
[]byte(clientReq),
)
require.Error(t, err)
require.True(t, closed)
}