Skip to content

Commit f5abbb8

Browse files
authored
[train] after_worker_group_poll_status errors result in ControllerError (#57869)
# Summary We observed that whenever `after_worker_group_poll_status` raised an exception, the Train Run would fail ungracefully and show up as `ABORTED` in the dashboard. This happened in the following situations: 1) Different workers report remote checkpoints with different paths -> `(TrainController pid=46993) RuntimeError: The storage path of the checkpoints in the training results is not the same. This means the checkpoints are not consistent. Got a mix of the following checkpoint paths: {'/tmp/tmpl95kv7ax', '/tmp/tmp__8e6etk'} ` -> `ABORTED` Train Run 2) `ray.train.report("loss": ...}, checkpoint=checkpoint)` in `train_func` -> `TypeError: Object of type 'ellipsis' is not JSON serializable` in `CheckpointManager._save_state` -> `ABORTED` Train Run This PR catches these exceptions, wraps them in a `ControllerError`, and goes through the `FailurePolicy`, ultimately resulting in an `ERRORED` Train Run, which is more intuitive because it happened due to an error in the training workers (`The Train run failed due to an error in the training workers.` is the comment associated with `RunStatus.ERRORED`). I considered implementing a more general solution that caught all `WorkerGroupCallback` errors and resurfaced them as `ControllerError`s, but decided against it because: * Callbacks occur in many different places and we might want to add custom try/catch logic in each case. * `after_worker_group_poll_status` is the only offender so far and most of its errors are from user mistakes; other callback errors could be legitimate bugs that should result in `ABORTED` # Testing Unit tests --------- Signed-off-by: Timothy Seah <[email protected]>
1 parent 91685a7 commit f5abbb8

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

python/ray/train/v2/_internal/execution/controller/controller.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,16 @@ async def _step(self) -> TrainControllerLoopIterationResult:
408408
assert isinstance(controller_state.scaling_decision, ResizeDecision)
409409
return self._execute_resize_decision(controller_state.scaling_decision)
410410
elif isinstance(controller_state, RunningState):
411-
worker_group_status: WorkerGroupPollStatus = await self._poll_workers()
411+
try:
412+
worker_group_status: WorkerGroupPollStatus = await self._poll_workers()
413+
except Exception as e:
414+
training_failed_error = ControllerError(e)
415+
failure_decision = self._failure_policy.make_decision(
416+
training_failed_error=training_failed_error,
417+
)
418+
return self._execute_failure_decision(
419+
failure_decision, training_failed_error=training_failed_error
420+
)
412421

413422
if worker_group_status.finished and not worker_group_status.errors:
414423
return TrainControllerLoopIterationResult(

python/ray/train/v2/tests/test_controller.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock, create_autospec
1+
from unittest.mock import create_autospec
22

33
import pytest
44

@@ -27,6 +27,7 @@
2727
NoopDecision,
2828
ResizeDecision,
2929
)
30+
from ray.train.v2._internal.execution.worker_group import WorkerGroupPollStatus
3031
from ray.train.v2.api.config import ScalingConfig
3132
from ray.train.v2.tests.util import (
3233
DummyObjectRefWrapper,
@@ -45,6 +46,8 @@ def patch_worker_group(monkeypatch):
4546
# Make polling interval 0 to speed up tests
4647
monkeypatch.setenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, "0")
4748
yield
49+
DummyWorkerGroup.set_poll_failure(None)
50+
DummyWorkerGroup.set_start_failure(None)
4851

4952

5053
@pytest.fixture(autouse=True)
@@ -167,7 +170,7 @@ async def test_failure_handling():
167170
await controller._run_control_loop_iteration()
168171
assert isinstance(controller.get_state(), RunningState)
169172

170-
controller.get_worker_group().error_worker(3)
173+
DummyWorkerGroup.set_poll_failure(RuntimeError("Simulated poll failure"))
171174
failure_policy.queue_decision(FailureDecision.RAISE)
172175
await controller._run_control_loop_iteration()
173176
assert isinstance(controller.get_state(), ErroredState)
@@ -177,7 +180,7 @@ async def test_failure_handling():
177180
"error_type", [WorkerGroupStartupFailedError, WorkerGroupStartupTimeoutError(2)]
178181
)
179182
@pytest.mark.asyncio
180-
async def test_worker_group_start_failure(monkeypatch, error_type):
183+
async def test_worker_group_start_failure(error_type):
181184
"""Check that controller can gracefully handle worker group start failures."""
182185
scaling_policy = MockScalingPolicy(scaling_config=ScalingConfig())
183186
failure_policy = MockFailurePolicy(failure_config=None)
@@ -189,7 +192,6 @@ async def test_worker_group_start_failure(monkeypatch, error_type):
189192
failure_policy=failure_policy,
190193
)
191194
DummyWorkerGroup.set_start_failure(error_type)
192-
monkeypatch.setattr(TrainController, "worker_group_cls", DummyWorkerGroup)
193195

194196
assert isinstance(controller.get_state(), InitializingState)
195197

@@ -208,7 +210,6 @@ async def test_worker_group_start_failure(monkeypatch, error_type):
208210

209211
# Let the worker group start successfully the 2nd time.
210212
DummyWorkerGroup.set_start_failure(None)
211-
monkeypatch.setattr(TrainController, "worker_group_cls", DummyWorkerGroup)
212213
scaling_policy.queue_recovery_decision(
213214
ResizeDecision(num_workers=2, resources_per_worker={})
214215
)
@@ -239,7 +240,10 @@ async def sleep_mock(t):
239240
failure_policy=None,
240241
)
241242
# Mock worker group to avoid actual polling
242-
controller._worker_group = MagicMock()
243+
controller._worker_group = create_autospec(DummyWorkerGroup, instance=True)
244+
controller._worker_group.poll_status.return_value = WorkerGroupPollStatus(
245+
worker_statuses={}
246+
)
243247

244248
num_polls = 5
245249
for _ in range(num_polls):

python/ray/train/v2/tests/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
class DummyWorkerGroup(WorkerGroup):
4646

4747
_start_failure = None
48+
_poll_failure = None
4849

4950
# TODO: Clean this up and use Mocks instead.
5051
def __init__(
@@ -58,6 +59,8 @@ def __init__(
5859
self._worker_statuses = {}
5960

6061
def poll_status(self, *args, **kwargs) -> WorkerGroupPollStatus:
62+
if self._poll_failure:
63+
raise self._poll_failure
6164
return WorkerGroupPollStatus(
6265
worker_statuses=self._worker_statuses,
6366
)
@@ -97,6 +100,10 @@ def finish_worker(self, worker_index):
97100
def set_start_failure(cls, start_failure):
98101
cls._start_failure = start_failure
99102

103+
@classmethod
104+
def set_poll_failure(cls, poll_failure):
105+
cls._poll_failure = poll_failure
106+
100107

101108
class MockScalingPolicy(ScalingPolicy):
102109
def __init__(self, scaling_config):

0 commit comments

Comments
 (0)