Skip to content

Commit

Permalink
Recover unreachable instances (#2043)
Browse files Browse the repository at this point in the history
* Check instances regardless of termination_policy -- gives unreachable
  instances a chance to become reachable again
* Stop a previous job container possibly still running on the instance when
  submitting a new one

Fixes: #2041
  • Loading branch information
un-def authored Dec 2, 2024
1 parent 25d87c3 commit a286816
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
and instance.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE
and instance.job_id is None
):
await _terminate_idle_instance(instance)
elif instance.status == InstanceStatus.PENDING:
# terminates the instance and sets instance.status to TERMINATED (along other fields)
# if termination_idle_time is reached, noop otherwise
await _maybe_terminate_idle_instance(instance)
if instance.status == InstanceStatus.PENDING:
if instance.remote_connection_info is not None:
await _add_remote(instance)
else:
Expand All @@ -180,7 +182,7 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
await session.commit()


async def _terminate_idle_instance(instance: InstanceModel):
async def _maybe_terminate_idle_instance(instance: InstanceModel):
current_time = get_current_datetime()
idle_duration = _get_instance_idle_duration(instance)
idle_seconds = instance.termination_idle_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _process_provisioning_with_shim(
for volume, volume_mount in zip(volumes, volume_mounts):
volume_mount.name = volume.name

shim_client.submit(
submitted = shim_client.submit(
username=username,
password=password,
image_name=job_spec.image_name,
Expand All @@ -444,6 +444,20 @@ def _process_provisioning_with_shim(
volumes=volumes,
instance_mounts=instance_mounts,
)
if not submitted:
# This can happen when we lost connection to the runner (e.g., network issues), marked
# the job as failed, released the instance (status=BUSY->IDLE, job_id={id}->None),
# but the job container is in fact alive, running the previous job. As we force-stop
# the container via shim API when cancelling the current job anyway (when either the user
# aborts the submission process or the submission deadline is reached), it's safe to kill
# the previous job container now, making the shim available (state=running->pending)
# for the next try.
logger.warning(
"%s: failed to sumbit, shim is already running a job, stopping it now, retry later",
fmt(job_model),
)
shim_client.stop(force=True)
return False

job_model.status = JobStatus.PULLING
logger.info("%s: now is %s", fmt(job_model), job_model.status.name)
Expand Down
8 changes: 8 additions & 0 deletions src/dstack/_internal/server/services/runner/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from http import HTTPStatus
from typing import BinaryIO, Dict, List, Optional, Union

import requests
Expand Down Expand Up @@ -154,6 +155,10 @@ def submit(
volumes: List[Volume],
instance_mounts: List[InstanceMountPoint],
):
"""
Returns `True` if submitted and `False` if the shim already has a job (`409 Conflict`).
Other error statuses raise an exception.
"""
_shm_size = int(shm_size * 1024 * 1024 * 1024) if shm_size else 0
volume_infos = [_volume_to_shim_volume_info(v) for v in volumes]
post_body = TaskConfigBody(
Expand All @@ -176,7 +181,10 @@ def submit(
json=post_body,
timeout=REQUEST_TIMEOUT,
)
if resp.status_code == HTTPStatus.CONFLICT:
return False
resp.raise_for_status()
return True

def stop(self, force: bool = False):
body = StopBody(force=force)
Expand Down
9 changes: 7 additions & 2 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DEFAULT_POOL_NAME,
DEFAULT_POOL_TERMINATION_IDLE_TIME,
Profile,
TerminationPolicy,
)
from dstack._internal.core.models.repos.base import RepoType
from dstack._internal.core.models.repos.local import LocalRunRepoData
Expand Down Expand Up @@ -443,6 +444,7 @@ async def create_instance(
pool: PoolModel,
fleet: Optional[FleetModel] = None,
status: InstanceStatus = InstanceStatus.IDLE,
unreachable: bool = False,
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
finished_at: Optional[datetime] = None,
spot: bool = False,
Expand All @@ -453,6 +455,8 @@ async def create_instance(
job: Optional[JobModel] = None,
instance_num: int = 0,
backend: BackendType = BackendType.DATACRUNCH,
termination_policy: Optional[TerminationPolicy] = None,
termination_idle_time: int = DEFAULT_POOL_TERMINATION_IDLE_TIME,
region: str = "eu-west",
remote_connection_info: Optional[RemoteConnectionInfo] = None,
) -> InstanceModel:
Expand Down Expand Up @@ -523,7 +527,7 @@ async def create_instance(
fleet=fleet,
project=project,
status=status,
unreachable=False,
unreachable=unreachable,
created_at=created_at,
started_at=created_at,
finished_at=finished_at,
Expand All @@ -532,7 +536,8 @@ async def create_instance(
price=1,
region=region,
backend=backend,
termination_idle_time=DEFAULT_POOL_TERMINATION_IDLE_TIME,
termination_policy=termination_policy,
termination_idle_time=termination_idle_time,
profile=profile.json(),
requirements=requirements.json(),
instance_configuration=instance_configuration.json(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,71 @@ async def test_check_shim_terminate_instance_by_dedaline(self, test_db, session:
assert instance.termination_reason == "Termination deadline"
assert instance.health_status == health_status

@pytest.mark.asyncio
@pytest.mark.parametrize(
["termination_policy", "has_job"],
[
pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, False, id="destroy-no-job"),
pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, True, id="destroy-with-job"),
pytest.param(TerminationPolicy.DONT_DESTROY, False, id="dont-destroy-no-job"),
pytest.param(TerminationPolicy.DONT_DESTROY, True, id="dont-destroy-with-job"),
],
)
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_check_shim_process_ureachable_state(
self,
test_db,
session: AsyncSession,
termination_policy: TerminationPolicy,
has_job: bool,
):
# see https://github.com/dstackai/dstack/issues/2041
project = await create_project(session=session)
pool = await create_pool(session, project)
if has_job:
user = await create_user(session=session)
repo = await create_repo(
session=session,
project_id=project.id,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=JobStatus.SUBMITTED,
)
else:
job = None
instance = await create_instance(
session,
project,
pool,
created_at=get_current_datetime(),
termination_policy=termination_policy,
status=InstanceStatus.IDLE,
unreachable=True,
job=job,
)

await session.commit()

with patch(
"dstack._internal.server.background.tasks.process_instances._instance_healthcheck"
) as healthcheck:
healthcheck.return_value = HealthStatus(healthy=True, reason="OK")
await process_instances()

await session.refresh(instance)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert not instance.unreachable


class TestTerminateIdleTime:
@pytest.mark.asyncio
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -36,6 +36,7 @@
get_run_spec,
get_volume_configuration,
)
from dstack._internal.utils.common import get_current_datetime


def get_job_provisioning_data(dockerized: bool) -> JobProvisioningData:
Expand Down Expand Up @@ -371,3 +372,52 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
assert job.remove_at is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_provisioning_shim_force_stop_if_already_running(
self,
monkeypatch: pytest.MonkeyPatch,
test_db,
session: AsyncSession,
):
project = await create_project(session=session)
user = await create_user(session=session)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
run_spec.configuration.image = "debian"
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
run_name="test-run",
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
status=JobStatus.PROVISIONING,
job_provisioning_data=get_job_provisioning_data(dockerized=True),
submitted_at=get_current_datetime(),
)
monkeypatch.setattr(
"dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock())
)
shim_client_mock = Mock()
monkeypatch.setattr(
"dstack._internal.server.services.runner.client.ShimClient",
Mock(return_value=shim_client_mock),
)
shim_client_mock.healthcheck.return_value = HealthcheckResponse(
service="dstack-shim", version="0.0.1.dev2"
)
shim_client_mock.submit.return_value = False

await process_running_jobs()

shim_client_mock.healthcheck.assert_called_once()
shim_client_mock.submit.assert_called_once()
shim_client_mock.stop.assert_called_once_with(force=True)
await session.refresh(job)
assert job.status == JobStatus.PROVISIONING

0 comments on commit a286816

Please sign in to comment.