Skip to content

Commit

Permalink
Allow sending ping and pong after close.
Browse files Browse the repository at this point in the history
Fix #1429.
  • Loading branch information
aaugustin committed Jan 27, 2024
1 parent d28b71d commit 705dc85
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 34 deletions.
28 changes: 22 additions & 6 deletions src/websockets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def send_continuation(self, data: bytes, fin: bool) -> None:
"""
if not self.expect_continuation_frame:
raise ProtocolError("unexpected continuation frame")
if self._state is not OPEN:
raise InvalidState(f"connection is {self.state.name.lower()}")
self.expect_continuation_frame = not fin
self.send_frame(Frame(OP_CONT, data, fin))

Expand All @@ -318,6 +320,8 @@ def send_text(self, data: bytes, fin: bool = True) -> None:
"""
if self.expect_continuation_frame:
raise ProtocolError("expected a continuation frame")
if self._state is not OPEN:
raise InvalidState(f"connection is {self.state.name.lower()}")
self.expect_continuation_frame = not fin
self.send_frame(Frame(OP_TEXT, data, fin))

Expand All @@ -339,6 +343,8 @@ def send_binary(self, data: bytes, fin: bool = True) -> None:
"""
if self.expect_continuation_frame:
raise ProtocolError("expected a continuation frame")
if self._state is not OPEN:
raise InvalidState(f"connection is {self.state.name.lower()}")
self.expect_continuation_frame = not fin
self.send_frame(Frame(OP_BINARY, data, fin))

Expand All @@ -358,6 +364,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None:
without a code.
"""
# While RFC 6455 doesn't rule out sending more than one close Frame,
# websockets is conservative in what it sends and doesn't allow that.
if self._state is not OPEN:
raise InvalidState(f"connection is {self.state.name.lower()}")
if code is None:
if reason != "":
raise ProtocolError("cannot send a reason without a code")
Expand All @@ -383,6 +393,9 @@ def send_ping(self, data: bytes) -> None:
data: payload containing arbitrary binary data.
"""
# RFC 6455 allows control frames after starting the closing handshake.
if self._state is not OPEN and self._state is not CLOSING:
raise InvalidState(f"connection is {self.state.name.lower()}")
self.send_frame(Frame(OP_PING, data))

def send_pong(self, data: bytes) -> None:
Expand All @@ -396,6 +409,9 @@ def send_pong(self, data: bytes) -> None:
data: payload containing arbitrary binary data.
"""
# RFC 6455 allows control frames after starting the closing handshake.
if self._state is not OPEN and self._state is not CLOSING:
raise InvalidState(f"connection is {self.state.name.lower()}")
self.send_frame(Frame(OP_PONG, data))

def fail(self, code: int, reason: str = "") -> None:
Expand Down Expand Up @@ -675,6 +691,8 @@ def recv_frame(self, frame: Frame) -> None:
# 1.4. Closing Handshake: "after receiving a control frame
# indicating the connection should be closed, a peer discards
# any further data received."
# RFC 6455 allows reading Ping and Pong frames after a Close frame.
# However, that doesn't seem useful; websockets doesn't support it.
self.parser = self.discard()
next(self.parser) # start coroutine

Expand All @@ -687,15 +705,13 @@ def recv_frame(self, frame: Frame) -> None:
# Private methods for sending events.

def send_frame(self, frame: Frame) -> None:
if self.state is not OPEN:
raise InvalidState(
f"cannot write to a WebSocket in the {self.state.name} state"
)

if self.debug:
self.logger.debug("> %s", frame)
self.writes.append(
frame.serialize(mask=self.side is CLIENT, extensions=self.extensions)
frame.serialize(
mask=self.side is CLIENT,
extensions=self.extensions,
)
)

def send_eof(self) -> None:
Expand Down
111 changes: 83 additions & 28 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,17 @@ def test_client_sends_text_after_sending_close(self):
with self.enforce_mask(b"\x00\x00\x00\x00"):
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"])
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
client.send_text(b"")
self.assertEqual(str(raised.exception), "connection is closing")

def test_server_sends_text_after_sending_close(self):
server = Protocol(SERVER)
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"])
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
server.send_text(b"")
self.assertEqual(str(raised.exception), "connection is closing")

def test_client_receives_text_after_receiving_close(self):
client = Protocol(CLIENT)
Expand Down Expand Up @@ -679,15 +681,17 @@ def test_client_sends_binary_after_sending_close(self):
with self.enforce_mask(b"\x00\x00\x00\x00"):
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"])
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
client.send_binary(b"")
self.assertEqual(str(raised.exception), "connection is closing")

def test_server_sends_binary_after_sending_close(self):
server = Protocol(SERVER)
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"])
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
server.send_binary(b"")
self.assertEqual(str(raised.exception), "connection is closing")

def test_client_receives_binary_after_receiving_close(self):
client = Protocol(CLIENT)
Expand Down Expand Up @@ -956,6 +960,37 @@ def test_server_receives_close_with_non_utf8_reason(self):
)
self.assertIs(server.state, CLOSING)

def test_client_sends_close_twice(self):
client = Protocol(CLIENT)
with self.enforce_mask(b"\x00\x00\x00\x00"):
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"])
with self.assertRaises(InvalidState) as raised:
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(str(raised.exception), "connection is closing")

def test_server_sends_close_twice(self):
server = Protocol(SERVER)
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"])
with self.assertRaises(InvalidState) as raised:
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(str(raised.exception), "connection is closing")

def test_client_sends_close_after_connection_is_closed(self):
client = Protocol(CLIENT)
client.receive_eof()
with self.assertRaises(InvalidState) as raised:
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(str(raised.exception), "connection is closed")

