diff --git a/.gitignore b/.gitignore index 1a5fb9d4..ac9f74d9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ dist/ docs/mkdocs/site test-results.xml __pycache__ +# Vim +.*.sw* # generated by setuptools_scm bumble/_version.py .vscode/launch.json diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py index 77d03046..8991ead0 100644 --- a/bumble/transport/tcp_server.py +++ b/bumble/transport/tcp_server.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import logging +import socket from .common import Transport, StreamPacketSource @@ -28,6 +29,12 @@ # ----------------------------------------------------------------------------- + +# A pass-through function to ease mock testing. +async def _create_server(*args, **kw_args): + await asyncio.get_running_loop().create_server(*args, **kw_args) + + async def open_tcp_server_transport(spec: str) -> Transport: ''' Open a TCP server transport. @@ -38,7 +45,22 @@ async def open_tcp_server_transport(spec: str) -> Transport: Example: _:9001 ''' + local_host, local_port = spec.split(':') + return await _open_tcp_server_transport_impl( + host=local_host if local_host != '_' else None, port=int(local_port) + ) + + +async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport: + ''' + Open a TCP server transport with an existing socket. + + One reason to use this variant is to let python pick an unused port. + ''' + return await _open_tcp_server_transport_impl(sock=sock) + +async def _open_tcp_server_transport_impl(**kwargs) -> Transport: class TcpServerTransport(Transport): async def close(self): await super().close() @@ -77,13 +99,10 @@ def on_packet(self, packet): else: logger.debug('no client, dropping packet') - local_host, local_port = spec.split(':') packet_source = StreamPacketSource() packet_sink = TcpServerPacketSink() - await asyncio.get_running_loop().create_server( - lambda: TcpServerProtocol(packet_source, packet_sink), - host=local_host if local_host != '_' else None, - port=int(local_port), + await _create_server( + lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs ) return TcpServerTransport(packet_source, packet_sink) diff --git a/tests/transport_tcp_server_test.py b/tests/transport_tcp_server_test.py new file mode 100644 index 00000000..a5f015d9 --- /dev/null +++ b/tests/transport_tcp_server_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import pytest +import socket +import unittest +from unittest.mock import ANY, patch + +from bumble.transport.tcp_server import ( + open_tcp_server_transport, + open_tcp_server_transport_with_socket, +) + + +class OpenTcpServerTransportTests(unittest.TestCase): + def setUp(self): + self.patcher = patch('bumble.transport.tcp_server._create_server') + self.mock_create_server = self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test_open_with_spec(self): + asyncio.run(open_tcp_server_transport('localhost:32100')) + self.mock_create_server.assert_awaited_once_with( + ANY, host='localhost', port=32100 + ) + + def test_open_with_port_only_spec(self): + asyncio.run(open_tcp_server_transport('_:32100')) + self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100) + + def test_open_with_socket(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + asyncio.run(open_tcp_server_transport_with_socket(sock=sock)) + self.mock_create_server.assert_awaited_once_with(ANY, sock=sock) + + +@pytest.mark.skipif( + not os.environ.get('PYTEST_NOSKIP', 0), + reason='''\ +Not hermetic. Should only run manually with + $ PYTEST_NOSKIP=1 pytest tests +''', +) +def test_open_with_real_socket(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('localhost', 0)) + port = sock.getsockname()[1] + assert port != 0 + asyncio.run(open_tcp_server_transport_with_socket(sock=sock))