Skip to content
122 changes: 122 additions & 0 deletions cosmos/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Debug utilities for Cosmos.

When debug mode is enabled via the `enable_debug_mode` setting, Cosmos will track
memory utilization during task execution and push the maximum memory usage to XCom.
"""

from __future__ import annotations

import os
import threading
import time
from typing import TYPE_CHECKING

try:
import psutil
except ImportError:
raise RuntimeError(
"psutil is not available. Install `https://pypi.org/project/psutil/` to enable memory tracking."
)

from cosmos import settings
from cosmos.log import get_logger

if TYPE_CHECKING:
try:
from airflow.sdk.definitions.context import Context
except ImportError:
from airflow.utils.context import Context # type: ignore[attr-defined]

logger = get_logger(__name__)

# XCom key for storing maximum memory usage in debug mode
XCOM_DEBUG_MAX_MEMORY_MB_KEY = "cosmos_debug_max_memory_mb"

# Global dictionary to store memory trackers per task
_memory_trackers: dict[str, MemoryTracker] = {}


class MemoryTracker:
"""
Tracks maximum RSS memory (bytes) for a process and all of its children.
Sampling-based to work across Airflow 2 & 3 without executor internals.
"""

def __init__(self, pid: int, poll_interval: float = 0.5):
self.pid = pid
self.poll_interval = poll_interval
self.max_rss_bytes = 0
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)

def start(self) -> None:
"""Start the memory tracking thread."""
self._thread.start()

def stop(self) -> None:
"""Stop the memory tracking thread."""
self._stop_event.set()
if self._thread.is_alive():
self._thread.join(timeout=5)

def _run(self) -> None:
"""Background thread that polls memory usage."""
try:
parent = psutil.Process(self.pid)
except psutil.NoSuchProcess:
return

while not self._stop_event.is_set():
rss = 0
try:
processes = [parent] + parent.children(recursive=True)
for p in processes:
try:
rss += p.memory_info().rss
except psutil.NoSuchProcess:
continue
self.max_rss_bytes = max(self.max_rss_bytes, rss)
except psutil.NoSuchProcess:
break

time.sleep(self.poll_interval)


def start_memory_tracking(context: Context) -> None:
"""
Callback to start memory tracking for a task.

This function should be used as an `on_execute_callback` for Cosmos operators
when debug mode is enabled.

:param context: The Airflow task context.
"""
ti = context["ti"]
task_key = f"{ti.dag_id}.{ti.task_id}.{ti.run_id}"
pid = os.getpid()
tracker = MemoryTracker(pid=pid, poll_interval=settings.debug_memory_poll_interval_seconds)
_memory_trackers[task_key] = tracker
tracker.start()
logger.debug("Started memory tracking for task %s (PID: %s)", task_key, pid)


def stop_memory_tracking(context: Context) -> None:
"""
Callback to stop memory tracking for a task and push the result to XCom.

This function should be used as an `on_success_callback` or `on_failure_callback`
for Cosmos operators when debug mode is enabled.

