diff --git a/cosmos/debug.py b/cosmos/debug.py new file mode 100644 index 0000000000..0aba093b42 --- /dev/null +++ b/cosmos/debug.py @@ -0,0 +1,122 @@ +""" +Debug utilities for Cosmos. + +When debug mode is enabled via the `enable_debug_mode` setting, Cosmos will track +memory utilization during task execution and push the maximum memory usage to XCom. +""" + +from __future__ import annotations + +import os +import threading +import time +from typing import TYPE_CHECKING + +try: + import psutil +except ImportError: + raise RuntimeError( + "psutil is not available. Install `https://pypi.org/project/psutil/` to enable memory tracking." + ) + +from cosmos import settings +from cosmos.log import get_logger + +if TYPE_CHECKING: + try: + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context # type: ignore[attr-defined] + +logger = get_logger(__name__) + +# XCom key for storing maximum memory usage in debug mode +XCOM_DEBUG_MAX_MEMORY_MB_KEY = "cosmos_debug_max_memory_mb" + +# Global dictionary to store memory trackers per task +_memory_trackers: dict[str, MemoryTracker] = {} + + +class MemoryTracker: + """ + Tracks maximum RSS memory (bytes) for a process and all of its children. + Sampling-based to work across Airflow 2 & 3 without executor internals. + """ + + def __init__(self, pid: int, poll_interval: float = 0.5): + self.pid = pid + self.poll_interval = poll_interval + self.max_rss_bytes = 0 + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + """Start the memory tracking thread.""" + self._thread.start() + + def stop(self) -> None: + """Stop the memory tracking thread.""" + self._stop_event.set() + if self._thread.is_alive(): + self._thread.join(timeout=5) + + def _run(self) -> None: + """Background thread that polls memory usage.""" + try: + parent = psutil.Process(self.pid) + except psutil.NoSuchProcess: + return + + while not self._stop_event.is_set(): + rss = 0 + try: + processes = [parent] + parent.children(recursive=True) + for p in processes: + try: + rss += p.memory_info().rss + except psutil.NoSuchProcess: + continue + self.max_rss_bytes = max(self.max_rss_bytes, rss) + except psutil.NoSuchProcess: + break + + time.sleep(self.poll_interval) + + +def start_memory_tracking(context: Context) -> None: + """ + Callback to start memory tracking for a task. + + This function should be used as an `on_execute_callback` for Cosmos operators + when debug mode is enabled. + + :param context: The Airflow task context. + """ + ti = context["ti"] + task_key = f"{ti.dag_id}.{ti.task_id}.{ti.run_id}" + pid = os.getpid() + tracker = MemoryTracker(pid=pid, poll_interval=settings.debug_memory_poll_interval_seconds) + _memory_trackers[task_key] = tracker + tracker.start() + logger.debug("Started memory tracking for task %s (PID: %s)", task_key, pid) + + +def stop_memory_tracking(context: Context) -> None: + """ + Callback to stop memory tracking for a task and push the result to XCom. + + This function should be used as an `on_success_callback` or `on_failure_callback` + for Cosmos operators when debug mode is enabled. + + :param context: The Airflow task context. + """ + ti = context["ti"] + task_key = f"{ti.dag_id}.{ti.task_id}.{ti.run_id}" + tracker = _memory_trackers.pop(task_key, None) + + if tracker: + tracker.stop() + max_mb = tracker.max_rss_bytes / 1024 / 1024 + logger.info("Max memory usage (RSS, incl. children): %.2f MB", max_mb) + # Persist to XCom for observability + ti.xcom_push(key=XCOM_DEBUG_MAX_MEMORY_MB_KEY, value=round(max_mb, 2)) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 7ebab3ccf8..b486a476f7 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -24,6 +24,7 @@ from airflow.utils.strings import to_boolean +from cosmos import settings from cosmos.dbt.executable import get_system_dbt from cosmos.log import get_logger @@ -315,7 +316,18 @@ def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) - self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs) + if settings.enable_debug_mode: + from cosmos.debug import start_memory_tracking, stop_memory_tracking + + start_memory_tracking(context) + try: + self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs) + stop_memory_tracking(context) + except Exception: + stop_memory_tracking(context) + raise + else: + self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs) class DbtBuildMixin: diff --git a/cosmos/settings.py b/cosmos/settings.py index e002d953ee..285c50999d 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -85,3 +85,7 @@ def convert_to_boolean(value: str | None) -> bool: enable_telemetry = conf.getboolean("cosmos", "enable_telemetry", fallback=True) do_not_track = convert_to_boolean(os.getenv("DO_NOT_TRACK")) no_analytics = convert_to_boolean(os.getenv("SCARF_NO_ANALYTICS")) + +# Debug mode - when enabled, Cosmos will track and push memory utilization to XCom +enable_debug_mode = conf.getboolean("cosmos", "enable_debug_mode", fallback=False) +debug_memory_poll_interval_seconds = conf.getfloat("cosmos", "debug_memory_poll_interval_seconds", fallback=0.5) diff --git a/docs/configuration/cosmos-conf.rst b/docs/configuration/cosmos-conf.rst index f8db11e593..1ed4c7bb56 100644 --- a/docs/configuration/cosmos-conf.rst +++ b/docs/configuration/cosmos-conf.rst @@ -261,6 +261,24 @@ This page lists all available Airflow configurations that affect ``astronomer-co :start-after: [START cosmos_init_imports] :end-before: [END cosmos_init_imports] +.. _enable_debug_mode: + +`enable_debug_mode`_: + Enable or disable debug mode. When enabled, Cosmos will track memory utilization for its tasks and push the peak + memory usage (in MB) to XCom under the key ``cosmos_debug_max_memory_mb``. This is useful for profiling and + optimizing resource allocation for dbt tasks. Requires ``psutil`` to be installed. + + - Default: ``False`` + - Environment Variable: ``AIRFLOW__COSMOS__ENABLE_DEBUG_MODE`` + +.. _debug_memory_poll_interval_seconds: + +`debug_memory_poll_interval_seconds`_: + The interval (in seconds) at which memory utilization is polled when debug mode is enabled. Lower values provide + more accurate peak memory measurements but may add slight overhead. + + - Default: ``0.5`` + - Environment Variable: ``AIRFLOW__COSMOS__DEBUG_MEMORY_POLL_INTERVAL_SECONDS`` [openlineage] ~~~~~~~~~~~~~ diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index 400a97a06b..d9e605eeb7 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -181,3 +181,51 @@ def test_abstract_dbt_base_init_no_super(): source = inspect.getsource(init_method) assert "super().__init__" not in source + + +@patch("cosmos.operators.base.settings") +@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd") +def test_dbt_base_operator_execute_debug_mode_exception_stops_memory_tracking( + mock_build_and_run_cmd, mock_settings, monkeypatch +): + """Tests that stop_memory_tracking is called when an exception occurs during debug mode execution.""" + mock_settings.enable_debug_mode = True + mock_build_and_run_cmd.side_effect = RuntimeError("Test exception") + + monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: []) + AbstractDbtBase.__abstractmethods__ = set() + + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") + + with patch("cosmos.debug.start_memory_tracking") as mock_start_tracking: + with patch("cosmos.debug.stop_memory_tracking") as mock_stop_tracking: + with pytest.raises(RuntimeError, match="Test exception"): + base_operator.execute(context={}) + + # Verify memory tracking was started + mock_start_tracking.assert_called_once_with({}) + # Verify memory tracking was stopped even though an exception occurred + mock_stop_tracking.assert_called_once_with({}) + + +@patch("cosmos.operators.base.settings") +@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd") +def test_dbt_base_operator_execute_debug_mode_success_stops_memory_tracking( + mock_build_and_run_cmd, mock_settings, monkeypatch +): + """Tests that stop_memory_tracking is called when execution succeeds in debug mode.""" + mock_settings.enable_debug_mode = True + + monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: []) + AbstractDbtBase.__abstractmethods__ = set() + + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") + + with patch("cosmos.debug.start_memory_tracking") as mock_start_tracking: + with patch("cosmos.debug.stop_memory_tracking") as mock_stop_tracking: + base_operator.execute(context={}) + + # Verify memory tracking was started + mock_start_tracking.assert_called_once_with({}) + # Verify memory tracking was stopped after successful execution + mock_stop_tracking.assert_called_once_with({}) diff --git a/tests/test_debug.py b/tests/test_debug.py new file mode 100644 index 0000000000..2809202cd7 --- /dev/null +++ b/tests/test_debug.py @@ -0,0 +1,239 @@ +"""Tests for the cosmos.debug module.""" + +from __future__ import annotations + +import os +import sys +import time +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from airflow import DAG + +from cosmos import debug, settings +from cosmos.config import ProfileConfig +from cosmos.operators.local import DbtRunLocalOperator +from tests.utils import test_dag as run_test_dag + + +class TestPsutilImport: + """Tests for psutil import handling.""" + + def test_import_raises_runtime_error_when_psutil_unavailable(self): + """Test that importing cosmos.debug raises RuntimeError when psutil is not available.""" + # Remove cosmos.debug from sys.modules to allow re-import + modules_to_remove = [key for key in sys.modules if key.startswith("cosmos.debug")] + original_modules = {key: sys.modules.pop(key) for key in modules_to_remove} + + # Mock psutil to simulate it not being installed + with patch.dict(sys.modules, {"psutil": None}): + with pytest.raises(RuntimeError) as exc_info: + from importlib import import_module + + import_module("cosmos.debug") + + assert "psutil is not available" in str(exc_info.value) + + # Restore original modules + sys.modules.update(original_modules) + + +class TestMemoryTracker: + """Tests for the MemoryTracker class.""" + + def test_memory_tracker_initialization(self): + """Test MemoryTracker initializes with correct values.""" + tracker = debug.MemoryTracker(pid=os.getpid(), poll_interval=0.1) + assert tracker.pid == os.getpid() + assert tracker.poll_interval == 0.1 + assert tracker.max_rss_bytes == 0 + + def test_memory_tracker_tracks_memory(self): + """Test MemoryTracker actually tracks memory usage.""" + tracker = debug.MemoryTracker(pid=os.getpid(), poll_interval=0.05) + tracker.start() + # Give it time to sample + time.sleep(0.2) + tracker.stop() + # Memory should be tracked (current process uses some memory) + assert tracker.max_rss_bytes > 0 + + def test_memory_tracker_stop_without_start(self): + """Test MemoryTracker.stop() doesn't raise if not started.""" + tracker = debug.MemoryTracker(pid=os.getpid()) + # Should not raise + tracker.stop() + + def test_memory_tracker_with_nonexistent_pid(self): + """Test MemoryTracker handles non-existent PID gracefully.""" + # Use a very high PID that's unlikely to exist + tracker = debug.MemoryTracker(pid=999999999, poll_interval=0.05) + tracker.start() + time.sleep(0.1) + tracker.stop() + # Should have 0 bytes since process doesn't exist + assert tracker.max_rss_bytes == 0 + + def test_memory_tracker_handles_child_process_termination(self): + """Test MemoryTracker continues when a child process terminates during memory_info() call.""" + import psutil + + tracker = debug.MemoryTracker(pid=os.getpid(), poll_interval=0.05) + + # Mock a child process that raises NoSuchProcess when memory_info() is called + mock_child = MagicMock() + mock_child.memory_info.side_effect = psutil.NoSuchProcess(pid=12345) + + mock_parent = MagicMock() + mock_parent.memory_info.return_value = MagicMock(rss=1024 * 1024) # 1 MB + mock_parent.children.return_value = [mock_child] + + with patch("psutil.Process", return_value=mock_parent): + tracker.start() + time.sleep(0.15) + tracker.stop() + + # Should have tracked memory from parent only (child raised NoSuchProcess) + assert tracker.max_rss_bytes >= 1024 * 1024 + + def test_memory_tracker_handles_parent_children_call_failure(self): + """Test MemoryTracker breaks loop when parent.children() raises NoSuchProcess.""" + import psutil + + tracker = debug.MemoryTracker(pid=os.getpid(), poll_interval=0.05) + + mock_parent = MagicMock() + # First call succeeds, second call raises NoSuchProcess + mock_parent.memory_info.return_value = MagicMock(rss=1024 * 1024) + mock_parent.children.side_effect = [[], psutil.NoSuchProcess(pid=os.getpid())] + + with patch("psutil.Process", return_value=mock_parent): + tracker.start() + time.sleep(0.15) + tracker.stop() + + # Should have tracked some memory before the exception + assert tracker.max_rss_bytes >= 1024 * 1024 + + +class TestStartMemoryTracking: + """Tests for the start_memory_tracking function.""" + + @pytest.fixture + def mock_context(self): + """Create a mock Airflow context.""" + mock_ti = MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test_task" + mock_ti.run_id = "test_run_123" + return {"ti": mock_ti} + + def test_start_memory_tracking_creates_tracker(self, mock_context): + """Test start_memory_tracking creates tracker.""" + debug.start_memory_tracking(mock_context) + task_key = f"{mock_context['ti'].dag_id}.{mock_context['ti'].task_id}.{mock_context['ti'].run_id}" + assert task_key in debug._memory_trackers + # Cleanup + tracker = debug._memory_trackers.pop(task_key) + tracker.stop() + + +class TestStopMemoryTracking: + """Tests for the stop_memory_tracking function.""" + + @pytest.fixture + def mock_context(self): + """Create a mock Airflow context.""" + mock_ti = MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test_task" + mock_ti.run_id = "test_run_123" + return {"ti": mock_ti} + + def test_stop_memory_tracking_pushes_xcom(self, mock_context): + """Test stop_memory_tracking pushes memory data to XCom.""" + # Start tracking first + debug.start_memory_tracking(mock_context) + time.sleep(0.1) # Let it sample + # Stop and check XCom push + debug.stop_memory_tracking(mock_context) + mock_context["ti"].xcom_push.assert_called_once() + call_args = mock_context["ti"].xcom_push.call_args + assert call_args[1]["key"] == "cosmos_debug_max_memory_mb" + assert isinstance(call_args[1]["value"], float) + assert call_args[1]["value"] > 0 + + def test_stop_memory_tracking_no_tracker(self, mock_context): + """Test stop_memory_tracking handles missing tracker gracefully.""" + # Don't start tracking, just stop + debug.stop_memory_tracking(mock_context) + # Should not raise and xcom_push should not be called + mock_context["ti"].xcom_push.assert_not_called() + + +class TestIntegration: + """Integration tests for the full debug flow.""" + + @pytest.fixture + def mock_context(self): + """Create a mock Airflow context.""" + mock_ti = MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test_task" + mock_ti.run_id = "test_run_integration" + return {"ti": mock_ti} + + def test_full_tracking_lifecycle(self, mock_context): + """Test complete memory tracking lifecycle.""" + # Start + debug.start_memory_tracking(mock_context) + task_key = f"{mock_context['ti'].dag_id}.{mock_context['ti'].task_id}.{mock_context['ti'].run_id}" + assert task_key in debug._memory_trackers + + # Simulate some work + time.sleep(0.2) + + # Stop + debug.stop_memory_tracking(mock_context) + assert task_key not in debug._memory_trackers + mock_context["ti"].xcom_push.assert_called_once() + + +MINI_DBT_PROJ_DIR = Path(__file__).parent / "sample" / "mini" +MINI_DBT_PROJ_PROFILE = MINI_DBT_PROJ_DIR / "profiles.yml" + +mini_profile_config = ProfileConfig( + profile_name="mini", + target_name="dev", + profiles_yml_filepath=MINI_DBT_PROJ_PROFILE, +) + + +@pytest.mark.integration +def test_dbt_run_local_operator_stores_memory_in_xcom_when_debug_enabled(): + """ + Integration test that DbtRunLocalOperator pushes peak memory utilization to XCom + when debug mode is enabled. + """ + with patch.object(settings, "enable_debug_mode", True): + with DAG("test-debug-memory", start_date=datetime(2022, 1, 1)) as dag: + run_operator = DbtRunLocalOperator( + profile_config=mini_profile_config, + project_dir=MINI_DBT_PROJ_DIR, + task_id="run", + append_env=True, + emit_datasets=False, + ) + run_operator + + dag_run = run_test_dag(dag) + + # Get the task instance to check XCom + ti = dag_run.get_task_instance(task_id="run") + memory_value = ti.xcom_pull(key="cosmos_debug_max_memory_mb") + + assert memory_value is not None, "Expected cosmos_debug_max_memory_mb in XCom" + assert isinstance(memory_value, float), "Memory value should be a float" + assert memory_value > 0, "Memory value should be greater than 0" diff --git a/tests/test_settings.py b/tests/test_settings.py index 1b8f79fa33..8efe66f1de 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -37,3 +37,19 @@ def test_enable_memory_optimised_imports_false(monkeypatch): result = subprocess.run(["python", "-c", script], capture_output=True, text=True) assert result.returncode == 0, result.stderr + + +@patch.dict(os.environ, {"AIRFLOW__COSMOS__ENABLE_DEBUG_MODE": "True"}, clear=True) +def test_enable_debug_mode_env_var(): + reload(settings) + assert settings.enable_debug_mode is True + + +@patch.dict( + os.environ, + {"AIRFLOW__COSMOS__DEBUG_MEMORY_POLL_INTERVAL_SECONDS": "0.25"}, + clear=True, +) +def test_debug_memory_poll_interval_env_var(): + reload(settings) + assert settings.debug_memory_poll_interval_seconds == 0.25