Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions dot/rpc/subscription/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import (
"github.com/gorilla/websocket"
)

type websocketMessage struct {
ID float64 `json:"id"`
Method string `json:"method"`
Params any
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
}

type httpclient interface {
Do(*http.Request) (*http.Response, error)
}
Expand All @@ -46,7 +52,7 @@ type WSConn struct {
}

// readWebsocketMessage will read and parse the message data to a string->interface{} data
func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) {
func (c *WSConn) readWebsocketMessage() ([]byte, *websocketMessage, error) {
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
_, mbytes, err := c.Wsconn.ReadMessage()
if err != nil {
logger.Debugf("websocket failed to read message: %s", err)
Expand All @@ -55,8 +61,7 @@ func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error)

logger.Tracef("websocket message received: %s", string(mbytes))

// determine if request is for subscribe method type
var msg map[string]interface{}
msg := new(websocketMessage)
err = json.Unmarshal(mbytes, &msg)

if err != nil {
Expand All @@ -80,53 +85,53 @@ func (c *WSConn) HandleComm() {
continue
}

params := msg["params"]
reqid := msg["id"].(float64)
method := msg["method"].(string)
if msg.Method == "" {
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

logger.Debugf("ws method %s called with params %v", method, params)
logger.Debugf("ws method %s called with params %v", msg.Method, msg.Params)

if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") {
setupListener := c.getSetupListener(method)
if !strings.Contains(msg.Method, "_unsubscribe") && !strings.Contains(msg.Method, "_unwatch") {
setupListener := c.getSetupListener(msg.Method)

if setupListener == nil {
c.executeRPCCall(mbytes)
continue
}

listener, err := setupListener(reqid, params)
listener, err := setupListener(msg.ID, msg.Params)
if err != nil {
logger.Warnf("failed to create listener (method=%s): %s", method, err)
logger.Warnf("failed to create listener (method=%s): %s", msg.Method, err)
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
continue
}

listener.Listen()
continue
}

listener, err := c.getUnsubListener(params)

listener, err := c.getUnsubListener(msg.Params)
if err != nil {
logger.Warnf("failed to get unsubscriber (method=%s): %s", method, err)
logger.Warnf("failed to get unsubscriber (method=%s): %s", msg.Method, err)

if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) {
c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
c.safeSendError(msg.ID, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) {
c.safeSend(newBooleanResponseJSON(false, reqid))
c.safeSend(newBooleanResponseJSON(false, msg.ID))
continue
}
}

err = listener.Stop()
if err != nil {
logger.Warnf("failed to stop listener goroutine (method=%s): %s", method, err)
c.safeSend(newBooleanResponseJSON(false, reqid))
logger.Warnf("failed to stop listener goroutine (method=%s): %s", msg.Method, err)
c.safeSend(newBooleanResponseJSON(false, msg.ID))
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
}

c.safeSend(newBooleanResponseJSON(true, reqid))
c.safeSend(newBooleanResponseJSON(true, msg.ID))
continue
}
}
Expand Down
47 changes: 47 additions & 0 deletions dot/rpc/subscription/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,53 @@ import (
"github.com/stretchr/testify/require"
)

func TestWSConn_CheckWebsocketInvalidData(t *testing.T) {
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleComm()
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated

tests := []struct {
sentMessage []byte
expected []byte
}{
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"method": "",
"id": 0,
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"id": "abcdef"
"method": "some_method_name"
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
}

for _, tt := range tests {
c.WriteMessage(websocket.TextMessage, tt.sentMessage)

_, msg, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, tt.expected, msg)
}
}

func TestWSConn_HandleComm(t *testing.T) {
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
Expand Down