diff --git a/hathor/websocket/messages.py b/hathor/websocket/messages.py index 86058759b..d84065745 100644 --- a/hathor/websocket/messages.py +++ b/hathor/websocket/messages.py @@ -23,6 +23,12 @@ class WebSocketMessage(BaseModel): pass +class WebSocketErrorMessage(WebSocketMessage): + type: str = Field('error', const=True) + success: bool = Field(False, const=True) + errmsg: str + + class CapabilitiesMessage(WebSocketMessage): type: str = Field('capabilities', const=True) capabilities: list[str] diff --git a/hathor/websocket/protocol.py b/hathor/websocket/protocol.py index e23d2b60a..5f173151e 100644 --- a/hathor/websocket/protocol.py +++ b/hathor/websocket/protocol.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from json import JSONDecodeError from typing import TYPE_CHECKING, Any, Union from autobahn.twisted.websocket import WebSocketServerProtocol @@ -28,7 +29,7 @@ aiter_xpub_addresses, gap_limit_search, ) -from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketMessage +from hathor.websocket.messages import CapabilitiesMessage, StreamErrorMessage, WebSocketErrorMessage, WebSocketMessage from hathor.websocket.streamer import HistoryStreamer if TYPE_CHECKING: @@ -103,10 +104,17 @@ def onClose(self, wasClean, code, reason): def onMessage(self, payload: Union[bytes, str], isBinary: bool) -> None: """Called by the websocket protocol when a new message is received.""" self.log.debug('new message', payload=payload.hex() if isinstance(payload, bytes) else payload) - if isinstance(payload, bytes): - message = json_loadb(payload) - else: - message = json_loads(payload) + + try: + if isinstance(payload, bytes): + message = json_loadb(payload) + else: + message = json_loads(payload) + except JSONDecodeError: + self.send_message(WebSocketErrorMessage( + errmsg='Malformed command' + )) + return _type = message.get('type')