From 1ceeccbbc0ae6f0c36c24ef593082d1f36ca0f5a Mon Sep 17 00:00:00 2001 From: Cheng Sheng Date: Sun, 3 Mar 2024 22:38:32 +0100 Subject: [PATCH] open_tcp_server_transport: allow explicit sock as input. When a user doesn't need an exact port, but cares more about getting SOME unused port, they can do: * Create a socket outside with port=None or port=0. * Use socket.getsockname()[1] to get the allocated port and pass to the TCP client somehow. * Use the created socket to create a TCP server transport. Use-case: unit-testing embedded software that implements a BLE host. The controller will be a Bumble controller, connected to the host via a TCP channel. * The host will have a TCP-client HCI transport for testing. * The pytest setup code will allocate the TCP server and pass the port number to the host. Also add some unittests with python mock. --- .gitignore | 2 + bumble/transport/tcp_server.py | 29 +++++++++++--- tests/transport_tcp_server_test.py | 64 ++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 tests/transport_tcp_server_test.py diff --git a/.gitignore b/.gitignore index 830ec1a4..f746ac5f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ dist/ docs/mkdocs/site test-results.xml __pycache__ +# Vim +.*.sw* # generated by setuptools_scm bumble/_version.py .vscode/launch.json diff --git a/bumble/transport/tcp_server.py b/bumble/transport/tcp_server.py index 77d03046..8991ead0 100644 --- a/bumble/transport/tcp_server.py +++ b/bumble/transport/tcp_server.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import logging +import socket from .common import Transport, StreamPacketSource @@ -28,6 +29,12 @@ # ----------------------------------------------------------------------------- + +# A pass-through function to ease mock testing. +async def _create_server(*args, **kw_args): + await asyncio.get_running_loop().create_server(*args, **kw_args) + + async def open_tcp_server_transport(spec: str) -> Transport: ''' Open a TCP server transport. @@ -38,7 +45,22 @@ async def open_tcp_server_transport(spec: str) -> Transport: Example: _:9001 ''' + local_host, local_port = spec.split(':') + return await _open_tcp_server_transport_impl( + host=local_host if local_host != '_' else None, port=int(local_port) + ) + + +async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport: + ''' + Open a TCP server transport with an existing socket. + + One reason to use this variant is to let python pick an unused port. + ''' + return await _open_tcp_server_transport_impl(sock=sock) + +async def _open_tcp_server_transport_impl(**kwargs) -> Transport: class TcpServerTransport(Transport): async def close(self): await super().close() @@ -77,13 +99,10 @@ def on_packet(self, packet): else: logger.debug('no client, dropping packet') - local_host, local_port = spec.split(':') packet_source = StreamPacketSource() packet_sink = TcpServerPacketSink() - await asyncio.get_running_loop().create_server( - lambda: TcpServerProtocol(packet_source, packet_sink), - host=local_host if local_host != '_' else None, - port=int(local_port), + await _create_server( + lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs ) return TcpServerTransport(packet_source, packet_sink) diff --git a/tests/transport_tcp_server_test.py b/tests/transport_tcp_server_test.py new file mode 100644 index 00000000..a5f015d9 --- /dev/null +++ b/tests/transport_tcp_server_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import pytest +import socket +import unittest +from unittest.mock import ANY, patch + +from bumble.transport.tcp_server import ( + open_tcp_server_transport, + open_tcp_server_transport_with_socket, +) + + +class OpenTcpServerTransportTests(unittest.TestCase): + def setUp(self): + self.patcher = patch('bumble.transport.tcp_server._create_server') + self.mock_create_server = self.patcher.start() + + def tearDown(self): + self.patcher.stop() + + def test_open_with_spec(self): + asyncio.run(open_tcp_server_transport('localhost:32100')) + self.mock_create_server.assert_awaited_once_with( + ANY, host='localhost', port=32100 + ) + + def test_open_with_port_only_spec(self): + asyncio.run(open_tcp_server_transport('_:32100')) + self.mock_create_server.assert_awaited_once_with(ANY, host=None, port=32100) + + def test_open_with_socket(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + asyncio.run(open_tcp_server_transport_with_socket(sock=sock)) + self.mock_create_server.assert_awaited_once_with(ANY, sock=sock) + + +@pytest.mark.skipif( + not os.environ.get('PYTEST_NOSKIP', 0), + reason='''\ +Not hermetic. Should only run manually with + $ PYTEST_NOSKIP=1 pytest tests +''', +) +def test_open_with_real_socket(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('localhost', 0)) + port = sock.getsockname()[1] + assert port != 0 + asyncio.run(open_tcp_server_transport_with_socket(sock=sock))