Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dot/rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wsc := NewWSConn(ws, h.serverConfig)
h.wsConns = append(h.wsConns, wsc)

go wsc.HandleComm()
go wsc.HandleConn()
}

// NewWSConn to create new WebSocket Connection struct
Expand Down
71 changes: 35 additions & 36 deletions dot/rpc/subscription/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@ import (
"github.com/gorilla/websocket"
)

type websocketMessage struct {
ID float64 `json:"id"`
Method string `json:"method"`
Params any
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
}

type httpclient interface {
Do(*http.Request) (*http.Response, error)
}

var errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
var errCannotUnmarshalMessage = errors.New("cannot unmarshal webasocket message data")
var logger = log.NewFromGlobal(log.AddContext("pkg", "rpc/subscription"))

// WSConn struct to hold WebSocket Connection references
Expand All @@ -46,87 +51,81 @@ type WSConn struct {
}

// readWebsocketMessage will read and parse the message data to a string->interface{} data
func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) {
_, mbytes, err := c.Wsconn.ReadMessage()
func (c *WSConn) readWebsocketMessage() (bytes []byte, err error) {
_, bytes, err = c.Wsconn.ReadMessage()
if err != nil {
logger.Debugf("websocket failed to read message: %s", err)
return nil, nil, errCannotReadFromWebsocket
}

logger.Tracef("websocket message received: %s", string(mbytes))

// determine if request is for subscribe method type
var msg map[string]interface{}
err = json.Unmarshal(mbytes, &msg)

if err != nil {
logger.Debugf("websocket failed to unmarshal request message: %s", err)
return nil, nil, errCannotUnmarshalMessage
return nil, errCannotReadFromWebsocket
}

return mbytes, msg, nil
logger.Tracef("websocket message received: %s", string(bytes))
return bytes, nil
}

//HandleComm handles messages received on websocket connections
func (c *WSConn) HandleComm() {
// HandleConn handles messages received on websocket connections
func (c *WSConn) HandleConn() {
for {
mbytes, msg, err := c.readWebsocketMessage()
mbytes, err := c.readWebsocketMessage()
if errors.Is(err, errCannotReadFromWebsocket) {
return
} else if err != nil {
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
}

if errors.Is(err, errCannotUnmarshalMessage) {
msg := new(websocketMessage)
err = json.Unmarshal(mbytes, &msg)
if err != nil {
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

params := msg["params"]
reqid := msg["id"].(float64)
method := msg["method"].(string)
if msg.Method == "" {
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated

logger.Debugf("ws method %s called with params %v", method, params)
logger.Debugf("ws method %s called with params %v", msg.Method, msg.Params)

if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") {
setupListener := c.getSetupListener(method)
if !strings.Contains(msg.Method, "_unsubscribe") && !strings.Contains(msg.Method, "_unwatch") {
setupListener := c.getSetupListener(msg.Method)

if setupListener == nil {
c.executeRPCCall(mbytes)
continue
}

listener, err := setupListener(reqid, params)
listener, err := setupListener(msg.ID, msg.Params)
if err != nil {
logger.Warnf("failed to create listener (method=%s): %s", method, err)
logger.Warnf("failed to create listener (method=%s): %s", msg.Method, err)
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
continue
}

listener.Listen()
continue
}

listener, err := c.getUnsubListener(params)

listener, err := c.getUnsubListener(msg.Params)
if err != nil {
logger.Warnf("failed to get unsubscriber (method=%s): %s", method, err)
logger.Warnf("failed to get unsubscriber (method=%s): %s", msg.Method, err)

if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) {
c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
c.safeSendError(msg.ID, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) {
c.safeSend(newBooleanResponseJSON(false, reqid))
c.safeSend(newBooleanResponseJSON(false, msg.ID))
continue
}
}

err = listener.Stop()
if err != nil {
logger.Warnf("failed to stop listener goroutine (method=%s): %s", method, err)
c.safeSend(newBooleanResponseJSON(false, reqid))
logger.Warnf("failed to stop listener goroutine (method=%s): %s", msg.Method, err)
c.safeSend(newBooleanResponseJSON(false, msg.ID))
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
}

c.safeSend(newBooleanResponseJSON(true, reqid))
c.safeSend(newBooleanResponseJSON(true, msg.ID))
continue
}
}
Expand Down
53 changes: 50 additions & 3 deletions dot/rpc/subscription/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,59 @@ import (
"github.com/stretchr/testify/require"
)

func TestWSConn_HandleComm(t *testing.T) {
func TestWSConn_CheckWebsocketInvalidData(t *testing.T) {
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
Comment thread
EclesioMeloJunior marked this conversation as resolved.
Outdated
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleComm()
go wsconn.HandleConn()

tests := []struct {
sentMessage []byte
expected []byte
}{
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"method": "",
"id": 0,
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"id": "abcdef"
"method": "some_method_name"
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
}

for _, tt := range tests {
c.WriteMessage(websocket.TextMessage, tt.sentMessage)

_, msg, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, tt.expected, msg)
}
}

func TestWSConn_HandleConn(t *testing.T) {
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleConn()
time.Sleep(time.Second * 2)

// test storageChangeListener
Expand Down Expand Up @@ -294,7 +341,7 @@ func TestSubscribeAllHeads(t *testing.T) {
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleComm()
go wsconn.HandleConn()
time.Sleep(time.Second * 2)

_, err := wsconn.initAllBlocksListerner(1, nil)
Expand Down