Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ async def receive(self) -> Message:
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
assert message_type == "websocket.connect"
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about this kind of style for the message_type case?...

f'Expected ASGI message "websocket.connect", but got {message_type!r}'

self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
assert message_type in {"websocket.receive", "websocket.disconnect"}
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or '
f'"websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
Expand All @@ -55,15 +63,23 @@ async def send(self, message: Message) -> None:
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
assert message_type in {"websocket.accept", "websocket.close"}
if message_type not in {"websocket.accept", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
assert message_type in {"websocket.send", "websocket.close"}
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
Expand All @@ -89,20 +105,30 @@ def _raise_on_disconnect(self, message: Message) -> None:
raise WebSocketDisconnect(message["code"])

async def receive_text(self) -> str:
assert self.application_state == WebSocketState.CONNECTED
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]

async def receive_bytes(self) -> bytes:
assert self.application_state == WebSocketState.CONNECTED
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]

async def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
assert self.application_state == WebSocketState.CONNECTED
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)

Expand Down Expand Up @@ -140,7 +166,8 @@ async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})

async def send_json(self, data: typing.Any, mode: str = "text") -> None:
assert mode in ["text", "binary"]
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
Expand Down
134 changes: 133 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState


def test_websocket_url(test_client_factory):
Expand Down Expand Up @@ -422,3 +422,135 @@ async def asgi(receive, send):
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Going Away"


def test_send_json_invalid_mode(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_json({}, mode="invalid")

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_json_invalid_mode(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive_json(mode="invalid")

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_text_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_text()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_bytes_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_bytes()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_json_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_json()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_send_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.send"})

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_send_wrong_message_type(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.accept"})
await websocket.send({"type": "websocket.accept"})

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
websocket.client_state = WebSocketState.CONNECTING
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't test this without modifying the client state.

await websocket.receive()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/") as websocket:
websocket.send({"type": "websocket.send"})


def test_receive_wrong_message_type(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/") as websocket:
websocket.send({"type": "websocket.connect"})