Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/rude-foxes-think.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@eth-optimism/proxyd': patch
---

Fix concurrent write panic in WS
81 changes: 49 additions & 32 deletions proxyd/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/ethereum/go-ethereum/log"
Expand Down Expand Up @@ -548,6 +549,7 @@ type WSProxier struct {
clientConn *websocket.Conn
backendConn *websocket.Conn
methodWhitelist *StringSet
clientConnMu sync.Mutex
}

func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
Expand All @@ -570,12 +572,11 @@ func (w *WSProxier) Proxy(ctx context.Context) error {

func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for {
outConn := w.backendConn
// Block until we get a message.
msgType, msg, err := w.clientConn.ReadMessage()
if err != nil {
errC <- err
if err := outConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing backendConn message", "err", err)
}
return
Expand All @@ -586,7 +587,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
// Route control messages to the backend. These don't
// count towards the total RPC requests count.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
err := outConn.WriteMessage(msgType, msg)
err := w.backendConn.WriteMessage(msgType, msg)
if err != nil {
errC <- err
return
Expand All @@ -612,20 +613,27 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
"req_id", GetReqID(ctx),
"err", err,
)
outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(ctx, BackendProxyd, method, err)
} else {
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
log.Info(
"forwarded WS message to backend",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)

// Send error response to client
err = w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
}
continue
}

err = outConn.WriteMessage(msgType, msg)
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
log.Info(
"forwarded WS message to backend",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)

