diff --git a/python/ray/data/tests/test_issue_detection.py b/python/ray/data/tests/test_issue_detection.py index e2c30d47a3ef..7a7bcfd1e07f 100644 --- a/python/ray/data/tests/test_issue_detection.py +++ b/python/ray/data/tests/test_issue_detection.py @@ -7,6 +7,11 @@ import pytest import ray +from ray.data._internal.execution.interfaces.physical_operator import ( + OpTask, + PhysicalOperator, + RefBundle, +) from ray.data._internal.execution.operators.input_data_buffer import ( InputDataBuffer, ) @@ -22,10 +27,44 @@ from ray.data._internal.issue_detection.detectors.high_memory_detector import ( HighMemoryIssueDetector, ) +from ray.data.block import BlockMetadata from ray.data.context import DataContext from ray.tests.conftest import * # noqa +class FakeOpTask(OpTask): + """A fake OpTask for testing purposes.""" + + def __init__(self, task_index: int): + super().__init__(task_index) + + def get_waitable(self): + """Return a dummy waitable.""" + return ray.put(None) + + +class FakeOperator(PhysicalOperator): + def __init__(self, name: str, data_context: DataContext): + super().__init__(name=name, input_dependencies=[], data_context=data_context) + + def _add_input_inner(self, refs: RefBundle, input_index: int) -> None: + pass + + def has_next(self) -> bool: + return False + + def _get_next_inner(self) -> RefBundle: + assert False + + def get_stats(self): + return {} + + def get_active_tasks(self): + # Return active tasks based on what's in _running_tasks + # This ensures has_execution_finished() works correctly + return [FakeOpTask(task_idx) for task_idx in self.metrics._running_tasks] + + class TestHangingExecutionIssueDetector: def test_hanging_detector_configuration(self, restore_data_context): """Test hanging detector configuration and initialization.""" @@ -56,7 +95,7 @@ def test_hanging_detector_configuration(self, restore_data_context): "ray.data._internal.execution.interfaces.op_runtime_metrics.TaskDurationStats" ) def test_basic_hanging_detection( - self, mock_stats_cls, ray_start_2_cpus, restore_data_context + self, mock_stats_cls, ray_start_regular_shared, restore_data_context ): # Set up logging capture log_capture = io.StringIO() @@ -99,41 +138,54 @@ def f2(x): log_output = log_capture.getvalue() assert re.search(warn_msg, log_output) is not None, log_output + @patch("time.perf_counter") def test_hanging_detector_detects_issues( - self, caplog, propagate_logs, restore_data_context + self, mock_perf_counter, ray_start_regular_shared ): - """Test hanging detector adaptive thresholds with real Ray Data pipelines and extreme configurations.""" + """Test that the hanging detector correctly identifies tasks that exceed the adaptive threshold.""" - ctx = DataContext.get_current() # Configure hanging detector with extreme std_factor values - ctx.issue_detectors_config.hanging_detector_config = ( - HangingExecutionIssueDetectorConfig( - op_task_stats_min_count=1, - op_task_stats_std_factor=1, - detection_time_interval_s=0, - ) + config = HangingExecutionIssueDetectorConfig( + op_task_stats_min_count=1, + op_task_stats_std_factor=1, + detection_time_interval_s=0, + ) + op = FakeOperator("TestOperator", DataContext.get_current()) + detector = HangingExecutionIssueDetector( + dataset_id="test_dataset", operators=[op], config=config ) - # Create a pipeline with many small blocks to ensure concurrent tasks - def sleep_task(x): - if x["id"] == 2: - # Issue detection is based on the mean + stdev. One of the tasks must take - # awhile, so doing it just for one of the rows. - time.sleep(1) - return x + # Create a simple RefBundle for testing + block_ref = ray.put([{"id": 0}]) + metadata = BlockMetadata( + num_rows=1, size_bytes=1, exec_stats=None, input_files=None + ) + input_bundle = RefBundle( + blocks=((block_ref, metadata),), owns_blocks=True, schema=None + ) - with caplog.at_level(logging.WARNING): - ray.data.range(3, override_num_blocks=3).map( - sleep_task, concurrency=1 - ).materialize() + mock_perf_counter.return_value = 0.0 - # Check if hanging detection occurred - hanging_detected = ( - "has been running for" in caplog.text - and "longer than the average task duration" in caplog.text - ) + # Submit three tasks. Two of them finish immediately, while the third one hangs. + op.metrics.on_task_submitted(0, input_bundle) + op.metrics.on_task_submitted(1, input_bundle) + op.metrics.on_task_submitted(2, input_bundle) + op.metrics.on_task_finished(0, exception=None) + op.metrics.on_task_finished(1, exception=None) + + # Start detecting + issues = detector.detect() + assert len(issues) == 0 + + # Set the perf_counter to trigger the issue detection + mock_perf_counter.return_value = 10.0 - assert hanging_detected, caplog.text + # On the second detect() call, the hanging task should be detected + issues = detector.detect() + assert len(issues) > 0, "Expected hanging issue to be detected" + assert issues[0].issue_type.value == "hanging" + assert "has been running for" in issues[0].message + assert "longer than the average task duration" in issues[0].message @pytest.mark.parametrize(