Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
is_producer_task_terminated,
safe_xcom_push,
xcom_set_lock,
)
Expand Down Expand Up @@ -600,6 +601,28 @@ def _cache_compiled_sql(self, ti: Any, context: Context) -> None:
if hasattr(self, "_override_rtif"):
self._override_rtif(context)

def _handle_retry(self, try_number: int, producer_task_state: str | None, context: Context) -> bool | None:
"""Handle sensor retry by checking whether the producer is still active.

Returns the fallback result if the producer has terminated, or None if
the sensor should continue polling (producer still active).
"""
if is_producer_task_terminated(producer_task_state):
# Producer finished — this is either an automatic retry after
# the producer completed or a manual task clear from the UI.
# Fall back to a non-watcher run.
return self._fallback_to_non_watcher_run(try_number, context)
# Producer is still active — the sensor likely timed out while the
# producer was still working. Keep polling instead of launching a
# duplicate dbt run.
logger.info(
"Try #%s but producer '%s' is still %s — continuing to poll instead of fallback.",
try_number,
self.producer_task_id,
producer_task_state or "unknown",
)
return None
Comment thread
pankajkoti marked this conversation as resolved.

def _handle_no_dbt_node_status(self, producer_task_state: str | None, try_number: int, context: Context) -> bool:
"""Handle the case where no dbt node status has been reported yet."""
if producer_task_state == "failed":
Expand Down Expand Up @@ -639,10 +662,13 @@ def poke(self, context: Context) -> bool:
self.model_unique_id,
)

producer_task_state = self._get_producer_task_status(context)

if try_number > 1:
return self._fallback_to_non_watcher_run(try_number, context)
retry_result = self._handle_retry(try_number, producer_task_state, context)
if retry_result is not None:
return retry_result

producer_task_state = self._get_producer_task_status(context)
if not self.is_test_sensor:
self._log_startup_events(ti)
status = self._get_node_status(ti, context)
Expand Down
9 changes: 9 additions & 0 deletions cosmos/operators/_watcher/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"})
DBT_SKIPPED_STATUSES = frozenset({"skipped"})

# Airflow task states that indicate the producer has finished and will not deliver any more XCom updates.
# Used to decide whether a sensor retry should fall back to a non-watcher run or keep polling.
PRODUCER_TERMINAL_STATES = frozenset({"success", "failed", "skipped", "upstream_failed", "removed"})


class DbtTestStatus(str, Enum):
"""Aggregated status of all tests for a given model."""
Expand Down Expand Up @@ -55,6 +59,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool:
return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status)


def is_producer_task_terminated(state: str | None) -> bool:
"""Return True when the producer task is in a terminal state."""
return state in PRODUCER_TERMINAL_STATES


xcom_set_lock = Lock()


Expand Down
13 changes: 13 additions & 0 deletions tests/operators/_watcher/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
is_producer_task_terminated,
safe_xcom_push,
)
from cosmos.operators._watcher.xcom import (
Expand Down Expand Up @@ -63,6 +64,18 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None):
assert is_dbt_node_status_terminal(status) is False


class TestProducerTaskTerminated:
"""Tests for is_producer_task_terminated helper."""

@pytest.mark.parametrize("state", ["success", "failed", "skipped", "upstream_failed", "removed"])
def test_terminal_states(self, state: str):
assert is_producer_task_terminated(state) is True

@pytest.mark.parametrize("state", ["running", "deferred", "queued", "scheduled", "up_for_reschedule", None, ""])
def test_non_terminal_states(self, state: str | None):
assert is_producer_task_terminated(state) is False


@pytest.mark.parametrize(
"dbt_event,expect_error,status",
[
Expand Down
35 changes: 33 additions & 2 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,10 @@ def test_poke_failure(self):
sensor.poke(context)

@patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd")
def test_task_retry(self, mock_build_and_run_cmd):
def test_task_retry_fallback_when_producer_terminated(self, mock_build_and_run_cmd):
"""On retry, if the producer has already finished, fall back to running the model locally."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "success"
ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
Expand All @@ -1078,6 +1080,22 @@ def test_task_retry(self, mock_build_and_run_cmd):
sensor.poke(context)
mock_build_and_run_cmd.assert_called_once()

@patch("cosmos.operators._watcher.base.get_xcom_val")
def test_task_retry_keeps_polling_when_producer_still_running(self, mock_get_xcom_val):
"""On retry, if the producer is still running, keep polling instead of launching a duplicate dbt run."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "running"
ti = MagicMock()
ti.try_number = 2
# _log_startup_events=None, _get_node_status=None, compiled_sql (skipped), _dbt_event=None
ti.xcom_pull.return_value = None
mock_get_xcom_val.return_value = None
context = self.make_context(ti)

result = sensor.poke(context)
assert result is False
assert sensor.poke_retry_number == 1

def test_fallback_to_non_watcher_run(self):
sensor = self.make_sensor()
ti = MagicMock()
Expand Down Expand Up @@ -2049,15 +2067,28 @@ def test_poke_reads_correct_xcom_key(self):
assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used

def test_fallback_raises_on_retry(self):
"""On retry (try_number > 1), the test sensor should raise since test re-execution is not yet supported."""
"""On retry (try_number > 1) with a terminated producer, the test sensor should raise since test re-execution is not yet supported."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "success"
ti = MagicMock()
ti.try_number = 2
context = self.make_context(ti)

with pytest.raises(AirflowException, match="Test re-execution is not yet supported"):
sensor.poke(context)

def test_retry_keeps_polling_when_producer_still_running(self):
"""On retry, if the producer is still running, the test sensor should keep polling instead of raising."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "running"
ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
context = self.make_context(ti)

result = sensor.poke(context)
assert result is False


class TestDefaultFreshnessCallback:
"""Tests for the _default_freshness_callback function."""
Expand Down
27 changes: 25 additions & 2 deletions tests/operators/test_watcher_kubernetes_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ def test_first_execution_behaves_as_base_consumer_sensor(mock_startup_events):
@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_and_run_cmd")
def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd):
"""
On retry (try_number > 1), the sensor should fall back to executing
as DbtRunKubernetesOperator by calling build_and_run_cmd.
On retry (try_number > 1) with a terminated producer, the sensor should
fall back to executing as DbtRunKubernetesOperator by calling build_and_run_cmd.
"""
sensor = make_sensor()
sensor._get_producer_task_status.return_value = "success"

ti = MagicMock()
ti.try_number = 2
Expand All @@ -237,6 +238,28 @@ def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd):
mock_build_and_run_cmd.assert_called_once()


@patch("cosmos.operators._watcher.base.get_xcom_val")
@patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events")
def test_retry_keeps_polling_when_producer_still_running(mock_startup_events, mock_get_xcom_val):
"""
On retry (try_number > 1) with the producer still running, the sensor
should keep polling instead of launching a duplicate dbt run.
"""
sensor = make_sensor()
sensor._get_producer_task_status.return_value = "running"

ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
mock_get_xcom_val.return_value = None
context = make_context(ti)

result = sensor.poke(context)

assert result is False
assert sensor.poke_retry_number == 1


class TestCallbacksNormalization:
"""Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator."""

Expand Down
Loading