Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 79 additions & 27 deletions python/ray/data/tests/test_issue_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import pytest

import ray
from ray.data._internal.execution.interfaces.physical_operator import (
OpTask,
PhysicalOperator,
RefBundle,
)
from ray.data._internal.execution.operators.input_data_buffer import (
InputDataBuffer,
)
Expand All @@ -22,10 +27,44 @@
from ray.data._internal.issue_detection.detectors.high_memory_detector import (
HighMemoryIssueDetector,
)
from ray.data.block import BlockMetadata
from ray.data.context import DataContext
from ray.tests.conftest import * # noqa


class FakeOpTask(OpTask):
"""A fake OpTask for testing purposes."""

def __init__(self, task_index: int):
super().__init__(task_index)

def get_waitable(self):
"""Return a dummy waitable."""
return ray.put(None)


class FakeOperator(PhysicalOperator):
def __init__(self, name: str, data_context: DataContext):
super().__init__(name=name, input_dependencies=[], data_context=data_context)

def _add_input_inner(self, refs: RefBundle, input_index: int) -> None:
pass

def has_next(self) -> bool:
return False

def _get_next_inner(self) -> RefBundle:
assert False

def get_stats(self):
return {}

def get_active_tasks(self):
# Return active tasks based on what's in _running_tasks
# This ensures has_execution_finished() works correctly
return [FakeOpTask(task_idx) for task_idx in self.metrics._running_tasks]


class TestHangingExecutionIssueDetector:
def test_hanging_detector_configuration(self, restore_data_context):
"""Test hanging detector configuration and initialization."""
Expand Down Expand Up @@ -56,7 +95,7 @@ def test_hanging_detector_configuration(self, restore_data_context):
"ray.data._internal.execution.interfaces.op_runtime_metrics.TaskDurationStats"
)
def test_basic_hanging_detection(
self, mock_stats_cls, ray_start_2_cpus, restore_data_context
self, mock_stats_cls, ray_start_regular_shared, restore_data_context
):
# Set up logging capture
log_capture = io.StringIO()
Expand Down Expand Up @@ -99,41 +138,54 @@ def f2(x):
log_output = log_capture.getvalue()
assert re.search(warn_msg, log_output) is not None, log_output

@patch("time.perf_counter")
def test_hanging_detector_detects_issues(
self, caplog, propagate_logs, restore_data_context
self, mock_perf_counter, ray_start_regular_shared
):
"""Test hanging detector adaptive thresholds with real Ray Data pipelines and extreme configurations."""
"""Test that the hanging detector correctly identifies tasks that exceed the adaptive threshold."""

ctx = DataContext.get_current()
# Configure hanging detector with extreme std_factor values
ctx.issue_detectors_config.hanging_detector_config = (
HangingExecutionIssueDetectorConfig(
op_task_stats_min_count=1,
op_task_stats_std_factor=1,
detection_time_interval_s=0,
)
config = HangingExecutionIssueDetectorConfig(
op_task_stats_min_count=1,
op_task_stats_std_factor=1,
detection_time_interval_s=0,
)
op = FakeOperator("TestOperator", DataContext.get_current())
detector = HangingExecutionIssueDetector(
dataset_id="test_dataset", operators=[op], config=config
)

# Create a pipeline with many small blocks to ensure concurrent tasks
def sleep_task(x):
if x["id"] == 2:
# Issue detection is based on the mean + stdev. One of the tasks must take
# awhile, so doing it just for one of the rows.
time.sleep(1)
return x
# Create a simple RefBundle for testing
block_ref = ray.put([{"id": 0}])
metadata = BlockMetadata(
num_rows=1, size_bytes=1, exec_stats=None, input_files=None
)
input_bundle = RefBundle(
blocks=((block_ref, metadata),), owns_blocks=True, schema=None
)

with caplog.at_level(logging.WARNING):
ray.data.range(3, override_num_blocks=3).map(
sleep_task, concurrency=1
).materialize()
mock_perf_counter.return_value = 0.0

# Check if hanging detection occurred
hanging_detected = (
"has been running for" in caplog.text
and "longer than the average task duration" in caplog.text
)
# Submit three tasks. Two of them finish immediately, while the third one hangs.
op.metrics.on_task_submitted(0, input_bundle)
op.metrics.on_task_submitted(1, input_bundle)
op.metrics.on_task_submitted(2, input_bundle)
op.metrics.on_task_finished(0, exception=None)
op.metrics.on_task_finished(1, exception=None)

# Start detecting
issues = detector.detect()
assert len(issues) == 0

# Set the perf_counter to trigger the issue detection
mock_perf_counter.return_value = 10.0

assert hanging_detected, caplog.text
# On the second detect() call, the hanging task should be detected
issues = detector.detect()
assert len(issues) > 0, "Expected hanging issue to be detected"
assert issues[0].issue_type.value == "hanging"
assert "has been running for" in issues[0].message
assert "longer than the average task duration" in issues[0].message


@pytest.mark.parametrize(
Expand Down