diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 3dcae1f2cb..2103dc977b 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -5,7 +5,10 @@ import logging import zlib from datetime import timedelta -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union + +import airflow +from packaging.version import Version if TYPE_CHECKING: # pragma: no cover try: @@ -37,6 +40,9 @@ DbtSourceLocalOperator, ) +AIRFLOW_VERSION = Version(airflow.__version__) + + try: from dbt_common.events.base_types import EventMsg except ImportError: # pragma: no cover @@ -82,7 +88,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator") kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WEIGHT_RULE) - super().__init__(task_id=task_id, *args, **kwargs) + on_failure_callback = self._set_on_failure_callback(kwargs.pop("on_failure_callback", None)) + super().__init__(task_id=task_id, *args, on_failure_callback=on_failure_callback, **kwargs) + + def _set_on_failure_callback( + self, user_callback: Any + ) -> Union[Callable[[Context], None], List[Callable[[Context], None]]]: + default_callback = self._store_producer_task_state + + if AIRFLOW_VERSION < Version("2.6.0"): + # Older versions only support a single callable + return default_callback + else: + if user_callback is None: + # No callback provided — use default in a list + return [default_callback] + elif isinstance(user_callback, list): + # Append to existing list of callbacks (make a copy to avoid side effects) + return user_callback + [default_callback] + else: + # Single callable provided — wrap it in a list and append ours + return [user_callback, default_callback] @staticmethod def _serialize_event(ev: EventMsg) -> dict[str, Any]: @@ -115,6 +141,10 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N if startup_events: ti.xcom_push(key="dbt_startup_events", value=startup_events) + def _store_producer_task_state(self, context: Context) -> None: + ti = context["ti"] + ti.xcom_push(key="state", value="failed") + def execute(self, context: Context, **kwargs: Any) -> Any: try: if not self.invocation_mode: @@ -298,6 +328,11 @@ def poke(self, context: Context) -> bool: status = self._get_status_from_run_results(ti) if status is None: + producer_task_state = ti.xcom_pull(task_ids=self.producer_task_id, key="state") + if producer_task_state == "failed": + raise AirflowException( + f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + ) return False elif status == "success": return True diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index deac985e06..fcbad5d732 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from airflow.exceptions import AirflowException @@ -232,6 +232,37 @@ def fake_build_run(self, context, **kw): assert data["results"][0]["status"] == "success" +@pytest.mark.parametrize( + "user_callback, expected_behavior", + [ + (None, "none"), + ([Mock(name="cb1")], "list"), + (Mock(name="cb2"), "single"), + ], +) +def test_set_on_failure_callback_with_actual_airflow(user_callback, expected_behavior, tmp_path): + + instance = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) + result = instance._set_on_failure_callback(user_callback) + + if AIRFLOW_VERSION < Version("2.6.0"): + # Always returns single callable regardless of input + assert callable(result) + assert result == instance._store_producer_task_state + else: + # Returns list depending on input + assert isinstance(result, list) + assert result[-1] == instance._store_producer_task_state + + if expected_behavior == "none": + assert len(result) == 1 + elif expected_behavior == "list": + assert len(result) == 2 + elif expected_behavior == "single": + assert len(result) == 2 + assert result[0] == user_callback + + @patch("cosmos.dbt.runner.is_available", return_value=False) @patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", return_value="done") def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): @@ -251,6 +282,16 @@ def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): assert op.invocation_mode == InvocationMode.SUBPROCESS +def test_store_producer_task_state_pushes_failed_state(): + mock_ti = MagicMock() + mock_context = {"ti": mock_ti} + instance = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + + instance._store_producer_task_state(mock_context) + + mock_ti.xcom_push.assert_called_once_with(key="state", value="failed") + + MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" ENCODED_RUN_RESULTS = base64.b64encode( zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}') @@ -291,7 +332,7 @@ def test_poke_status_none_from_events(self, MockEventMsg): sensor.invocation_mode = InvocationMode.DBT_RUNNER ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.side_effect = [None, None] # no event msg found + ti.xcom_pull.side_effect = [None, None, None] # no event msg found context = self.make_context(ti) result = sensor.poke(context) @@ -411,6 +452,22 @@ def test_get_status_from_events_none(self): result = sensor._get_status_from_events(ti) assert result is None + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") + def test_producer_state_failed(self, mock_run_result): + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + mock_run_result.return_value = None + ti.xcom_pull.return_value = "failed" + + context = self.make_context(ti) + + with pytest.raises( + AirflowException, + match="The dbt build command failed in producer task. Please check the log of task dbt_producer_watcher for details.", + ): + sensor.poke(context) + class TestDbtBuildWatcherOperator: