From 41e86eaff555f1dc1376cf0690a389505e7da5c3 Mon Sep 17 00:00:00 2001 From: Michal Mrazek Date: Sun, 8 Mar 2026 18:16:06 +0100 Subject: [PATCH 1/3] feat: implement DbtTestWatcherOperator as a functional consumer sensor with deferrable support --- cosmos/operators/_watcher/base.py | 65 ++++++++--- cosmos/operators/_watcher/triggerer.py | 32 +++++- cosmos/operators/watcher.py | 43 +++++--- tests/operators/_watcher/test_triggerer.py | 27 +++-- tests/operators/test_watcher.py | 119 +++++++++++++++++++-- 5 files changed, 233 insertions(+), 53 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index e5d176b56c..9399591b80 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -17,7 +17,7 @@ WATCHER_TASK_WEIGHT_RULE, ) from cosmos.log import get_logger -from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate +from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key, push_test_result_or_aggregate from cosmos.operators._watcher.state import ( DBT_FAILED_STATUSES, _iso_to_string, @@ -28,7 +28,7 @@ is_dbt_node_status_terminal, safe_xcom_push, ) -from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom +from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger, _parse_compressed_xcom try: from airflow.sdk.bases.sensor import BaseSensorOperator @@ -288,6 +288,16 @@ def __init__( self.deferrable = deferrable self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") + @property + def is_test_sensor(self) -> bool: + """Whether this sensor watches aggregated test results instead of individual node results.""" + return False + + @property + def _resource_label(self) -> str: + """Human-readable label for log and error messages.""" + return "Tests for model" if self.is_test_sensor else "Model" + @staticmethod def _filter_flags(flags: list[str]) -> list[str]: """Filters out dbt flags that are incompatible with retry (e.g., --select, --exclude).""" @@ -310,7 +320,16 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo Handles logic for retrying a failed dbt model execution. Reconstructs the dbt command by cloning the project and re-running the model with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. + + For test sensors, re-execution is not supported in watcher mode; retries are skipped. """ + if self.is_test_sensor: + logger.info( + "Test sensor retry detected (try #%s). Skipping — test re-execution is not supported in watcher mode.", + try_number, + ) + return True + logger.info( f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}", try_number - 1, @@ -401,6 +420,7 @@ def execute(self, context: Context, **kwargs: Any) -> None: map_index=context["task_instance"].map_index, use_event=self.use_event(), poke_interval=self.poke_interval, + is_test_sensor=self.is_test_sensor, ), timeout=self.execution_timeout, method_name=self.execute_complete.__name__, @@ -410,9 +430,10 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: status = event.get("status") reason = event.get("reason") - if status == "success" and reason == "model_not_run": + if status == "success" and reason == WatcherEventReason.NODE_NOT_RUN: logger.info( - "Model '%s' was skipped by the dbt command. This may happen if it is an ephemeral model or if the model sql file is empty.", + "%s '%s' was skipped by the dbt command. This may happen if it is an ephemeral model or if the model sql file is empty.", + self._resource_label, self.model_unique_id, ) @@ -432,14 +453,14 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: task_ids=self.producer_task_id, ) _log_dbt_event(dbt_events) - if reason == "model_failed": + if reason == WatcherEventReason.NODE_FAILED: raise AirflowException( - f"dbt model '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details." + f"dbt {self._resource_label.lower()} '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details." ) - if reason == "producer_failed": + if reason == WatcherEventReason.PRODUCER_FAILED: raise AirflowException( - f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." + f"Watcher producer task '{self.producer_task_id}' failed before reporting results for {self._resource_label.lower()} '{self.model_unique_id}'. Check its logs for the underlying error." ) def use_event(self) -> bool: @@ -457,19 +478,33 @@ def _log_startup_events(self, ti: Any) -> None: # Adding debug level to avoid redundant logs for non-deferrable mode logger.debug("%s", event.get("msg")) + def _get_node_status(self, ti: Any, context: Context) -> Any: + """Return the current status of the watched dbt node from XCom. + + For test sensors, reads the aggregated ``_tests_status`` key. + For model sensors, reads from event-based or subprocess-based keys. + """ + if self.is_test_sensor: + xcom_key = get_tests_status_xcom_key(self.model_unique_id) + return get_xcom_val(ti, self.producer_task_id, xcom_key) + if self.use_event(): + return self._get_status_from_events(ti, context) + return get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + def poke(self, context: Context) -> bool: """ - Checks the status of a dbt model run by pulling relevant XComs from the master task. - Handles retries and checks for successful completion of the model execution. + Checks the status of a dbt node (model or aggregated tests) by pulling relevant XComs from the producer task. + Handles retries and checks for successful completion. """ ti = context["ti"] try_number = ti.try_number logger.info( - "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'", + "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for %s '%s'", try_number, self.poke_retry_number, self.producer_task_id, + self._resource_label.lower(), self.model_unique_id, ) @@ -478,11 +513,9 @@ def poke(self, context: Context) -> bool: # We have assumption here that both the build producer and the sensor task will have same invocation mode producer_task_state = self._get_producer_task_status(context) - if self.use_event(): - status = self._get_status_from_events(ti, context) - else: + if not self.use_event() and not self.is_test_sensor: self._log_startup_events(ti) - status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + status = self._get_node_status(ti, context) # compiled_sql is always in the canonical per-model XCom key (same for event and subprocess modes) if status is not None: @@ -518,4 +551,4 @@ def poke(self, context: Context) -> bool: elif is_dbt_node_status_success(status): return True else: - raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'") + raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'") diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 3b1f41fa01..b9a91f4f79 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -5,6 +5,7 @@ import json import zlib from collections.abc import AsyncIterator +from enum import Enum from typing import Any from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -25,6 +26,14 @@ logger = get_logger(__name__) +class WatcherEventReason(str, Enum): + """Reason codes used in TriggerEvent payloads between WatcherTrigger and BaseConsumerSensor.execute_complete.""" + + NODE_FAILED = "node_failed" + PRODUCER_FAILED = "producer_failed" + NODE_NOT_RUN = "node_not_run" + + class WatcherTrigger(BaseTrigger): def __init__( @@ -36,6 +45,7 @@ def __init__( map_index: int | None, use_event: bool, poke_interval: float = 5.0, + is_test_sensor: bool = False, ): self.model_unique_id = model_unique_id self.producer_task_id = producer_task_id @@ -44,6 +54,7 @@ def __init__( self.map_index = map_index self.use_event = use_event self.poke_interval = poke_interval + self.is_test_sensor = is_test_sensor def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -56,6 +67,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "map_index": self.map_index, "use_event": self.use_event, "poke_interval": self.poke_interval, + "is_test_sensor": self.is_test_sensor, }, ) @@ -127,9 +139,21 @@ async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str Parse node status and compiled_sql from XCom. Returns a tuple of (status, compiled_sql). - Status comes from mode-specific keys (nodefinished_* for event, *_status for subprocess). + + For test sensors (``is_test_sensor=True``), the aggregated test status is + read from the ``_tests_status`` key. No compiled_sql is relevant. + + For regular sensors, status comes from mode-specific keys + (nodefinished_* for event, *_status for subprocess). compiled_sql is always read from the canonical per-model key (same for both modes). """ + if self.is_test_sensor: + from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key + + status_key = get_tests_status_xcom_key(self.model_unique_id) + status = await self.get_xcom_val(status_key) + return status, None + compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql" status = await self._get_node_status() @@ -208,7 +232,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: return elif is_dbt_node_status_failed(dbt_node_status): logger.warning("dbt node '%s' failed", self.model_unique_id) - event_data = {"status": EventStatus.FAILED, "reason": "model_failed"} + event_data = {"status": EventStatus.FAILED, "reason": WatcherEventReason.NODE_FAILED} if compiled_sql: event_data["compiled_sql"] = compiled_sql yield TriggerEvent(event_data) # type: ignore[no-untyped-call] @@ -219,7 +243,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.producer_task_id, self.model_unique_id, ) - yield TriggerEvent({"status": EventStatus.FAILED, "reason": "producer_failed"}) # type: ignore[no-untyped-call] + yield TriggerEvent({"status": EventStatus.FAILED, "reason": WatcherEventReason.PRODUCER_FAILED}) # type: ignore[no-untyped-call] return elif producer_task_state == "success" and dbt_node_status is None: logger.info( @@ -227,7 +251,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.producer_task_id, self.model_unique_id, ) - yield TriggerEvent({"status": EventStatus.SUCCESS, "reason": "model_not_run"}) # type: ignore[no-untyped-call] + yield TriggerEvent({"status": EventStatus.SUCCESS, "reason": WatcherEventReason.NODE_NOT_RUN}) # type: ignore[no-untyped-call] return # Sleep briefly before re-polling diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 98b11a59d6..9b49631803 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -11,15 +11,6 @@ from airflow.exceptions import AirflowException from cosmos.config import ProfileConfig -from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push -from cosmos.operators._watcher.state import DBT_FAILED_STATUSES -from cosmos.settings import watcher_dbt_execution_queue - -try: - from airflow.providers.standard.operators.empty import EmptyOperator -except ImportError: # pragma: no cover - from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] - from cosmos.constants import ( _DBT_STARTUP_EVENTS_XCOM_KEY, PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, @@ -28,6 +19,7 @@ InvocationMode, ) from cosmos.log import get_logger +from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate from cosmos.operators._watcher.base import ( BaseConsumerSensor, @@ -35,6 +27,7 @@ store_compiled_sql_for_model, store_dbt_resource_status_from_log, ) +from cosmos.operators._watcher.state import DBT_FAILED_STATUSES from cosmos.operators.base import ( DbtBuildMixin, DbtRunMixin, @@ -46,6 +39,7 @@ DbtRunLocalOperator, DbtSourceLocalOperator, ) +from cosmos.settings import watcher_dbt_execution_queue try: from dbt_common.events.base_types import EventMsg @@ -345,13 +339,28 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) -class DbtTestWatcherOperator(EmptyOperator): - """ - As a starting point, this operator does nothing. - We'll be implementing this operator as part of: https://github.com/astronomer/astronomer-cosmos/issues/1974 +class DbtTestWatcherOperator(DbtConsumerWatcherSensor): # type: ignore[misc] + """Sensor that watches the aggregated test status for a dbt model in watcher execution mode. + + The producer task (``DbtProducerWatcherOperator``) collects individual test + results as they finish and, once every test for a given model has reported, + pushes a single aggregated XCom (``"pass"`` or ``"fail"``) under the key + ``_tests_status``. + + This sensor polls that key and: + * returns success when the value is ``"pass"``, + * raises ``AirflowException`` when the value is ``"fail"``. + + Deferral is fully supported: the ``WatcherTrigger`` receives + ``is_test_sensor=True`` and polls the correct aggregated key. """ - def __init__(self, *args: Any, **kwargs: Any): - desired_keys = ("dag", "task_group", "task_id") - new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys} - super().__init__(**new_kwargs) # type: ignore[no-untyped-call] + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields # type: ignore[operator] + + @property + def is_test_sensor(self) -> bool: + return True + + def use_event(self) -> bool: + """This sensor relies on the producer task pushing aggregated test results to XCom, so it does not use real-time events.""" + return False diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index b9cb8036ec..015e8d8b1e 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -4,7 +4,7 @@ from packaging.version import Version from cosmos.constants import _DBT_STARTUP_EVENTS_XCOM_KEY, AIRFLOW_VERSION -from cosmos.operators._watcher.triggerer import WatcherTrigger +from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger _STARTUP_EVENTS = [{"name": "MainReportVersion", "msg": "Running with dbt=1.0.0", "ts": ""}] @@ -118,6 +118,19 @@ async def mock_get_xcom_val(key): assert status == expected_status assert compiled_sql == expected_compiled_sql + async def test_parse_dbt_node_status_for_test_sensor(self): + """When is_test_sensor=True, _parse_dbt_node_status_and_compiled_sql reads the aggregated tests_status key.""" + self.trigger.is_test_sensor = True + self.trigger.model_unique_id = "model.jaffle_shop.stg_orders" + + mock_get_xcom_val = AsyncMock(return_value="pass") + with patch.object(self.trigger, "get_xcom_val", mock_get_xcom_val): + status, compiled_sql = await self.trigger._parse_dbt_node_status_and_compiled_sql() + + assert status == "pass" + assert compiled_sql is None + mock_get_xcom_val.assert_called_once_with("model__jaffle_shop__stg_orders_tests_status") + @pytest.mark.parametrize( "airflow_version, expected_val", [ @@ -140,9 +153,9 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): "dbt_node_status, producer_state, expected", [ ("success", "running", {"status": "success"}), - ("failed", "running", {"status": "failed", "reason": "model_failed"}), - (None, "failed", {"status": "failed", "reason": "producer_failed"}), - (None, "success", {"status": "success", "reason": "model_not_run"}), + ("failed", "running", {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}), + (None, "failed", {"status": "failed", "reason": WatcherEventReason.PRODUCER_FAILED}), + (None, "success", {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}), ], ) @patch("cosmos.operators._watcher.triggerer.WatcherTrigger._log_startup_events") @@ -236,7 +249,7 @@ def _import_side_effect(name: str, *args, **kwargs): @pytest.mark.asyncio async def test_run_producer_success_model_not_run(self, caplog): - """Test that when producer succeeds but model has no status, trigger yields success with model_not_run reason.""" + """Test that when producer succeeds but model has no status, trigger yields success with node_not_run reason.""" get_xcom_val_mock = AsyncMock( side_effect=lambda key: _STARTUP_EVENTS if key == _DBT_STARTUP_EVENTS_XCOM_KEY else None ) @@ -257,7 +270,7 @@ async def test_run_producer_success_model_not_run(self, caplog): events.append(event) assert len(events) == 1 - assert events[0].payload == {"status": "success", "reason": "model_not_run"} + assert events[0].payload == {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN} assert "The producer task 'task_1' succeeded" in caplog.text assert "There is no information about the node 'model.test' execution" in caplog.text @@ -309,7 +322,7 @@ async def get_xcom_val_side_effect(key): events = [event async for event in self.trigger.run()] assert len(events) == 1 assert events[0].payload["status"] == "failed" - assert events[0].payload["reason"] == "model_failed" + assert events[0].payload["reason"] == WatcherEventReason.NODE_FAILED assert events[0].payload["compiled_sql"] == "SELECT * FROM broken_model" @patch("cosmos.operators._watcher.triggerer.logger") diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 71ef484181..958dc06240 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -26,7 +26,7 @@ from cosmos.config import InvocationMode from cosmos.constants import _DBT_STARTUP_EVENTS_XCOM_KEY, PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode from cosmos.operators._watcher.base import _store_startup_event_from_log, store_compiled_sql_for_model -from cosmos.operators._watcher.triggerer import WatcherTrigger +from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger from cosmos.operators.watcher import ( DbtBuildWatcherOperator, DbtConsumerWatcherSensor, @@ -276,7 +276,7 @@ def test_handle_startup_event(): def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message(caplog): - """Test that execute_complete logs an info message when model was skipped (model_not_run).""" + """Test that execute_complete logs an info message when model was skipped (node_not_run).""" sensor = DbtConsumerWatcherSensor( project_dir=".", profiles_dir=".", @@ -289,7 +289,7 @@ def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message sensor.model_unique_id = "model.pkg.skipped_model" context = {"dag_run": MagicMock()} - event = {"status": "success", "reason": "model_not_run"} + event = {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN} with caplog.at_level(logging.INFO): sensor.execute_complete(context, event) @@ -334,14 +334,14 @@ def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): "event, expected_message", [ ({"status": "success"}, None), - ({"status": "success", "reason": "model_not_run"}, None), + ({"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}, None), ( - {"status": "failed", "reason": "model_failed"}, + {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}, "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher_operator' logs for details.", ), ( - {"status": "failed", "reason": "producer_failed"}, - "Watcher producer task 'dbt_producer_watcher_operator' failed before reporting model results. Check its logs for the underlying error.", + {"status": "failed", "reason": WatcherEventReason.PRODUCER_FAILED}, + "Watcher producer task 'dbt_producer_watcher_operator' failed before reporting results for model 'model.pkg.m'. Check its logs for the underlying error.", ), ], ) @@ -1522,8 +1522,8 @@ def test_deferrable_false_via_constructor_does_not_defer(self, mock_poke): @pytest.mark.parametrize( "mock_event", [ - {"status": "failed", "reason": "model_failed"}, - {"status": "failed", "reason": "producer_failed"}, + {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}, + {"status": "failed", "reason": WatcherEventReason.PRODUCER_FAILED}, {"status": "success"}, ], ) @@ -2174,3 +2174,104 @@ def test_dbt_source_watcher_operator_template_fields(): # DbtSourceWatcherOperator should inherit template_fields from DbtSourceLocalOperator assert DbtSourceWatcherOperator.template_fields == DbtSourceLocalOperator.template_fields + + +class TestDbtTestWatcherOperator: + """Tests for DbtTestWatcherOperator — the sensor that watches aggregated test results.""" + + MODEL_UID = "model.jaffle_shop.stg_orders" + TESTS_STATUS_XCOM_KEY = "model__jaffle_shop__stg_orders_tests_status" + + def make_sensor(self, **overrides): + extra_context = {"dbt_node_config": {"unique_id": self.MODEL_UID}} + sensor = DbtTestWatcherOperator( + task_id="stg_orders.test", + project_dir="/tmp/project", + profile_config=None, + deferrable=overrides.pop("deferrable", True), + extra_context=extra_context, + **overrides, + ) + sensor._get_producer_task_status = MagicMock(return_value=None) + return sensor + + def make_context(self, ti_mock, *, run_id="test-run", map_index=0): + return { + "ti": ti_mock, + "run_id": run_id, + "task_instance": MagicMock(map_index=map_index), + } + + # -- property checks -- + + def test_is_test_sensor_returns_true(self): + sensor = self.make_sensor() + assert sensor.is_test_sensor is True + + def test_resource_label_is_tests_for_model(self): + sensor = self.make_sensor() + assert sensor._resource_label == "Tests for model" + + def test_use_event_returns_false(self): + sensor = self.make_sensor() + assert sensor.use_event() is False + + # -- poke behaviour -- + + def test_poke_returns_true_when_tests_pass(self): + """When the aggregated test status is 'pass', poke should return True.""" + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = "pass" + context = self.make_context(ti) + + assert sensor.poke(context) is True + + def test_poke_raises_when_tests_fail(self): + """When the aggregated test status is 'fail', poke should raise AirflowException.""" + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = "fail" + context = self.make_context(ti) + + with pytest.raises(AirflowException, match="Tests for model"): + sensor.poke(context) + + def test_poke_returns_false_when_no_status_yet(self): + """When the aggregated test status has not been pushed yet, poke should return False.""" + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = None + context = self.make_context(ti) + + assert sensor.poke(context) is False + + def test_poke_reads_correct_xcom_key(self): + """Poke should pull from the _tests_status XCom key, not the regular _status key.""" + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = "pass" + context = self.make_context(ti) + + sensor.poke(context) + + # Verify xcom_pull was called with the aggregated tests_status key + calls = ti.xcom_pull.call_args_list + xcom_keys_used = [call.kwargs.get("key") or call[1].get("key") for call in calls] + assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used + + # -- retry / fallback -- + + def test_fallback_skips_reexecution_on_retry(self): + """On retry (try_number > 1), the test sensor should skip re-execution and return True.""" + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 2 + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is True From c27db6daac9412c7eb5b7ab2d989cb55a1ca4e21 Mon Sep 17 00:00:00 2001 From: Michal Mrazek Date: Fri, 20 Mar 2026 13:57:09 +0100 Subject: [PATCH 2/3] dont retry test for now --- cosmos/operators/_watcher/base.py | 8 +- tests/operators/test_watcher.py | 234 +++--------------------------- 2 files changed, 28 insertions(+), 214 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 9399591b80..58e07cea60 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -324,11 +324,11 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo For test sensors, re-execution is not supported in watcher mode; retries are skipped. """ if self.is_test_sensor: - logger.info( - "Test sensor retry detected (try #%s). Skipping — test re-execution is not supported in watcher mode.", - try_number, + raise AirflowException( + f"Test re-execution is not yet supported in watcher mode. " + f"{self._resource_label} '{self.model_unique_id}' cannot be retried. " + f"A future release will add fallback to local test execution." ) - return True logger.info( f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}", diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 958dc06240..8141c99bc7 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -8,7 +8,6 @@ from datetime import datetime, timedelta from pathlib import Path from types import SimpleNamespace -from typing import Any from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -24,8 +23,8 @@ from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior from cosmos.config import InvocationMode -from cosmos.constants import _DBT_STARTUP_EVENTS_XCOM_KEY, PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode -from cosmos.operators._watcher.base import _store_startup_event_from_log, store_compiled_sql_for_model +from cosmos.constants import PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode +from cosmos.operators._watcher.base import store_compiled_sql_for_model from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger from cosmos.operators.watcher import ( DbtBuildWatcherOperator, @@ -62,17 +61,12 @@ class _MockTI: def __init__(self) -> None: - self.store: dict[str, Any] = {} + self.store: dict[str, str] = {} self.try_number = 1 - def xcom_push(self, key: str, value: Any, **_): + def xcom_push(self, key: str, value: str, **_): self.store[key] = value - def xcom_pull(self, task_ids=None, key=None, **kwargs): - if key is not None: - return self.store.get(key) - return None - class _MockContext(dict): pass @@ -121,12 +115,6 @@ def test_dbt_producer_watcher_operator_priority_weight_default(): assert op.priority_weight == PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT -def test_dbt_producer_watcher_operator_does_not_store_compiled_sql_on_template(): - """Producer does not publish compiled_sql to its rendered_template; sensors show it per model.""" - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - assert op.should_store_compiled_sql is False - - @pytest.mark.parametrize( "queue_override, expected_queue", [ @@ -345,8 +333,7 @@ def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): ), ], ) -@patch("cosmos.operators._watcher.base._log_dbt_event") -def test_dbt_consumer_watcher_sensor_execute_complete(mock_dbt_event, event, expected_message): +def test_dbt_consumer_watcher_sensor_execute_complete(event, expected_message): sensor = DbtConsumerWatcherSensor( project_dir=".", profiles_dir=".", @@ -358,7 +345,7 @@ def test_dbt_consumer_watcher_sensor_execute_complete(mock_dbt_event, event, exp ) sensor.model_unique_id = "model.pkg.m" - context = {"dag_run": MagicMock(), "ti": MagicMock()} + context = {"dag_run": MagicMock()} if expected_message is None: sensor.execute_complete(context, event) @@ -469,10 +456,7 @@ def fake_base_execute(self, context=None, **_): # type: ignore[override] ): op.execute(context=ctx) - assert _DBT_STARTUP_EVENTS_XCOM_KEY in ti.store - startup_events = ti.store[_DBT_STARTUP_EVENTS_XCOM_KEY] - assert isinstance(startup_events, list) and len(startup_events) == 1 - assert startup_events[0]["name"] == "MainReportVersion" + assert "dbt_startup_events" in ti.store node_key = "nodefinished_model__pkg__x" assert node_key in ti.store @@ -731,59 +715,6 @@ def test_store_dbt_resource_status_from_log_handles_missing_node_info(self): # No status should be stored assert len(ti.store) == 0 - def test_store_startup_event_from_log_appends_to_dbt_startup_events(self): - """Test that _store_startup_event_from_log appends MainReportVersion/AdapterRegistered to dbt_startup_events.""" - # Use a minimal mock with xcom_pull (needed by _store_startup_event_from_log); _MockTI does not have it. - store = {} - - class _TIMock: - def xcom_push(self, key, value, **_): - store[key] = value - - def xcom_pull(self, task_ids=None, key=None, **_): - return store.get(key) if key else None - - ti = _TIMock() - _store_startup_event_from_log( - ti, - {"info": {"name": "MainReportVersion", "msg": "Running with dbt=1.10.0", "ts": "2025-01-01T12:00:00Z"}}, - ) - assert _DBT_STARTUP_EVENTS_XCOM_KEY in store - events = store[_DBT_STARTUP_EVENTS_XCOM_KEY] - assert len(events) == 1 - assert events[0]["name"] == "MainReportVersion" - assert events[0]["msg"] == "Running with dbt=1.10.0" - - _store_startup_event_from_log( - ti, - { - "info": { - "name": "AdapterRegistered", - "msg": "Registered adapter: postgres=1.10.0", - "ts": "2025-01-01T12:00:01Z", - } - }, - ) - events = store[_DBT_STARTUP_EVENTS_XCOM_KEY] - assert len(events) == 2 - assert events[1]["name"] == "AdapterRegistered" - assert events[1]["msg"] == "Registered adapter: postgres=1.10.0" - - def test_store_startup_event_from_log_ignores_other_events(self): - """Test that _store_startup_event_from_log ignores events other than MainReportVersion/AdapterRegistered.""" - store = {} - - class _TIMock: - def xcom_push(self, key, value, **_): - store[key] = value - - def xcom_pull(self, task_ids=None, key=None, **_): - return store.get(key) if key else None - - ti = _TIMock() - _store_startup_event_from_log(ti, {"info": {"name": "NodeFinished", "msg": "Done"}}) - assert _DBT_STARTUP_EVENTS_XCOM_KEY not in store - @pytest.mark.parametrize( "msg, level", [ @@ -891,8 +822,7 @@ def test_process_log_line_callable_integration_with_subprocess_pattern(self): assert ti.store.get("model__pkg__model_a_status") == "success" assert ti.store.get("model__pkg__model_b_status") == "failed" - assert "model__pkg__model_a_status" in ti.store - assert "model__pkg__model_b_status" in ti.store + assert len(ti.store) == 2 # Only success and failed statuses are stored def test_store_dbt_resource_status_from_log_pushes_compiled_sql_for_models(self, tmp_path): """Test that compiled_sql is pushed to XCom for successful model nodes.""" @@ -1205,8 +1135,7 @@ def test_poke_status_none_from_events(self, MockEventMsg): result = sensor.poke(context) assert result is False - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_poke_success_from_run_results(self, mock_startup_events): + def test_poke_success_from_run_results(self): sensor = self.make_sensor() sensor.invocation_mode = "SUBPROCESS" @@ -1219,8 +1148,7 @@ def test_poke_success_from_run_results(self, mock_startup_events): assert result is True @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_startup_events, mock_override_rtif): + def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_override_rtif): """Test that in subprocess mode, poke extracts compiled_sql from per-model XCom key when status is success.""" sensor = self.make_sensor() sensor.invocation_mode = "SUBPROCESS" @@ -1228,8 +1156,8 @@ def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_startup ti = MagicMock() ti.try_number = 1 - # First call returns status from per-model XCom key, second call returns compiled_sql, third is _dbt_event for _log_dbt_event - ti.xcom_pull.side_effect = ["success", "SELECT * FROM orders", None] + # First call returns status from per-model XCom key, second call returns compiled_sql from per-model XCom key + ti.xcom_pull.side_effect = ["success", "SELECT * FROM orders"] context = self.make_context(ti) assert sensor.compiled_sql == "" # Initially empty @@ -1250,8 +1178,8 @@ def test_poke_event_mode_extracts_compiled_sql_from_canonical_key( ti = MagicMock() ti.try_number = 1 - # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then get_xcom_val(compiled_sql_key), get_xcom_val(_dbt_event) - ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders", None] + # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then get_xcom_val(compiled_sql_key) + ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders"] context = self.make_context(ti) assert sensor.compiled_sql == "" # Initially empty @@ -1272,8 +1200,7 @@ def _fallback_to_non_watcher_run(self, mock_get_producer_task_status): result = sensor.poke(context) assert result is True - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_poke_failure_from_run_results(self, mock_startup_events): + def test_poke_failure_from_run_results(self): sensor = self.make_sensor() sensor.invocation_mode = "OTHER_MODE" @@ -1339,32 +1266,6 @@ def test_get_status_from_run_results_success(self): result = sensor._get_status_from_run_results(ti, _MockContext(ti=ti)) assert result == "success" - def test_get_status_from_run_results_logs_error(self): - sensor = self.make_sensor() - ti = MagicMock() - - run_results_payload = { - "results": [ - { - "unique_id": sensor.model_unique_id, - "status": "error", - "message": "dbt model failed", - "compiled_code": "select 1", - } - ] - } - - encoded = base64.b64encode(zlib.compress(json.dumps(run_results_payload).encode())).decode("utf-8") - - ti.xcom_pull.return_value = encoded - context = self.make_context(ti) - - with patch("cosmos.operators._watcher.base.logger") as mock_logger: - result = sensor._get_status_from_run_results(ti, context) - - assert result == "error" - mock_logger.error.assert_called_once() - def test_get_status_from_run_results_none(self): sensor = self.make_sensor() ti = MagicMock() @@ -1404,27 +1305,8 @@ def test_get_status_from_events_does_not_set_compiled_sql_from_event(self): assert result == "success" assert sensor.compiled_sql == "" # not set from event; poke() will get it from canonical key - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_get_status_logs_error(self, mock_log_startup_events): - sensor = self.make_sensor() - ti = MagicMock() - - event_payload = {"data": {"run_result": {"status": "error", "message": "dbt model failed"}}} - - encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") - - ti.xcom_pull.return_value = encoded_event - context = self.make_context(ti) - - with patch("cosmos.operators.watcher.logger.error") as mock_log_error: - result = sensor._get_status_from_events(ti, context) - - assert result == "error" - mock_log_error.assert_called_once_with("%s", "dbt model failed") - @patch("cosmos.operators._watcher.base.get_xcom_val") - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_producer_state_failed(self, mock_startup_events, mock_get_xcom_val): + def test_producer_state_failed(self, mock_get_xcom_val): sensor = self.make_sensor() sensor._get_producer_task_status.return_value = "failed" ti = MagicMock() @@ -1443,9 +1325,8 @@ def test_producer_state_failed(self, mock_startup_events, mock_get_xcom_val): @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._fallback_to_non_watcher_run") @patch("cosmos.operators._watcher.base.get_xcom_val") - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") def test_producer_state_does_not_fail_if_previously_upstream_failed( - self, mock_startup_events, mock_get_xcom_val, mock_fallback_to_non_watcher_run + self, mock_get_xcom_val, mock_fallback_to_non_watcher_run ): """ Attempt to run the task using ExecutionMode.LOCAL if State.UPSTREAM_FAILED happens. @@ -1527,19 +1408,13 @@ def test_deferrable_false_via_constructor_does_not_defer(self, mock_poke): {"status": "success"}, ], ) - @patch("cosmos.operators._watcher.base._log_dbt_event") - def test_execute_complete(self, mock_log_dbt_event, mock_event): + def test_execute_complete(self, mock_event): sensor = self.make_sensor() - - ti = MagicMock() - context = {"ti": ti} - if mock_event.get("status") == "failed": with pytest.raises(AirflowException): - sensor.execute_complete(context=context, event=mock_event) + sensor.execute_complete(context=Mock(), event=mock_event) else: - result = sensor.execute_complete(context=context, event=mock_event) - assert result is None + assert sensor.execute_complete(context=Mock(), event=mock_event) is None @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") def test_execute_complete_extracts_compiled_sql(self, mock_override_rtif): @@ -1654,57 +1529,6 @@ async def mock_get_xcom_val(key): assert status == "success" assert compiled_sql == "SELECT id FROM users" - @pytest.mark.asyncio - async def test_log_startup_events_returns_when_events_available(self, caplog): - """Test that _log_startup_events returns once dbt_startup_events is available and logs.""" - trigger = self.make_trigger() - events = [ - {"name": "MainReportVersion", "msg": "Running with dbt=1.10.0", "ts": "2025-01-01T12:00:00Z"}, - {"name": "AdapterRegistered", "msg": "Registered adapter: postgres=1.10.0", "ts": "2025-01-01T12:00:01Z"}, - ] - call_count = 0 - - async def mock_get_xcom_val(key): - nonlocal call_count - call_count += 1 - if key == _DBT_STARTUP_EVENTS_XCOM_KEY: - return events - return None - - async def mock_producer_running(): - return None # not failed - - trigger.get_xcom_val = mock_get_xcom_val - trigger._get_producer_task_status = mock_producer_running - - with caplog.at_level(logging.INFO): - await trigger._log_startup_events() - - assert "Running with dbt=1.10.0" in caplog.text - assert "Registered adapter: postgres=1.10.0" in caplog.text - assert call_count >= 1 - - @pytest.mark.asyncio - async def test_wait_and_log_startup_events_returns_when_producer_failed(self): - """Test that _log_startup_events returns without blocking when producer task failed.""" - trigger = self.make_trigger() - call_count = 0 - - async def mock_get_xcom_val(key): - nonlocal call_count - call_count += 1 - return None # no events yet - - async def mock_producer_failed(): - return "failed" - - trigger.get_xcom_val = mock_get_xcom_val - trigger._get_producer_task_status = mock_producer_failed - - await trigger._log_startup_events() - - assert call_count >= 1 - @pytest.mark.integration def test_dbt_dag_with_watcher(capsys): @@ -2202,22 +2026,14 @@ def make_context(self, ti_mock, *, run_id="test-run", map_index=0): "task_instance": MagicMock(map_index=map_index), } - # -- property checks -- - def test_is_test_sensor_returns_true(self): sensor = self.make_sensor() assert sensor.is_test_sensor is True - def test_resource_label_is_tests_for_model(self): - sensor = self.make_sensor() - assert sensor._resource_label == "Tests for model" - def test_use_event_returns_false(self): sensor = self.make_sensor() assert sensor.use_event() is False - # -- poke behaviour -- - def test_poke_returns_true_when_tests_pass(self): """When the aggregated test status is 'pass', poke should return True.""" sensor = self.make_sensor() @@ -2264,14 +2080,12 @@ def test_poke_reads_correct_xcom_key(self): xcom_keys_used = [call.kwargs.get("key") or call[1].get("key") for call in calls] assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used - # -- retry / fallback -- - - def test_fallback_skips_reexecution_on_retry(self): - """On retry (try_number > 1), the test sensor should skip re-execution and return True.""" + def test_fallback_raises_on_retry(self): + """On retry (try_number > 1), the test sensor should raise since test re-execution is not yet supported.""" sensor = self.make_sensor() ti = MagicMock() ti.try_number = 2 context = self.make_context(ti) - result = sensor.poke(context) - assert result is True + with pytest.raises(AirflowException, match="Test re-execution is not yet supported"): + sensor.poke(context) From 614ada301b197e404c8269f1cc187abde52a2c03 Mon Sep 17 00:00:00 2001 From: Michal Mrazek Date: Tue, 24 Mar 2026 08:52:42 +0100 Subject: [PATCH 3/3] tests --- tests/operators/test_watcher.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 43351dba76..5a2b0173b2 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -314,7 +314,9 @@ def test_dbt_consumer_watcher_sensor_execute_complete(event, expected_message): ) sensor.model_unique_id = "model.pkg.m" - context = {"dag_run": MagicMock()} + ti = MagicMock() + ti.xcom_pull.return_value = None + context = {"dag_run": MagicMock(), "ti": ti} if expected_message is None: sensor.execute_complete(context, event) @@ -1110,7 +1112,8 @@ def test_poke_success_from_run_results(self): ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.return_value = "success" + # xcom_pull calls: _log_startup_events=None, _get_node_status="success", compiled_sql=None, _dbt_event=None + ti.xcom_pull.side_effect = [None, "success", None, None] context = self.make_context(ti) result = sensor.poke(context) @@ -1125,8 +1128,8 @@ def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_overrid ti = MagicMock() ti.try_number = 1 - # First call returns status from per-model XCom key, second call returns compiled_sql from per-model XCom key - ti.xcom_pull.side_effect = ["success", "SELECT * FROM orders"] + # xcom_pull calls: _log_startup_events=None, _get_node_status="success", compiled_sql="SELECT * FROM orders", _dbt_event=None + ti.xcom_pull.side_effect = [None, "success", "SELECT * FROM orders", None] context = self.make_context(ti) assert sensor.compiled_sql == "" # Initially empty @@ -1147,8 +1150,8 @@ def test_poke_event_mode_extracts_compiled_sql_from_canonical_key( ti = MagicMock() ti.try_number = 1 - # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then get_xcom_val(compiled_sql_key) - ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders"] + # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then compiled_sql, then _dbt_event=None + ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders", None] context = self.make_context(ti) assert sensor.compiled_sql == "" # Initially empty @@ -1175,7 +1178,8 @@ def test_poke_failure_from_run_results(self): ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS_FAILED + # xcom_pull calls: _log_startup_events=None, _get_node_status=ENCODED_RUN_RESULTS_FAILED, compiled_sql=None, _dbt_event=None + ti.xcom_pull.side_effect = [None, ENCODED_RUN_RESULTS_FAILED, None, None] context = self.make_context(ti) with pytest.raises(AirflowException): @@ -1282,7 +1286,8 @@ def test_producer_state_failed(self, mock_get_xcom_val): ti.try_number = 1 sensor.poke_retry_number = 1 mock_get_xcom_val.return_value = None - ti.xcom_pull.return_value = "failed" + # _log_startup_events still calls ti.xcom_pull directly + ti.xcom_pull.return_value = None context = self.make_context(ti) @@ -1307,7 +1312,8 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed( ti.try_number = 1 sensor.poke_retry_number = 0 mock_get_xcom_val.return_value = None - ti.xcom_pull.return_value = "failed" + # _log_startup_events still calls ti.xcom_pull directly + ti.xcom_pull.return_value = None context = self.make_context(ti) @@ -1379,11 +1385,14 @@ def test_deferrable_false_via_constructor_does_not_defer(self, mock_poke): ) def test_execute_complete(self, mock_event): sensor = self.make_sensor() + ti = MagicMock() + ti.xcom_pull.return_value = None + context = self.make_context(ti) if mock_event.get("status") == "failed": with pytest.raises(AirflowException): - sensor.execute_complete(context=Mock(), event=mock_event) + sensor.execute_complete(context=context, event=mock_event) else: - assert sensor.execute_complete(context=Mock(), event=mock_event) is None + assert sensor.execute_complete(context=context, event=mock_event) is None @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") def test_execute_complete_extracts_compiled_sql(self, mock_override_rtif):