From fecdef0bd44ff3f763030eb3f495df7cd320d991 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=BChn?= Date: Mon, 14 Apr 2025 20:40:10 +0200 Subject: [PATCH 1/6] set bound and connection at once --- tardis/utilities/executors/sshexecutor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tardis/utilities/executors/sshexecutor.py b/tardis/utilities/executors/sshexecutor.py index 049b8d61..6422a66b 100644 --- a/tardis/utilities/executors/sshexecutor.py +++ b/tardis/utilities/executors/sshexecutor.py @@ -146,9 +146,12 @@ async def bounded_connection(self): async with self.lock: # check that connection has not been initialized in a different task while self._ssh_connection is None: - self._ssh_connection = await self._establish_connection() - max_session = await probe_max_session(self._ssh_connection) - self._session_bound = asyncio.Semaphore(value=max_session) + connection = await self._establish_connection() + max_session = await probe_max_session(connection) + self._ssh_connection, self._session_bound = ( + connection, + asyncio.Semaphore(value=max_session), + ) assert self._ssh_connection is not None assert self._session_bound is not None bound, session = self._session_bound, self._ssh_connection From 89e6991d5f67b4b790f2984482ef9d06b6b855e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=BChn?= Date: Mon, 14 Apr 2025 20:41:44 +0200 Subject: [PATCH 2/6] use one attributed for connection state --- tardis/utilities/executors/sshexecutor.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tardis/utilities/executors/sshexecutor.py b/tardis/utilities/executors/sshexecutor.py index 6422a66b..931b21d9 100644 --- a/tardis/utilities/executors/sshexecutor.py +++ b/tardis/utilities/executors/sshexecutor.py @@ -96,10 +96,10 @@ def __init__(self, **parameters): self._parameters["client_factory"] = partial( MFASSHClient, mfa_config=mfa_config ) - # the current SSH connection or None if it must be (re-)established - self._ssh_connection: Optional[asyncssh.SSHClientConnection] = None - # the bound on MaxSession running concurrently - self._session_bound: Optional[asyncio.Semaphore] = None + # the current SSH connection and bound unless it must be (re-)established + self._ssh_connection: ( + "tuple[asyncssh.SSHClientConnection, asyncio.Semaphore] | None" + ) = None self._lock = None async def _establish_connection(self): @@ -123,7 +123,10 @@ def _handle_broken_ssh_connection( ): # clear broken connection to get it replaced # by a new connection during next command - if ssh_connection is self._ssh_connection: + if ( + self._ssh_connection is not None + and ssh_connection is self._ssh_connection[0] + ): self._ssh_connection = None raise CommandExecutionFailure( message=(f"Could not run command {command} due to a connection loss!"), @@ -148,13 +151,12 @@ async def bounded_connection(self): while self._ssh_connection is None: connection = await self._establish_connection() max_session = await probe_max_session(connection) - self._ssh_connection, self._session_bound = ( + self._ssh_connection = ( connection, asyncio.Semaphore(value=max_session), ) assert self._ssh_connection is not None - assert self._session_bound is not None - bound, session = self._session_bound, self._ssh_connection + session, bound = self._ssh_connection async with bound: yield session From c5f80e37861f2d28c03e5267cc7f18885d5d7dd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=BChn?= Date: Mon, 14 Apr 2025 20:41:56 +0200 Subject: [PATCH 3/6] annotation --- tardis/utilities/executors/sshexecutor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tardis/utilities/executors/sshexecutor.py b/tardis/utilities/executors/sshexecutor.py index 931b21d9..399fd64d 100644 --- a/tardis/utilities/executors/sshexecutor.py +++ b/tardis/utilities/executors/sshexecutor.py @@ -119,7 +119,7 @@ def _handle_broken_ssh_connection( self, ssh_connection: asyncssh.SSHClientConnection, command: str, - chained_exception: Exception = None, + chained_exception: "Exception | None" = None, ): # clear broken connection to get it replaced # by a new connection during next command From cab354ef2f8c341710c05e842d32f57da79376de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=BChn?= Date: Mon, 14 Apr 2025 20:48:38 +0200 Subject: [PATCH 4/6] adjust test to new layout --- tests/utilities_t/executors_t/test_sshexecutor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities_t/executors_t/test_sshexecutor.py b/tests/utilities_t/executors_t/test_sshexecutor.py index 5c6930d6..4efe5e54 100644 --- a/tests/utilities_t/executors_t/test_sshexecutor.py +++ b/tests/utilities_t/executors_t/test_sshexecutor.py @@ -199,7 +199,7 @@ async def force_connection(): self.assertIsNone(self.executor._ssh_connection) run_async(force_connection) - self.assertIsInstance(self.executor._ssh_connection, MockConnection) + self.assertIsInstance(self.executor._ssh_connection[0], MockConnection) current_ssh_connection = self.executor._ssh_connection run_async(force_connection) # make sure the connection is not needlessly replaced From 6f173bf2c159506f7ef067c2232581d2cd6093a2 Mon Sep 17 00:00:00 2001 From: Manuel Giffels Date: Tue, 15 Apr 2025 09:21:06 +0200 Subject: [PATCH 5/6] Add unittest for race condition mentioned in #369 --- CONTRIBUTORS | 1 + docs/source/changelog.rst | 4 +-- .../executors_t/test_sshexecutor.py | 29 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index a599b2c9..1ba2aa7a 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -13,6 +13,7 @@ R. Florian von Cube Benjamin Rottler Sebastian Wozniewski mschnepf +Max Kühn swozniewski Alexander Haas <104835302+haasal@users.noreply.github.com> Dirk Sammel diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 8ce7337d..42055a7e 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,4 +1,4 @@ -.. Created by changelog.py at 2025-04-08, command +.. Created by changelog.py at 2025-04-16, command '/Users/giffler/.cache/pre-commit/repoecmh3ah8/py_env-python3.12/bin/changelog docs/source/changes compile --categories Added Changed Fixed Security Deprecated --output=docs/source/changelog.rst' based on the format of 'https://keepachangelog.com/' @@ -6,7 +6,7 @@ CHANGELOG ######### -[Unreleased] - 2025-04-08 +[Unreleased] - 2025-04-16 ========================= Added diff --git a/tests/utilities_t/executors_t/test_sshexecutor.py b/tests/utilities_t/executors_t/test_sshexecutor.py index 4efe5e54..5b51945c 100644 --- a/tests/utilities_t/executors_t/test_sshexecutor.py +++ b/tests/utilities_t/executors_t/test_sshexecutor.py @@ -205,6 +205,35 @@ async def force_connection(): # make sure the connection is not needlessly replaced self.assertEqual(self.executor._ssh_connection, current_ssh_connection) + def test_connection_race(self): + # see https://github.com/MatterMiners/tardis/issues/369 + waiter = asyncio.Event() + + async def mocked_probe_max_session(connection): + await waiter.wait() + return 10 + + async def run_bounded_connection(): + async with self.executor.bounded_connection as connection: + return connection + + async def run_race_condition(): + first_connection = asyncio.ensure_future(run_bounded_connection()) + await asyncio.sleep(0.1) # give some time to hit the waiter + self.assertIsNone(self.executor._ssh_connection) + second_connection = asyncio.ensure_future(run_bounded_connection()) + await asyncio.sleep(0.1) # give some time to schedule the second tasks + waiter.set() + # check that no new connection is established + self.assertEqual(await first_connection, await second_connection) + + # monkey patch prob session + with patch( + "tardis.utilities.executors.sshexecutor.probe_max_session", + mocked_probe_max_session, + ): + run_async(run_race_condition) + def test_lock(self): self.assertIsInstance(self.executor.lock, asyncio.Lock) From 728f7feec48b97d8a09e0b5840764522c6586fb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20K=C3=BChn?= Date: Wed, 16 Apr 2025 11:27:38 +0200 Subject: [PATCH 6/6] use NamedTuple for connection state Co-authored-by: Manuel Giffels --- tardis/utilities/executors/sshexecutor.py | 36 +++++++++++-------- .../executors_t/test_sshexecutor.py | 12 ++++--- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/tardis/utilities/executors/sshexecutor.py b/tardis/utilities/executors/sshexecutor.py index 399fd64d..ab59a6ec 100644 --- a/tardis/utilities/executors/sshexecutor.py +++ b/tardis/utilities/executors/sshexecutor.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, NamedTuple from ...configuration.utilities import enable_yaml_load from ...exceptions.tardisexceptions import TardisAuthError from ...exceptions.executorexceptions import CommandExecutionFailure @@ -86,6 +86,15 @@ async def kbdint_challenge_received( raise TardisAuthError(msg) from ke +class ConnectionState(NamedTuple): + """State associated with an active SSH connection""" + + #: the SSH connection itself + connection: asyncssh.SSHClientConnection + #: bound on concurrent sessions over the connection + bound: asyncio.Semaphore + + @enable_yaml_load("!SSHExecutor") @yaml_tag(eager=True) class SSHExecutor(Executor): @@ -96,10 +105,8 @@ def __init__(self, **parameters): self._parameters["client_factory"] = partial( MFASSHClient, mfa_config=mfa_config ) - # the current SSH connection and bound unless it must be (re-)established - self._ssh_connection: ( - "tuple[asyncssh.SSHClientConnection, asyncio.Semaphore] | None" - ) = None + # the current SSH connection unless it must be (re-)established + self._connection_state: "ConnectionState | None" = None self._lock = None async def _establish_connection(self): @@ -124,10 +131,10 @@ def _handle_broken_ssh_connection( # clear broken connection to get it replaced # by a new connection during next command if ( - self._ssh_connection is not None - and ssh_connection is self._ssh_connection[0] + self._connection_state is not None + and ssh_connection is self._connection_state.connection ): - self._ssh_connection = None + self._connection_state = None raise CommandExecutionFailure( message=(f"Could not run command {command} due to a connection loss!"), exit_code=255, @@ -145,18 +152,17 @@ async def bounded_connection(self): :py:class:`~asyncssh.SSHClientConnection` so that only `MaxSessions` commands run at once. """ - if self._ssh_connection is None: + if self._connection_state is None: async with self.lock: # check that connection has not been initialized in a different task - while self._ssh_connection is None: + while self._connection_state is None: connection = await self._establish_connection() max_session = await probe_max_session(connection) - self._ssh_connection = ( - connection, - asyncio.Semaphore(value=max_session), + self._connection_state = ConnectionState( + connection, asyncio.Semaphore(value=max_session) ) - assert self._ssh_connection is not None - session, bound = self._ssh_connection + assert self._connection_state is not None + session, bound = self._connection_state async with bound: yield session diff --git a/tests/utilities_t/executors_t/test_sshexecutor.py b/tests/utilities_t/executors_t/test_sshexecutor.py index 5b51945c..e81127f9 100644 --- a/tests/utilities_t/executors_t/test_sshexecutor.py +++ b/tests/utilities_t/executors_t/test_sshexecutor.py @@ -197,13 +197,15 @@ async def force_connection(): async with self.executor.bounded_connection as connection: return connection - self.assertIsNone(self.executor._ssh_connection) + self.assertIsNone(self.executor._connection_state) run_async(force_connection) - self.assertIsInstance(self.executor._ssh_connection[0], MockConnection) - current_ssh_connection = self.executor._ssh_connection + self.assertIsInstance( + self.executor._connection_state.connection, MockConnection + ) + current_ssh_connection = self.executor._connection_state run_async(force_connection) # make sure the connection is not needlessly replaced - self.assertEqual(self.executor._ssh_connection, current_ssh_connection) + self.assertEqual(self.executor._connection_state, current_ssh_connection) def test_connection_race(self): # see https://github.com/MatterMiners/tardis/issues/369 @@ -220,7 +222,7 @@ async def run_bounded_connection(): async def run_race_condition(): first_connection = asyncio.ensure_future(run_bounded_connection()) await asyncio.sleep(0.1) # give some time to hit the waiter - self.assertIsNone(self.executor._ssh_connection) + self.assertIsNone(self.executor._connection_state) second_connection = asyncio.ensure_future(run_bounded_connection()) await asyncio.sleep(0.1) # give some time to schedule the second tasks waiter.set()