Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import logging
import zlib
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union

import airflow
from packaging.version import Version

if TYPE_CHECKING: # pragma: no cover
try:
Expand Down Expand Up @@ -37,6 +40,9 @@
DbtSourceLocalOperator,
)

AIRFLOW_VERSION = Version(airflow.__version__)


try:
from dbt_common.events.base_types import EventMsg
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -82,7 +88,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator")
kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT)
kwargs.setdefault("weight_rule", WEIGHT_RULE)
super().__init__(task_id=task_id, *args, **kwargs)
on_failure_callback = self._set_on_failure_callback(kwargs.pop("on_failure_callback", None))
super().__init__(task_id=task_id, *args, on_failure_callback=on_failure_callback, **kwargs)

def _set_on_failure_callback(
self, user_callback: Any
) -> Union[Callable[[Context], None], List[Callable[[Context], None]]]:
Comment thread
pankajastro marked this conversation as resolved.
default_callback = self._store_producer_task_state

if AIRFLOW_VERSION < Version("2.6.0"):
# Older versions only support a single callable
return default_callback
Comment thread
pankajastro marked this conversation as resolved.
else:
if user_callback is None:
# No callback provided — use default in a list
return [default_callback]
elif isinstance(user_callback, list):
# Append to existing list of callbacks (make a copy to avoid side effects)
return user_callback + [default_callback]
else:
# Single callable provided — wrap it in a list and append ours
return [user_callback, default_callback]

@staticmethod
def _serialize_event(ev: EventMsg) -> dict[str, Any]:
Expand Down Expand Up @@ -115,6 +141,10 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N
if startup_events:
ti.xcom_push(key="dbt_startup_events", value=startup_events)

def _store_producer_task_state(self, context: Context) -> None:
ti = context["ti"]
ti.xcom_push(key="state", value="failed")
Comment thread
pankajastro marked this conversation as resolved.

def execute(self, context: Context, **kwargs: Any) -> Any:
try:
if not self.invocation_mode:
Expand Down Expand Up @@ -298,6 +328,11 @@ def poke(self, context: Context) -> bool:
status = self._get_status_from_run_results(ti)

if status is None:
producer_task_state = ti.xcom_pull(task_ids=self.producer_task_id, key="state")
if producer_task_state == "failed":
raise AirflowException(
f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details."
)
return False
elif status == "success":
return True
Expand Down
61 changes: 59 additions & 2 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -232,6 +232,37 @@ def fake_build_run(self, context, **kw):
assert data["results"][0]["status"] == "success"


@pytest.mark.parametrize(
"user_callback, expected_behavior",
[
(None, "none"),
([Mock(name="cb1")], "list"),
(Mock(name="cb2"), "single"),
],
)
def test_set_on_failure_callback_with_actual_airflow(user_callback, expected_behavior, tmp_path):

instance = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None)
result = instance._set_on_failure_callback(user_callback)

if AIRFLOW_VERSION < Version("2.6.0"):
# Always returns single callable regardless of input
assert callable(result)
assert result == instance._store_producer_task_state
else:
# Returns list depending on input
assert isinstance(result, list)
assert result[-1] == instance._store_producer_task_state

if expected_behavior == "none":
assert len(result) == 1
elif expected_behavior == "list":
assert len(result) == 2
elif expected_behavior == "single":
assert len(result) == 2
assert result[0] == user_callback


@patch("cosmos.dbt.runner.is_available", return_value=False)
@patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", return_value="done")
def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available):
Expand All @@ -251,6 +282,16 @@ def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available):
assert op.invocation_mode == InvocationMode.SUBPROCESS


def test_store_producer_task_state_pushes_failed_state():
mock_ti = MagicMock()
mock_context = {"ti": mock_ti}
instance = DbtProducerWatcherOperator(project_dir=".", profile_config=None)

instance._store_producer_task_state(mock_context)

mock_ti.xcom_push.assert_called_once_with(key="state", value="failed")


MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders"
ENCODED_RUN_RESULTS = base64.b64encode(
zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}')
Expand Down Expand Up @@ -291,7 +332,7 @@ def test_poke_status_none_from_events(self, MockEventMsg):
sensor.invocation_mode = InvocationMode.DBT_RUNNER
ti = MagicMock()
ti.try_number = 1
ti.xcom_pull.side_effect = [None, None] # no event msg found
ti.xcom_pull.side_effect = [None, None, None] # no event msg found
context = self.make_context(ti)

result = sensor.poke(context)
Expand Down Expand Up @@ -411,6 +452,22 @@ def test_get_status_from_events_none(self):
result = sensor._get_status_from_events(ti)
assert result is None

@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results")
def test_producer_state_failed(self, mock_run_result):
sensor = self.make_sensor()
ti = MagicMock()
ti.try_number = 1
mock_run_result.return_value = None
ti.xcom_pull.return_value = "failed"

context = self.make_context(ti)

with pytest.raises(
AirflowException,
match="The dbt build command failed in producer task. Please check the log of task dbt_producer_watcher for details.",
):
sensor.poke(context)


class TestDbtBuildWatcherOperator:

Expand Down