err = w.backendConn.WriteMessage(msgType, msg)
if err != nil {
errC <- err
return
Expand All @@ -639,7 +647,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
msgType, msg, err := w.backendConn.ReadMessage()
if err != nil {
errC <- err
if err := w.clientConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
log.Error("error writing clientConn message", "err", err)
}
return
Expand All @@ -649,7 +657,7 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {

// Route control messages directly to the client.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
err := w.clientConn.WriteMessage(msgType, msg)
err := w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
Expand All @@ -664,26 +672,28 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
id = res.ID
}
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
}
if res.IsError() {
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"source", "ws",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
log.Info("backend responded with error", "err", err)
} else {
log.Info(
"forwarded WS message to client",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
if res.IsError() {
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"source", "ws",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
} else {
log.Info(
"forwarded WS message to client",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
}

err = w.clientConn.WriteMessage(msgType, msg)
err = w.writeClientConn(msgType, msg)
if err != nil {
errC <- err
return
Expand Down Expand Up @@ -726,6 +736,13 @@ func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
return res, nil
}

func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
w.clientConnMu.Lock()
err := w.clientConn.WriteMessage(msgType, msg)
w.clientConnMu.Unlock()
return err
}

func mustMarshalJSON(in interface{}) []byte {
out, err := json.Marshal(in)
if err != nil {
Expand Down
71 changes: 71 additions & 0 deletions proxyd/integration_tests/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"

"github.com/ethereum-optimism/optimism/proxyd"
"github.com/gorilla/websocket"
)

type RecordedRequest struct {
Expand Down Expand Up @@ -251,3 +253,72 @@ func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
m.handler.ServeHTTP(w, clone)
m.mtx.Unlock()
}

type MockWSBackend struct {
connCB MockWSBackendOnConnect
msgCB MockWSBackendOnMessage
closeCB MockWSBackendOnClose
server *httptest.Server
upgrader websocket.Upgrader
conns []*websocket.Conn
connsMu sync.Mutex
}

type MockWSBackendOnConnect func(conn *websocket.Conn)
type MockWSBackendOnMessage func(conn *websocket.Conn, msgType int, data []byte)
type MockWSBackendOnClose func(conn *websocket.Conn, err error)

func NewMockWSBackend(
connCB MockWSBackendOnConnect,
msgCB MockWSBackendOnMessage,
closeCB MockWSBackendOnClose,
) *MockWSBackend {
mb := &MockWSBackend{
connCB: connCB,
msgCB: msgCB,
closeCB: closeCB,
}
mb.server = httptest.NewServer(mb)
return mb
}

func (m *MockWSBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
panic(err)
}
if m.connCB != nil {
m.connCB(conn)
}
go func() {
for {
mType, msg, err := conn.ReadMessage()
if err != nil {
if m.closeCB != nil {
m.closeCB(conn, err)
}
return
}
if m.msgCB != nil {
m.msgCB(conn, mType, msg)
}
}
}()
m.connsMu.Lock()
m.conns = append(m.conns, conn)
m.connsMu.Unlock()
}

func (m *MockWSBackend) URL() string {
return strings.Replace(m.server.URL, "http://", "ws://", 1)
}

func (m *MockWSBackend) Close() {
m.server.Close()

m.connsMu.Lock()
for _, conn := range m.conns {
conn.Close()
}
m.connsMu.Unlock()
}
25 changes: 25 additions & 0 deletions proxyd/integration_tests/testdata/ws.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
ws_backend_group = "main"

ws_method_whitelist = [
"eth_subscribe"
]

[server]
rpc_port = 8545
ws_port = 8546

[backend]
response_timeout_seconds = 1

[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
max_ws_conns = 1

[backend_groups]
[backend_groups.main]
backends = ["good"]

[rpc_method_mappings]
eth_chainId = "main"
82 changes: 75 additions & 7 deletions proxyd/integration_tests/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,26 @@ import (
"net/http"
"os"
"testing"
"time"

"github.com/ethereum/go-ethereum/log"

"github.com/gorilla/websocket"

"github.com/BurntSushi/toml"
"github.com/ethereum-optimism/optimism/proxyd"
"github.com/ethereum/go-ethereum/log"
"github.com/stretchr/testify/require"
)

type ProxydClient struct {
type ProxydHTTPClient struct {
url string
}

func NewProxydClient(url string) *ProxydClient {
return &ProxydClient{url: url}
func NewProxydClient(url string) *ProxydHTTPClient {
return &ProxydHTTPClient{url: url}
}

func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
rpcReq := NewRPCReq("999", method, params)
body, err := json.Marshal(rpcReq)
if err != nil {
Expand All @@ -32,15 +36,15 @@ func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int
return p.SendRequest(body)
}

func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
body, err := json.Marshal(reqs)
if err != nil {
panic(err)
}
return p.SendRequest(body)
}

func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) {
func (p *ProxydHTTPClient) SendRequest(body []byte) ([]byte, int, error) {
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
if err != nil {
return nil, -1, err
Expand Down Expand Up @@ -98,6 +102,70 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
}
}

type ProxydWSClient struct {
conn *websocket.Conn
msgCB ProxydWSClientOnMessage
closeCB ProxydWSClientOnClose
}

type WSMessage struct {
Type int
Body []byte
}

type ProxydWSClientOnMessage func(msgType int, data []byte)
type ProxydWSClientOnClose func(err error)

func NewProxydWSClient(
url string,
msgCB ProxydWSClientOnMessage,
closeCB ProxydWSClientOnClose,
) (*ProxydWSClient, error) {
conn, _, err := websocket.DefaultDialer.Dial(url, nil) // nolint:bodyclose
if err != nil {
return nil, err
}

c := &ProxydWSClient{
conn: conn,
msgCB: msgCB,
closeCB: closeCB,
}
go c.readPump()
return c, nil
}

func (h *ProxydWSClient) readPump() {
for {
mType, msg, err := h.conn.ReadMessage()
if err != nil {
if h.closeCB != nil {
h.closeCB(err)
}
return
}
if h.msgCB != nil {
h.msgCB(mType, msg)
}
}
}

func (h *ProxydWSClient) HardClose() {
h.conn.Close()
}

func (h *ProxydWSClient) SoftClose() error {
return h.WriteMessage(websocket.CloseMessage, nil)
}

func (h *ProxydWSClient) WriteMessage(msgType int, msg []byte) error {
return h.conn.WriteMessage(msgType, msg)
}

func (h *ProxydWSClient) WriteControlMessage(msgType int, msg []byte) error {
return h.conn.WriteControl(msgType, msg, time.Now().Add(time.Minute))
}

func InitLogger() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlDebug,
Expand Down
Loading