Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 2 additions & 28 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import timedelta
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, List, Union
from typing import TYPE_CHECKING, Any

from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom

Expand Down Expand Up @@ -110,27 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["default_args"] = default_args
kwargs["retries"] = 0

on_failure_callback = self._set_on_failure_callback(kwargs.pop("on_failure_callback", None))
super().__init__(task_id=task_id, *args, on_failure_callback=on_failure_callback, **kwargs)

def _set_on_failure_callback(
self, user_callback: Any
) -> Union[Callable[[Context], None], List[Callable[[Context], None]]]:
default_callback = self._store_producer_task_state

if AIRFLOW_VERSION < Version("2.6.0"):
# Older versions only support a single callable
return default_callback
else:
if user_callback is None:
# No callback provided — use default in a list
return [default_callback]
elif isinstance(user_callback, list):
# Append to existing list of callbacks (make a copy to avoid side effects)
return user_callback + [default_callback]
else:
# Single callable provided — wrap it in a list and append ours
return [user_callback, default_callback]
super().__init__(task_id=task_id, *args, **kwargs)

@staticmethod
def _serialize_event(event_message: EventMsg) -> dict[str, Any]:
Expand Down Expand Up @@ -179,9 +159,6 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N
if startup_events:
safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events)

def _store_producer_task_state(self, context: Context) -> None:
safe_xcom_push(task_instance=context["ti"], key="state", value="failed")

def execute(self, context: Context, **kwargs: Any) -> Any:
task_instance = context.get("ti")
if task_instance is None:
Expand Down Expand Up @@ -371,9 +348,6 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any:

return node_result.get("status")

def _get_producer_task_state(self, ti: Any) -> Any:
return ti.xcom_pull(task_ids=self.producer_task_id, key="state")

def _get_producer_task_status(self, context: Context) -> str | None:
"""
Get the task status of the producer task for both Airflow 2 and Airflow 3.
Expand Down
43 changes: 1 addition & 42 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,37 +425,6 @@ def fake_build_run(self, context, **kw):
assert data["results"][0]["status"] == "success"


@pytest.mark.parametrize(
"user_callback, expected_behavior",
[
(None, "none"),
([Mock(name="cb1")], "list"),
(Mock(name="cb2"), "single"),
],
)
def test_set_on_failure_callback_with_actual_airflow(user_callback, expected_behavior, tmp_path):

instance = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None)
result = instance._set_on_failure_callback(user_callback)

if AIRFLOW_VERSION < Version("2.6.0"):
# Always returns single callable regardless of input
assert callable(result)
assert result == instance._store_producer_task_state
else:
# Returns list depending on input
assert isinstance(result, list)
assert result[-1] == instance._store_producer_task_state

if expected_behavior == "none":
assert len(result) == 1
elif expected_behavior == "list":
assert len(result) == 2
elif expected_behavior == "single":
assert len(result) == 2
assert result[0] == user_callback


@patch("cosmos.dbt.runner.is_available", return_value=False)
@patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", return_value="done")
def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available):
Expand All @@ -475,16 +444,6 @@ def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available):
assert op.invocation_mode == InvocationMode.SUBPROCESS


def test_store_producer_task_state_pushes_failed_state():
mock_ti = MagicMock()
mock_context = {"ti": mock_ti}
instance = DbtProducerWatcherOperator(project_dir=".", profile_config=None)

instance._store_producer_task_state(mock_context)

mock_ti.xcom_push.assert_called_once_with(key="state", value="failed")


MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders"
ENCODED_RUN_RESULTS = base64.b64encode(
zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}')
Expand Down Expand Up @@ -652,7 +611,7 @@ def test_poke_success_from_run_results(self):
assert result is True

@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None)
def _fallback_to_local_run(self, mock_get_producer_task_state):
def _fallback_to_local_run(self, mock_get_producer_task_status):
sensor = self.make_sensor()
sensor.invocation_mode = None

Expand Down