diff --git a/dot/rpc/http.go b/dot/rpc/http.go index fb50e79503..9427c8ef67 100644 --- a/dot/rpc/http.go +++ b/dot/rpc/http.go @@ -233,7 +233,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { wsc := NewWSConn(ws, h.serverConfig) h.wsConns = append(h.wsConns, wsc) - go wsc.HandleComm() + go wsc.HandleConn() } // NewWSConn to create new WebSocket Connection struct diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index cb53b91438..faf951db14 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -22,12 +22,21 @@ import ( "github.com/gorilla/websocket" ) +type websocketMessage struct { + ID float64 `json:"id"` + Method string `json:"method"` + Params any `json:"params"` +} + type httpclient interface { Do(*http.Request) (*http.Response, error) } -var errCannotReadFromWebsocket = errors.New("cannot read message from websocket") -var errCannotUnmarshalMessage = errors.New("cannot unmarshal webasocket message data") +var ( + errCannotReadFromWebsocket = errors.New("cannot read message from websocket") + errEmptyMethod = errors.New("empty method") +) + var logger = log.NewFromGlobal(log.AddContext("pkg", "rpc/subscription")) // WSConn struct to hold WebSocket Connection references @@ -46,57 +55,53 @@ type WSConn struct { } // readWebsocketMessage will read and parse the message data to a string->interface{} data -func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) { - _, mbytes, err := c.Wsconn.ReadMessage() +func (c *WSConn) readWebsocketMessage() (rawBytes []byte, wsMessage *websocketMessage, err error) { + _, rawBytes, err = c.Wsconn.ReadMessage() if err != nil { - logger.Debugf("websocket failed to read message: %s", err) - return nil, nil, errCannotReadFromWebsocket + return nil, nil, fmt.Errorf("%w: %s", errCannotReadFromWebsocket, err.Error()) } - logger.Tracef("websocket message received: %s", string(mbytes)) - - // determine if request is for subscribe method type - var msg map[string]interface{} - err = json.Unmarshal(mbytes, &msg) - + wsMessage = new(websocketMessage) + err = json.Unmarshal(rawBytes, wsMessage) if err != nil { - logger.Debugf("websocket failed to unmarshal request message: %s", err) - return nil, nil, errCannotUnmarshalMessage + return nil, nil, err } - return mbytes, msg, nil + if wsMessage.Method == "" { + return nil, nil, errEmptyMethod + } + + return rawBytes, wsMessage, nil } -//HandleComm handles messages received on websocket connections -func (c *WSConn) HandleComm() { +// HandleConn handles messages received on websocket connections +func (c *WSConn) HandleConn() { for { - mbytes, msg, err := c.readWebsocketMessage() - if errors.Is(err, errCannotReadFromWebsocket) { - return - } + rawBytes, wsMessage, err := c.readWebsocketMessage() + if err != nil { + logger.Debugf("websocket failed to read message: %s", err) + if errors.Is(err, errCannotReadFromWebsocket) { + return + } - if errors.Is(err, errCannotUnmarshalMessage) { c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage) continue } - params := msg["params"] - reqid := msg["id"].(float64) - method := msg["method"].(string) - - logger.Debugf("ws method %s called with params %v", method, params) + logger.Tracef("websocket message received: %s", string(rawBytes)) + logger.Debugf("ws method %s called with params %v", wsMessage.Method, wsMessage.Params) - if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") { - setupListener := c.getSetupListener(method) + if !strings.Contains(wsMessage.Method, "_unsubscribe") && !strings.Contains(wsMessage.Method, "_unwatch") { + setupListener := c.getSetupListener(wsMessage.Method) if setupListener == nil { - c.executeRPCCall(mbytes) + c.executeRPCCall(rawBytes) continue } - listener, err := setupListener(reqid, params) + listener, err := setupListener(wsMessage.ID, wsMessage.Params) if err != nil { - logger.Warnf("failed to create listener (method=%s): %s", method, err) + logger.Warnf("failed to create listener (method=%s): %s", wsMessage.Method, err) continue } @@ -104,29 +109,28 @@ func (c *WSConn) HandleComm() { continue } - listener, err := c.getUnsubListener(params) - + listener, err := c.getUnsubListener(wsMessage.Params) if err != nil { - logger.Warnf("failed to get unsubscriber (method=%s): %s", method, err) + logger.Warnf("failed to get unsubscriber (method=%s): %s", wsMessage.Method, err) if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) { - c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage) + c.safeSendError(wsMessage.ID, big.NewInt(InvalidRequestCode), InvalidRequestMessage) continue } if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) { - c.safeSend(newBooleanResponseJSON(false, reqid)) + c.safeSend(newBooleanResponseJSON(false, wsMessage.ID)) continue } } err = listener.Stop() if err != nil { - logger.Warnf("failed to stop listener goroutine (method=%s): %s", method, err) - c.safeSend(newBooleanResponseJSON(false, reqid)) + logger.Warnf("failed to stop listener goroutine (method=%s): %s", wsMessage.Method, err) + c.safeSend(newBooleanResponseJSON(false, wsMessage.ID)) } - c.safeSend(newBooleanResponseJSON(true, reqid)) + c.safeSend(newBooleanResponseJSON(true, wsMessage.ID)) continue } } diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index a085dc6f2e..5ad9c6db4b 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -21,12 +21,12 @@ import ( "github.com/stretchr/testify/require" ) -func TestWSConn_HandleComm(t *testing.T) { +func TestWSConn_HandleConn(t *testing.T) { wsconn, c, cancel := setupWSConn(t) wsconn.Subscriptions = make(map[uint32]Listener) defer cancel() - go wsconn.HandleComm() + go wsconn.HandleConn() time.Sleep(time.Second * 2) // test storageChangeListener @@ -294,7 +294,7 @@ func TestSubscribeAllHeads(t *testing.T) { wsconn.Subscriptions = make(map[uint32]Listener) defer cancel() - go wsconn.HandleComm() + go wsconn.HandleConn() time.Sleep(time.Second * 2) _, err := wsconn.initAllBlocksListerner(1, nil) @@ -372,3 +372,50 @@ func TestSubscribeAllHeads(t *testing.T) { require.NoError(t, l.Stop()) mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) } + +func TestWSConn_CheckWebsocketInvalidData(t *testing.T) { + wsconn, c, cancel := setupWSConn(t) + wsconn.Subscriptions = make(map[uint32]Listener) + defer cancel() + + go wsconn.HandleConn() + + tests := []struct { + sentMessage []byte + expected []byte + }{ + { + sentMessage: []byte(`{ + "jsonrpc": "2.0", + "method": "", + "id": 0, + "params": [] + }`), + expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"), + }, + { + sentMessage: []byte(`{ + "jsonrpc": "2.0", + "params": [] + }`), + expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"), + }, + { + sentMessage: []byte(`{ + "jsonrpc": "2.0", + "id": "abcdef" + "method": "some_method_name" + "params": [] + }`), + expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"), + }, + } + + for _, tt := range tests { + c.WriteMessage(websocket.TextMessage, tt.sentMessage) + + _, msg, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, tt.expected, msg) + } +}