Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ R. Florian von Cube <florian.voncube@gmail.com>
Benjamin Rottler <benjamin.rottler@cern.ch>
Sebastian Wozniewski <sebastian.wozniewski@uni-goettingen.de>
mschnepf <matthias.schnepf@kit.edu>
Max Kühn <maxfischer2781@gmail.com>
swozniewski <sebastian.wozniewski@uni-goettingen.de>
Alexander Haas <104835302+haasal@users.noreply.github.com>
Dirk Sammel <dirk.sammel@cern.ch>
Expand Down
4 changes: 2 additions & 2 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
.. Created by changelog.py at 2025-04-08, command
.. Created by changelog.py at 2025-04-16, command
'/Users/giffler/.cache/pre-commit/repoecmh3ah8/py_env-python3.12/bin/changelog docs/source/changes compile --categories Added Changed Fixed Security Deprecated --output=docs/source/changelog.rst'
based on the format of 'https://keepachangelog.com/'

#########
CHANGELOG
#########

[Unreleased] - 2025-04-08
[Unreleased] - 2025-04-16
=========================

Added
Expand Down
43 changes: 27 additions & 16 deletions tardis/utilities/executors/sshexecutor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, NamedTuple
from ...configuration.utilities import enable_yaml_load
from ...exceptions.tardisexceptions import TardisAuthError
from ...exceptions.executorexceptions import CommandExecutionFailure
Expand Down Expand Up @@ -86,6 +86,15 @@ async def kbdint_challenge_received(
raise TardisAuthError(msg) from ke


class ConnectionState(NamedTuple):
"""State associated with an active SSH connection"""

#: the SSH connection itself
connection: asyncssh.SSHClientConnection
#: bound on concurrent sessions over the connection
bound: asyncio.Semaphore


@enable_yaml_load("!SSHExecutor")
@yaml_tag(eager=True)
class SSHExecutor(Executor):
Expand All @@ -96,10 +105,8 @@ def __init__(self, **parameters):
self._parameters["client_factory"] = partial(
MFASSHClient, mfa_config=mfa_config
)
# the current SSH connection or None if it must be (re-)established
self._ssh_connection: Optional[asyncssh.SSHClientConnection] = None
# the bound on MaxSession running concurrently
self._session_bound: Optional[asyncio.Semaphore] = None
# the current SSH connection unless it must be (re-)established
self._connection_state: "ConnectionState | None" = None
self._lock = None

async def _establish_connection(self):
Expand All @@ -119,12 +126,15 @@ def _handle_broken_ssh_connection(
self,
ssh_connection: asyncssh.SSHClientConnection,
command: str,
chained_exception: Exception = None,
chained_exception: "Exception | None" = None,
):
# clear broken connection to get it replaced
# by a new connection during next command
if ssh_connection is self._ssh_connection:
self._ssh_connection = None
if (
self._connection_state is not None
and ssh_connection is self._connection_state.connection
):
self._connection_state = None
raise CommandExecutionFailure(
message=(f"Could not run command {command} due to a connection loss!"),
exit_code=255,
Expand All @@ -142,16 +152,17 @@ async def bounded_connection(self):
:py:class:`~asyncssh.SSHClientConnection`
so that only `MaxSessions` commands run at once.
"""
if self._ssh_connection is None:
if self._connection_state is None:
async with self.lock:
# check that connection has not been initialized in a different task
while self._ssh_connection is None:
self._ssh_connection = await self._establish_connection()
max_session = await probe_max_session(self._ssh_connection)
self._session_bound = asyncio.Semaphore(value=max_session)
assert self._ssh_connection is not None
assert self._session_bound is not None
bound, session = self._session_bound, self._ssh_connection
while self._connection_state is None:
connection = await self._establish_connection()
max_session = await probe_max_session(connection)
self._connection_state = ConnectionState(
connection, asyncio.Semaphore(value=max_session)
)
assert self._connection_state is not None
session, bound = self._connection_state
async with bound:
yield session

Expand Down
39 changes: 35 additions & 4 deletions tests/utilities_t/executors_t/test_sshexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,44 @@ async def force_connection():
async with self.executor.bounded_connection as connection:
return connection

self.assertIsNone(self.executor._ssh_connection)
self.assertIsNone(self.executor._connection_state)
run_async(force_connection)
self.assertIsInstance(self.executor._ssh_connection, MockConnection)
current_ssh_connection = self.executor._ssh_connection
self.assertIsInstance(
self.executor._connection_state.connection, MockConnection
)
current_ssh_connection = self.executor._connection_state
run_async(force_connection)
# make sure the connection is not needlessly replaced
self.assertEqual(self.executor._ssh_connection, current_ssh_connection)
self.assertEqual(self.executor._connection_state, current_ssh_connection)

def test_connection_race(self):
# see https://github.com/MatterMiners/tardis/issues/369
waiter = asyncio.Event()

async def mocked_probe_max_session(connection):
await waiter.wait()
return 10

async def run_bounded_connection():
async with self.executor.bounded_connection as connection:
return connection

async def run_race_condition():
first_connection = asyncio.ensure_future(run_bounded_connection())
await asyncio.sleep(0.1) # give some time to hit the waiter
self.assertIsNone(self.executor._connection_state)
second_connection = asyncio.ensure_future(run_bounded_connection())
await asyncio.sleep(0.1) # give some time to schedule the second tasks
waiter.set()
# check that no new connection is established
self.assertEqual(await first_connection, await second_connection)

# monkey patch prob session
with patch(
"tardis.utilities.executors.sshexecutor.probe_max_session",
mocked_probe_max_session,
):
run_async(run_race_condition)

def test_lock(self):
self.assertIsInstance(self.executor.lock, asyncio.Lock)
Expand Down