Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da

### Closing the connection

* `await websocket.close(code=1000)`
* `await websocket.close(code=1000, reason=None)`

### Sending and receiving messages

Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def _asgi_send(self, message: Message) -> None:

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
raise WebSocketDisconnect(message.get("code", 1000), message.get("reason"))

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
20 changes: 15 additions & 5 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason


class WebSocket(HTTPConnection):
Expand Down Expand Up @@ -144,13 +145,22 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})

async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
async def close(self, code: int = 1000, reason: str = None) -> None:
if reason is None:
await self.send({"type": "websocket.close", "code": code})
else:
await self.send({"type": "websocket.close", "code": code, "reason": reason})


class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})
if self.reason is None:
await send({"type": "websocket.close", "code": self.code})
else:
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
38 changes: 37 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket, WebSocketClose, WebSocketDisconnect


def test_websocket_url(test_client_factory):
Expand Down Expand Up @@ -391,3 +391,39 @@ async def mock_send(message):
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}


def test_websocket_close_reason(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(code=1001, reason="Closing")

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"


def test_websocket_close_reason_manual(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()

websocket_close = WebSocketClose(code=1001, reason="Closing")
await websocket_close(scope, receive, send)

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"