diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index aa47e81291..955746b90e 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -28,6 +28,7 @@ is_dbt_node_status_skipped, is_dbt_node_status_success, is_dbt_node_status_terminal, + is_producer_task_terminated, safe_xcom_push, xcom_set_lock, ) @@ -583,6 +584,28 @@ def _cache_compiled_sql(self, ti: Any, context: Context) -> None: if hasattr(self, "_override_rtif"): self._override_rtif(context) + def _handle_retry(self, try_number: int, producer_task_state: str | None, context: Context) -> bool | None: + """Handle sensor retry by checking whether the producer is still active. + + Returns the fallback result if the producer has terminated, or None if + the sensor should continue polling (producer still active). + """ + if is_producer_task_terminated(producer_task_state): + # Producer finished — this is either an automatic retry after + # the producer completed or a manual task clear from the UI. + # Fall back to a non-watcher run. + return self._fallback_to_non_watcher_run(try_number, context) + # Producer is still active — the sensor likely timed out while the + # producer was still working. Keep polling instead of launching a + # duplicate dbt run. + logger.info( + "Try #%s but producer '%s' is still %s — continuing to poll instead of fallback.", + try_number, + self.producer_task_id, + producer_task_state or "unknown", + ) + return None + def poke(self, context: Context) -> bool: """ Checks the status of a dbt node (model or aggregated tests) by pulling relevant XComs from the producer task. @@ -605,10 +628,13 @@ def poke(self, context: Context) -> bool: self.model_unique_id, ) + producer_task_state = self._get_producer_task_status(context) + if try_number > 1: - return self._fallback_to_non_watcher_run(try_number, context) + retry_result = self._handle_retry(try_number, producer_task_state, context) + if retry_result is not None: + return retry_result - producer_task_state = self._get_producer_task_status(context) if not self.is_test_sensor: self._log_startup_events(ti) status = self._get_node_status(ti, context) @@ -624,8 +650,13 @@ def poke(self, context: Context) -> bool: ) _log_dbt_event(dbt_events) - if status is None: + return self._evaluate_node_status(status, producer_task_state, try_number, context) + def _evaluate_node_status( + self, status: Any, producer_task_state: str | None, try_number: int, context: Context + ) -> bool: + """Evaluate the dbt node status and return the poke result.""" + if status is None: if producer_task_state == "failed": if self.poke_retry_number > 0: raise AirflowException( @@ -636,13 +667,12 @@ def poke(self, context: Context) -> bool: return self._fallback_to_non_watcher_run(try_number, context) self.poke_retry_number += 1 - return False - elif is_dbt_node_status_skipped(status): + + if is_dbt_node_status_skipped(status): raise AirflowSkipException( f"{self._resource_label} '{self.model_unique_id}' was skipped by the dbt command." ) - elif is_dbt_node_status_success(status): + if is_dbt_node_status_success(status): return True - else: - raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'") + raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'") diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index 27d3b488b0..fb85f469ec 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -25,6 +25,10 @@ DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"}) DBT_SKIPPED_STATUSES = frozenset({"skipped"}) +# Airflow task states that indicate the producer has finished and will not deliver any more XCom updates. +# Used to decide whether a sensor retry should fall back to a non-watcher run or keep polling. +PRODUCER_TERMINAL_STATES = frozenset({"success", "failed", "skipped", "upstream_failed", "removed"}) + class DbtTestStatus(str, Enum): """Aggregated status of all tests for a given model.""" @@ -55,6 +59,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool: return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status) +def is_producer_task_terminated(state: str | None) -> bool: + """Return True when the producer task is in a terminal state.""" + return state in PRODUCER_TERMINAL_STATES + + xcom_set_lock = Lock() diff --git a/tests/operators/_watcher/test_state.py b/tests/operators/_watcher/test_state.py index e027ae3744..54b3a01fe5 100644 --- a/tests/operators/_watcher/test_state.py +++ b/tests/operators/_watcher/test_state.py @@ -12,6 +12,7 @@ is_dbt_node_status_skipped, is_dbt_node_status_success, is_dbt_node_status_terminal, + is_producer_task_terminated, ) @@ -51,6 +52,18 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None): assert is_dbt_node_status_terminal(status) is False +class TestProducerTaskTerminated: + """Tests for is_producer_task_terminated helper.""" + + @pytest.mark.parametrize("state", ["success", "failed", "skipped", "upstream_failed", "removed"]) + def test_terminal_states(self, state: str): + assert is_producer_task_terminated(state) is True + + @pytest.mark.parametrize("state", ["running", "deferred", "queued", "scheduled", "up_for_reschedule", None, ""]) + def test_non_terminal_states(self, state: str | None): + assert is_producer_task_terminated(state) is False + + @pytest.mark.parametrize( "dbt_event,expect_error,status", [ diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 62f78cb647..9083731e13 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -1032,8 +1032,10 @@ def test_poke_failure(self): sensor.poke(context) @patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd") - def test_task_retry(self, mock_build_and_run_cmd): + def test_task_retry_fallback_when_producer_terminated(self, mock_build_and_run_cmd): + """On retry, if the producer has already finished, fall back to running the model locally.""" sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 ti.xcom_pull.return_value = None @@ -1042,6 +1044,22 @@ def test_task_retry(self, mock_build_and_run_cmd): sensor.poke(context) mock_build_and_run_cmd.assert_called_once() + @patch("cosmos.operators._watcher.base.get_xcom_val") + def test_task_retry_keeps_polling_when_producer_still_running(self, mock_get_xcom_val): + """On retry, if the producer is still running, keep polling instead of launching a duplicate dbt run.""" + sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "running" + ti = MagicMock() + ti.try_number = 2 + # _log_startup_events=None, _get_node_status=None, compiled_sql (skipped), _dbt_event=None + ti.xcom_pull.return_value = None + mock_get_xcom_val.return_value = None + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is False + assert sensor.poke_retry_number == 1 + def test_fallback_to_non_watcher_run(self): sensor = self.make_sensor() ti = MagicMock() @@ -1964,8 +1982,9 @@ def test_poke_reads_correct_xcom_key(self): assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used def test_fallback_raises_on_retry(self): - """On retry (try_number > 1), the test sensor should raise since test re-execution is not yet supported.""" + """On retry (try_number > 1) with a terminated producer, the test sensor should raise since test re-execution is not yet supported.""" sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 context = self.make_context(ti) @@ -1973,6 +1992,18 @@ def test_fallback_raises_on_retry(self): with pytest.raises(AirflowException, match="Test re-execution is not yet supported"): sensor.poke(context) + def test_retry_keeps_polling_when_producer_still_running(self): + """On retry, if the producer is still running, the test sensor should keep polling instead of raising.""" + sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "running" + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is False + class TestDefaultFreshnessCallback: """Tests for the _default_freshness_callback function.""" diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index 2d5f8855b8..5b823000c1 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -174,10 +174,11 @@ def test_first_execution_behaves_as_base_consumer_sensor(mock_startup_events): @patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_and_run_cmd") def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): """ - On retry (try_number > 1), the sensor should fall back to executing - as DbtRunKubernetesOperator by calling build_and_run_cmd. + On retry (try_number > 1) with a terminated producer, the sensor should + fall back to executing as DbtRunKubernetesOperator by calling build_and_run_cmd. """ sensor = make_sensor() + sensor._get_producer_task_status.return_value = "success" ti = MagicMock() ti.try_number = 2 @@ -191,6 +192,28 @@ def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): mock_build_and_run_cmd.assert_called_once() +@patch("cosmos.operators._watcher.base.get_xcom_val") +@patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") +def test_retry_keeps_polling_when_producer_still_running(mock_startup_events, mock_get_xcom_val): + """ + On retry (try_number > 1) with the producer still running, the sensor + should keep polling instead of launching a duplicate dbt run. + """ + sensor = make_sensor() + sensor._get_producer_task_status.return_value = "running" + + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + mock_get_xcom_val.return_value = None + context = make_context(ti) + + result = sensor.poke(context) + + assert result is False + assert sensor.poke_retry_number == 1 + + class TestCallbacksNormalization: """Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator."""