def test_server_sends_close_after_connection_is_closed(self):
server = Protocol(SERVER)
server.receive_eof()
with self.assertRaises(InvalidState) as raised:
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(str(raised.exception), "connection is closed")


class PingTests(ProtocolTestCase):
"""
Expand Down Expand Up @@ -1072,35 +1107,23 @@ def test_client_sends_ping_after_sending_close(self):
with self.enforce_mask(b"\x00\x00\x00\x00"):
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"])
# The spec says: "An endpoint MAY send a Ping frame any time (...)
# before the connection is closed" but websockets doesn't support
# sending a Ping frame after a Close frame.
with self.assertRaises(InvalidState) as raised:
with self.enforce_mask(b"\x00\x44\x88\xcc"):
client.send_ping(b"")
self.assertEqual(
str(raised.exception),
"cannot write to a WebSocket in the CLOSING state",
)
self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"])

def test_server_sends_ping_after_sending_close(self):
server = Protocol(SERVER)
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"])
# The spec says: "An endpoint MAY send a Ping frame any time (...)
# before the connection is closed" but websockets doesn't support
# sending a Ping frame after a Close frame.
with self.assertRaises(InvalidState) as raised:
server.send_ping(b"")
self.assertEqual(
str(raised.exception),
"cannot write to a WebSocket in the CLOSING state",
)
server.send_ping(b"")
self.assertEqual(server.data_to_send(), [b"\x89\x00"])

def test_client_receives_ping_after_receiving_close(self):
client = Protocol(CLIENT)
client.receive_data(b"\x88\x02\x03\xe8")
self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE)
client.receive_data(b"\x89\x04\x22\x66\xaa\xee")
# websockets ignores control frames after a close frame.
self.assertFrameReceived(client, None)
self.assertFrameSent(client, None)

Expand All @@ -1109,9 +1132,24 @@ def test_server_receives_ping_after_receiving_close(self):
server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9")
self.assertConnectionClosing(server, CloseCode.GOING_AWAY)
server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22")
# websockets ignores control frames after a close frame.
self.assertFrameReceived(server, None)
self.assertFrameSent(server, None)

def test_client_sends_ping_after_connection_is_closed(self):
client = Protocol(CLIENT)
client.receive_eof()
with self.assertRaises(InvalidState) as raised:
client.send_ping(b"")
self.assertEqual(str(raised.exception), "connection is closed")

def test_server_sends_ping_after_connection_is_closed(self):
server = Protocol(SERVER)
server.receive_eof()
with self.assertRaises(InvalidState) as raised:
server.send_ping(b"")
self.assertEqual(str(raised.exception), "connection is closed")


class PongTests(ProtocolTestCase):
"""
Expand Down Expand Up @@ -1212,23 +1250,23 @@ def test_client_sends_pong_after_sending_close(self):
with self.enforce_mask(b"\x00\x00\x00\x00"):
client.send_close(CloseCode.GOING_AWAY)
self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"])
# websockets doesn't support sending a Pong frame after a Close frame.
with self.assertRaises(InvalidState):
with self.enforce_mask(b"\x00\x44\x88\xcc"):
client.send_pong(b"")
self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"])

def test_server_sends_pong_after_sending_close(self):
server = Protocol(SERVER)
server.send_close(CloseCode.NORMAL_CLOSURE)
self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"])
# websockets doesn't support sending a Pong frame after a Close frame.
with self.assertRaises(InvalidState):
server.send_pong(b"")
server.send_pong(b"")
self.assertEqual(server.data_to_send(), [b"\x8a\x00"])

def test_client_receives_pong_after_receiving_close(self):
client = Protocol(CLIENT)
client.receive_data(b"\x88\x02\x03\xe8")
self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE)
client.receive_data(b"\x8a\x04\x22\x66\xaa\xee")
# websockets ignores control frames after a close frame.
self.assertFrameReceived(client, None)
self.assertFrameSent(client, None)

Expand All @@ -1237,9 +1275,24 @@ def test_server_receives_pong_after_receiving_close(self):
server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9")
self.assertConnectionClosing(server, CloseCode.GOING_AWAY)
server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22")
# websockets ignores control frames after a close frame.
self.assertFrameReceived(server, None)
self.assertFrameSent(server, None)

def test_client_sends_pong_after_connection_is_closed(self):
client = Protocol(CLIENT)
client.receive_eof()
with self.assertRaises(InvalidState) as raised:
client.send_pong(b"")
self.assertEqual(str(raised.exception), "connection is closed")

def test_server_sends_pong_after_connection_is_closed(self):
server = Protocol(SERVER)
server.receive_eof()
with self.assertRaises(InvalidState) as raised:
server.send_pong(b"")
self.assertEqual(str(raised.exception), "connection is closed")


class FailTests(ProtocolTestCase):
"""
Expand Down Expand Up @@ -1370,8 +1423,9 @@ def test_client_send_close_in_fragmented_message(self):
client.send_close()
self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"])
self.assertIs(client.state, CLOSING)
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
client.send_continuation(b"Eggs", fin=True)
self.assertEqual(str(raised.exception), "connection is closing")

def test_server_send_close_in_fragmented_message(self):
server = Protocol(SERVER)
Expand All @@ -1380,8 +1434,9 @@ def test_server_send_close_in_fragmented_message(self):
server.send_close()
self.assertEqual(server.data_to_send(), [b"\x88\x00"])
self.assertIs(server.state, CLOSING)
with self.assertRaises(InvalidState):
with self.assertRaises(InvalidState) as raised:
server.send_continuation(b"Eggs", fin=True)
self.assertEqual(str(raised.exception), "connection is closing")

def test_client_receive_close_in_fragmented_message(self):
client = Protocol(CLIENT)
Expand Down

0 comments on commit 705dc85

Please sign in to comment.