Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ determines how long to wait for the connection to be established.

Default: ``10``.

The low-level socket connect timeout used for the raw SSH probe can be tuned
with the ``SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT`` environment variable (seconds;
default: ``1``). This controls how long each socket connection attempt waits
before SkyPilot falls back to retrying.

.. _config-yaml-aws:

``aws``
Expand Down
24 changes: 23 additions & 1 deletion sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,26 @@ def _shlex_join(command: List[str]) -> str:
return ' '.join(shlex.quote(arg) for arg in command)


def _socket_connect_timeout_seconds() -> float:
"""Returns the socket connect timeout used for SSH readiness probing."""
env_value = os.getenv('SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT')
if env_value is None:
return 1.0
try:
timeout = float(env_value)
except ValueError as e:
message = ('Invalid SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT value '
f'{env_value!r}; must be a positive number.')
with ux_utils.print_exception_no_traceback():
raise ValueError(message) from e
if timeout <= 0:
message = ('Invalid SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT value '
f'{env_value!r}; must be a positive number.')
with ux_utils.print_exception_no_traceback():
raise ValueError(message)
return timeout


def _wait_ssh_connection_direct(ip: str,
ssh_port: int,
ssh_user: str,
Expand All @@ -314,7 +334,9 @@ def _wait_ssh_connection_direct(ip: str,
try:
success = False
stderr = ''
with socket.create_connection((ip, ssh_port), timeout=1) as s:
connect_timeout = _socket_connect_timeout_seconds()
with socket.create_connection((ip, ssh_port),
timeout=connect_timeout) as s:
if s.recv(100).startswith(b'SSH'):
# Wait for SSH being actually ready, otherwise we may get the
# following error:
Expand Down
83 changes: 83 additions & 0 deletions tests/unit_tests/test_provisioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from unittest import mock

import pytest

from sky.provision import provisioner


class _DummySocket:
"""Simple context manager socket mock for SSH probing tests."""

def __init__(self, recv_payload: bytes = b'SSH'):
self._recv_payload = recv_payload

def recv(self, _num_bytes: int) -> bytes:
return self._recv_payload

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return False


def _setup_successful_socket(monkeypatch):
dummy_socket = _DummySocket()
create_connection_mock = mock.MagicMock(return_value=dummy_socket)
monkeypatch.setattr(provisioner.socket, 'create_connection',
create_connection_mock)
monkeypatch.setattr(provisioner, '_wait_ssh_connection_indirect',
mock.MagicMock(return_value=(True, '')))
return create_connection_mock


def test_wait_ssh_connection_direct_uses_default_timeout(monkeypatch):
monkeypatch.delenv('SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT', raising=False)
create_connection_mock = _setup_successful_socket(monkeypatch)

success, stderr = provisioner._wait_ssh_connection_direct(
'1.2.3.4',
22,
ssh_user='user',
ssh_private_key='key',
ssh_probe_timeout=5)

assert success
assert stderr == ''
assert create_connection_mock.call_args.kwargs['timeout'] == 1.0


def test_wait_ssh_connection_direct_env_override(monkeypatch):
monkeypatch.setenv('SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT', '2.5')
create_connection_mock = _setup_successful_socket(monkeypatch)

success, _ = provisioner._wait_ssh_connection_direct('5.6.7.8',
22,
ssh_user='user',
ssh_private_key='key',
ssh_probe_timeout=5)

assert success
assert create_connection_mock.call_args.kwargs['timeout'] == pytest.approx(
2.5)


@pytest.mark.parametrize('value', ['abc', '0', '-1'])
def test_wait_ssh_connection_direct_invalid_env(monkeypatch, value):
monkeypatch.setenv('SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT', value)
create_connection_mock = mock.MagicMock()
monkeypatch.setattr(provisioner.socket, 'create_connection',
create_connection_mock)
monkeypatch.setattr(provisioner, '_wait_ssh_connection_indirect',
mock.MagicMock(return_value=(True, '')))

success, stderr = provisioner._wait_ssh_connection_direct(
'9.9.9.9',
22,
ssh_user='user',
ssh_private_key='key',
ssh_probe_timeout=5)

assert not success
assert 'Invalid SKYPILOT_SSH_SOCKET_CONNECT_TIMEOUT' in stderr
create_connection_mock.assert_not_called()
Loading