Skip to content

Commit

Permalink
Add API for the set of active connections.
Browse files Browse the repository at this point in the history
Fix #1486.
  • Loading branch information
aaugustin committed Aug 22, 2024
1 parent 4920a58 commit d341bba
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 32 deletions.
10 changes: 10 additions & 0 deletions docs/howto/upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ implementation.
Depending on your use case, adopting this method may improve performance when
streaming large messages. Specifically, it could reduce memory usage.

Tracking open connections
.........................

The new implementation of :class:`~asyncio.server.Server` provides a
:attr:`~asyncio.server.Server.connections` property, which is a set of all open
connections. This didn't exist in the original implementation.

If you were keeping track of open connections, you may be able to simplify your
code by using this property.

.. _basic-auth:

Performing HTTP Basic Authentication
Expand Down
7 changes: 5 additions & 2 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ notice.
New features
............

* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for
enforcing HTTP Basic Auth on the server side.
* Made the set of active connections available in the :attr:`Server.connections
<asyncio.server.Server.connections>` property.

* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
implementations of servers.

.. _13.0:

Expand Down
2 changes: 2 additions & 0 deletions docs/reference/asyncio/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Running a server

.. autoclass:: Server

.. autoattribute:: connections

.. automethod:: close

.. automethod:: wait_closed
Expand Down
14 changes: 13 additions & 1 deletion src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
validate_subprotocols,
)
from ..http11 import SERVER, Request, Response
from ..protocol import CONNECTING, Event
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
from .compatibility import asyncio_timeout
Expand Down Expand Up @@ -313,6 +313,18 @@ def __init__(
# Completed when the server is closed and connections are terminated.
self.closed_waiter: asyncio.Future[None] = self.loop.create_future()

@property
def connections(self) -> set[ServerConnection]:
"""
Set of active connections.
This property contains all connections that completed the opening
handshake successfully and didn't start the closing handshake yet.
It can be useful in combination with :func:`~broadcast`.
"""
return {connection for connection in self.handlers if connection.state is OPEN}

def wrap(self, server: asyncio.Server) -> None:
"""
Attach to a given :class:`asyncio.Server`.
Expand Down
24 changes: 16 additions & 8 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ async def test_disable_keepalive(self):
latency = eval(await client.recv())
self.assertEqual(latency, 0)

async def test_logger(self):
"""Server accepts a logger argument."""
logger = logging.getLogger("test")
async with run_server(logger=logger) as server:
self.assertIs(server.logger, logger)

async def test_custom_connection_factory(self):
"""Server runs ServerConnection factory provided in create_connection."""

Expand All @@ -362,6 +368,16 @@ def create_connection(*args, **kwargs):
async with run_client(server) as client:
await self.assertEval(client, "ws.create_connection_ran", "True")

async def test_connections(self):
"""Server provides a connections property."""
async with run_server() as server:
self.assertEqual(server.connections, set())
async with run_client(server) as client:
self.assertEqual(len(server.connections), 1)
ws_id = str(next(iter(server.connections)).id)
await self.assertEval(client, "ws.id", ws_id)
self.assertEqual(server.connections, set())

async def test_handshake_fails(self):
"""Server receives connection from client but the handshake fails."""

Expand Down Expand Up @@ -555,14 +571,6 @@ async def test_unsupported_compression(self):
)


class WebSocketServerTests(unittest.IsolatedAsyncioTestCase):
async def test_logger(self):
"""Server accepts a logger argument."""
logger = logging.getLogger("test")
async with run_server(logger=logger) as server:
self.assertIs(server.logger, logger)


class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase):
async def test_valid_authorization(self):
"""basic_auth authenticates client with HTTP Basic Authentication."""
Expand Down
40 changes: 19 additions & 21 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def test_disable_compression(self):
with run_client(server) as client:
self.assertEval(client, "ws.protocol.extensions", "[]")

def test_logger(self):
"""Server accepts a logger argument."""
logger = logging.getLogger("test")
with run_server(logger=logger) as server:
self.assertIs(server.logger, logger)

def test_custom_connection_factory(self):
"""Server runs ServerConnection factory provided in create_connection."""

Expand All @@ -247,6 +253,19 @@ def create_connection(*args, **kwargs):
with run_client(server) as client:
self.assertEval(client, "ws.create_connection_ran", "True")

def test_fileno(self):
"""Server provides a fileno attribute."""
with run_server() as server:
self.assertIsInstance(server.fileno(), int)

def test_shutdown(self):
"""Server provides a shutdown method."""
with run_server() as server:
server.shutdown()
# Check that the server socket is closed.
with self.assertRaises(OSError):
server.socket.accept()

def test_handshake_fails(self):
"""Server receives connection from client but the handshake fails."""

Expand Down Expand Up @@ -393,27 +412,6 @@ def test_unsupported_compression(self):
)


class WebSocketServerTests(unittest.TestCase):
def test_logger(self):
"""Server accepts a logger argument."""
logger = logging.getLogger("test")
with run_server(logger=logger) as server:
self.assertIs(server.logger, logger)

def test_fileno(self):
"""Server provides a fileno attribute."""
with run_server() as server:
self.assertIsInstance(server.fileno(), int)

def test_shutdown(self):
"""Server provides a shutdown method."""
with run_server() as server:
server.shutdown()
# Check that the server socket is closed.
with self.assertRaises(OSError):
server.socket.accept()


class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase):
def test_valid_authorization(self):
"""basic_auth authenticates client with HTTP Basic Authentication."""
Expand Down

0 comments on commit d341bba

Please sign in to comment.