:param context: The Airflow task context.
"""
ti = context["ti"]
task_key = f"{ti.dag_id}.{ti.task_id}.{ti.run_id}"
tracker = _memory_trackers.pop(task_key, None)

if tracker:
tracker.stop()
max_mb = tracker.max_rss_bytes / 1024 / 1024
logger.info("Max memory usage (RSS, incl. children): %.2f MB", max_mb)
# Persist to XCom for observability
ti.xcom_push(key=XCOM_DEBUG_MAX_MEMORY_MB_KEY, value=round(max_mb, 2))
14 changes: 13 additions & 1 deletion cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.utils.strings import to_boolean

from cosmos import settings
from cosmos.dbt.executable import get_system_dbt
from cosmos.log import get_logger

Expand Down Expand Up @@ -315,7 +316,18 @@ def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore
if self.extra_context:
context_merge(context, self.extra_context)

self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs)
if settings.enable_debug_mode:
from cosmos.debug import start_memory_tracking, stop_memory_tracking

start_memory_tracking(context)
try:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs)
stop_memory_tracking(context)
except Exception:
stop_memory_tracking(context)
raise
else:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs)


class DbtBuildMixin:
Expand Down
4 changes: 4 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def convert_to_boolean(value: str | None) -> bool:
enable_telemetry = conf.getboolean("cosmos", "enable_telemetry", fallback=True)
do_not_track = convert_to_boolean(os.getenv("DO_NOT_TRACK"))
no_analytics = convert_to_boolean(os.getenv("SCARF_NO_ANALYTICS"))

# Debug mode - when enabled, Cosmos will track and push memory utilization to XCom
enable_debug_mode = conf.getboolean("cosmos", "enable_debug_mode", fallback=False)
debug_memory_poll_interval_seconds = conf.getfloat("cosmos", "debug_memory_poll_interval_seconds", fallback=0.5)
18 changes: 18 additions & 0 deletions docs/configuration/cosmos-conf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,24 @@ This page lists all available Airflow configurations that affect ``astronomer-co
:start-after: [START cosmos_init_imports]
:end-before: [END cosmos_init_imports]

.. _enable_debug_mode:

`enable_debug_mode`_:
Enable or disable debug mode. When enabled, Cosmos will track memory utilization for its tasks and push the peak
memory usage (in MB) to XCom under the key ``cosmos_debug_max_memory_mb``. This is useful for profiling and
optimizing resource allocation for dbt tasks. Requires ``psutil`` to be installed.

- Default: ``False``
- Environment Variable: ``AIRFLOW__COSMOS__ENABLE_DEBUG_MODE``

.. _debug_memory_poll_interval_seconds:

`debug_memory_poll_interval_seconds`_:
The interval (in seconds) at which memory utilization is polled when debug mode is enabled. Lower values provide
more accurate peak memory measurements but may add slight overhead.

- Default: ``0.5``
- Environment Variable: ``AIRFLOW__COSMOS__DEBUG_MEMORY_POLL_INTERVAL_SECONDS``

[openlineage]
~~~~~~~~~~~~~
Expand Down
48 changes: 48 additions & 0 deletions tests/operators/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,51 @@ def test_abstract_dbt_base_init_no_super():

source = inspect.getsource(init_method)
assert "super().__init__" not in source


@patch("cosmos.operators.base.settings")
@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd")
def test_dbt_base_operator_execute_debug_mode_exception_stops_memory_tracking(
mock_build_and_run_cmd, mock_settings, monkeypatch
):
"""Tests that stop_memory_tracking is called when an exception occurs during debug mode execution."""
mock_settings.enable_debug_mode = True
mock_build_and_run_cmd.side_effect = RuntimeError("Test exception")

monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: [])
AbstractDbtBase.__abstractmethods__ = set()

base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir")

with patch("cosmos.debug.start_memory_tracking") as mock_start_tracking:
with patch("cosmos.debug.stop_memory_tracking") as mock_stop_tracking:
with pytest.raises(RuntimeError, match="Test exception"):
base_operator.execute(context={})

# Verify memory tracking was started
mock_start_tracking.assert_called_once_with({})
# Verify memory tracking was stopped even though an exception occurred
mock_stop_tracking.assert_called_once_with({})


@patch("cosmos.operators.base.settings")
@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd")
def test_dbt_base_operator_execute_debug_mode_success_stops_memory_tracking(
mock_build_and_run_cmd, mock_settings, monkeypatch
):
"""Tests that stop_memory_tracking is called when execution succeeds in debug mode."""
mock_settings.enable_debug_mode = True

monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: [])
AbstractDbtBase.__abstractmethods__ = set()

base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir")

with patch("cosmos.debug.start_memory_tracking") as mock_start_tracking:
with patch("cosmos.debug.stop_memory_tracking") as mock_stop_tracking:
base_operator.execute(context={})

# Verify memory tracking was started
mock_start_tracking.assert_called_once_with({})
# Verify memory tracking was stopped after successful execution
mock_stop_tracking.assert_called_once_with({})
Loading