Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
is_producer_task_terminated,
safe_xcom_push,
xcom_set_lock,
)
Expand Down Expand Up @@ -583,6 +584,28 @@ def _cache_compiled_sql(self, ti: Any, context: Context) -> None:
if hasattr(self, "_override_rtif"):
self._override_rtif(context)

def _handle_retry(self, try_number: int, producer_task_state: str | None, context: Context) -> bool | None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it would be a cleaner interface if we renamed things to is_producer_still_running.
Retry is a consequence depending on other variables - including if the producer is actively running the task

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe _handle_retry still makes sense as the name here. It's called from the poke method when try_number > 1, i.e., when we're in an Airflow retry scenario, and this method, as part of handling the Airflow retry, decides whether to actually fall back (producer terminated) or keep polling (producer still active).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions should be named based on what they do, not how they are used. The current name (_handle_retry) is related to where it is used, rather than to its name.

Naming functions after their action creates reusable, predictable code. If a function is named after its specific use case, it becomes difficult to reuse that same logic elsewhere.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit torn here. _handle_retry actually does describe what the method does; it handles the retry scenario. It checks the producer state, decides whether to fall back or keep polling, and acts accordingly. That is the method's responsibility. It's not named after its call site; it's named after the action it performs.

When the time comes for reusing (although this function is marked local to the module/class), I believe these are what the actions will be for handling retries for the case, no? Or do you see a different path there?

We already have is_producer_task_terminated for the state check. The suggested is_producer_still_running shouldn't call an action based on its name, so the fallback logic would have to move elsewhere, either into another method or back into poke. If it goes into poke, it's not reusable either.

"""Handle sensor retry by checking whether the producer is still active.

Returns the fallback result if the producer has terminated, or None if
the sensor should continue polling (producer still active).
"""
if is_producer_task_terminated(producer_task_state):
# Producer finished — this is either an automatic retry after
# the producer completed or a manual task clear from the UI.
# Fall back to a non-watcher run.
return self._fallback_to_non_watcher_run(try_number, context)
# Producer is still active — the sensor likely timed out while the
# producer was still working. Keep polling instead of launching a
# duplicate dbt run.
logger.info(
"Try #%s but producer '%s' is still %s — continuing to poll instead of fallback.",
try_number,
self.producer_task_id,
producer_task_state or "unknown",
)
return None

def poke(self, context: Context) -> bool:
"""
Checks the status of a dbt node (model or aggregated tests) by pulling relevant XComs from the producer task.
Expand All @@ -605,10 +628,13 @@ def poke(self, context: Context) -> bool:
self.model_unique_id,
)

producer_task_state = self._get_producer_task_status(context)

if try_number > 1:
return self._fallback_to_non_watcher_run(try_number, context)
retry_result = self._handle_retry(try_number, producer_task_state, context)
if retry_result is not None:
return retry_result

producer_task_state = self._get_producer_task_status(context)
if not self.is_test_sensor:
self._log_startup_events(ti)
status = self._get_node_status(ti, context)
Expand All @@ -624,8 +650,13 @@ def poke(self, context: Context) -> bool:
)
_log_dbt_event(dbt_events)

if status is None:
return self._evaluate_node_status(status, producer_task_state, try_number, context)

def _evaluate_node_status(
Comment on lines +653 to +655
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to a new method, as this was getting complained about code complexity

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pankajkoti Would there be an alternative way to refactor?

The function name (_evaluate_node_status) does not seem accurate with what it is doing, which is actually to:

  • check the status of the node
  • check the status of the producer airflow task
  • maybe execute dbt itself to run, as a fallback

These steps feel quite critical and relevant in the poke method, and it is probably hard to find a suitable name in a single method. There may be other parts of the function that could be refactored and isolated that are not so core to the logic (maybe the retrieval of status, or the logging handling?)

Copy link
Copy Markdown
Contributor Author

@pankajkoti pankajkoti Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name broadly reflects the method’s primary responsibility: it evaluates the node status and determines the appropriate poke outcome. By "evaluate/evaluation", I’m including checking dependencies like the producer task state and deciding on actions such as fallback, retry, raise, or succeed. The producer state check and fallback are part of handling cases where no direct status is available.

That said, I agree it’s doing a few important things, so I’m open to renaming if you have a more precise suggestion.

On extracting retrieval/logging: those are already separated into helper methods (_log_startup_events, _get_node_status, _cache_compiled_sql). Wrapping them further would likely add indirection without much benefit. The evaluation logic, on the other hand, forms a cohesive decision tree, which is why it felt reasonable to group it into a single method.

Happy to refactor further if you feel a different split would improve readability.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function combines decision-making, state retrieval, and side effects (running dbt). While it is small, it has a lot of responsibility. I'm worried about future readability and testability.

Maybe we could break it down into two smaller methods with clear responsibilities, something along the lines:

  • _evaluate_node_status: implements the decision tree logic
  • _handle_node_decision(decision, ...): executes side effects - run dbt or others

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the method again, it doesn’t actually do state retrieval. Both status and producer_task_state are passed in; the retrieval happens in poke. So this method is really decision-making and taking actions based on the decisions.

I feel splitting it would introduce indirection without reducing coupling; the action method would still need the same parameters (try_number, context) for the fallback path.

Right now, the full flow is visible in one place:

  • status is None -> check producer -> fallback or keep polling
  • skipped -> skip
  • success -> done

If we split this into decision + handler, a reader would need to jump between methods and map decisions to actions, which I think hurts readability more than it helps at this size.

I’m happy to rename the method if the current name doesn’t reflect its behaviour well, open to suggestions there. But I’d prefer not to split it unless the logic grows or we see reuse emerging.

self, status: Any, producer_task_state: str | None, try_number: int, context: Context
) -> bool:
"""Evaluate the dbt node status and return the poke result."""
if status is None:
if producer_task_state == "failed":
if self.poke_retry_number > 0:
raise AirflowException(
Expand All @@ -636,13 +667,12 @@ def poke(self, context: Context) -> bool:
return self._fallback_to_non_watcher_run(try_number, context)

self.poke_retry_number += 1

return False
elif is_dbt_node_status_skipped(status):

if is_dbt_node_status_skipped(status):
raise AirflowSkipException(
f"{self._resource_label} '{self.model_unique_id}' was skipped by the dbt command."
)
elif is_dbt_node_status_success(status):
if is_dbt_node_status_success(status):
return True
else:
raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'")
raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'")
9 changes: 9 additions & 0 deletions cosmos/operators/_watcher/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"})
DBT_SKIPPED_STATUSES = frozenset({"skipped"})

# Airflow task states that indicate the producer has finished and will not deliver any more XCom updates.
# Used to decide whether a sensor retry should fall back to a non-watcher run or keep polling.
PRODUCER_TERMINAL_STATES = frozenset({"success", "failed", "skipped", "upstream_failed", "removed"})


class DbtTestStatus(str, Enum):
"""Aggregated status of all tests for a given model."""
Expand Down Expand Up @@ -55,6 +59,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool:
return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status)


def is_producer_task_terminated(state: str | None) -> bool:
"""Return True when the producer task is in a terminal state."""
return state in PRODUCER_TERMINAL_STATES


xcom_set_lock = Lock()


Expand Down
13 changes: 13 additions & 0 deletions tests/operators/_watcher/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
is_producer_task_terminated,
)


Expand Down Expand Up @@ -51,6 +52,18 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None):
assert is_dbt_node_status_terminal(status) is False


class TestProducerTaskTerminated:
"""Tests for is_producer_task_terminated helper."""

@pytest.mark.parametrize("state", ["success", "failed", "skipped", "upstream_failed", "removed"])
def test_terminal_states(self, state: str):
assert is_producer_task_terminated(state) is True

@pytest.mark.parametrize("state", ["running", "deferred", "queued", "scheduled", "up_for_reschedule", None, ""])
def test_non_terminal_states(self, state: str | None):
assert is_producer_task_terminated(state) is False


@pytest.mark.parametrize(
"dbt_event,expect_error,status",
[
Expand Down
35 changes: 33 additions & 2 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,8 +1032,10 @@ def test_poke_failure(self):
sensor.poke(context)

@patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd")
def test_task_retry(self, mock_build_and_run_cmd):
def test_task_retry_fallback_when_producer_terminated(self, mock_build_and_run_cmd):
"""On retry, if the producer has already finished, fall back to running the model locally."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "success"
ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
Expand All @@ -1042,6 +1044,22 @@ def test_task_retry(self, mock_build_and_run_cmd):
sensor.poke(context)
mock_build_and_run_cmd.assert_called_once()

@patch("cosmos.operators._watcher.base.get_xcom_val")
def test_task_retry_keeps_polling_when_producer_still_running(self, mock_get_xcom_val):
"""On retry, if the producer is still running, keep polling instead of launching a duplicate dbt run."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "running"
ti = MagicMock()
ti.try_number = 2
# _log_startup_events=None, _get_node_status=None, compiled_sql (skipped), _dbt_event=None
ti.xcom_pull.return_value = None
mock_get_xcom_val.return_value = None
context = self.make_context(ti)

result = sensor.poke(context)
assert result is False
assert sensor.poke_retry_number == 1

def test_fallback_to_non_watcher_run(self):
sensor = self.make_sensor()
ti = MagicMock()
Expand Down Expand Up @@ -1964,15 +1982,28 @@ def test_poke_reads_correct_xcom_key(self):
assert self.TESTS_STATUS_XCOM_KEY in xcom_keys_used

def test_fallback_raises_on_retry(self):
"""On retry (try_number > 1), the test sensor should raise since test re-execution is not yet supported."""
"""On retry (try_number > 1) with a terminated producer, the test sensor should raise since test re-execution is not yet supported."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "success"
ti = MagicMock()
ti.try_number = 2
context = self.make_context(ti)

with pytest.raises(AirflowException, match="Test re-execution is not yet supported"):
sensor.poke(context)

def test_retry_keeps_polling_when_producer_still_running(self):
"""On retry, if the producer is still running, the test sensor should keep polling instead of raising."""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "running"
ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
context = self.make_context(ti)

result = sensor.poke(context)
assert result is False


class TestDefaultFreshnessCallback:
"""Tests for the _default_freshness_callback function."""
Expand Down
27 changes: 25 additions & 2 deletions tests/operators/test_watcher_kubernetes_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ def test_first_execution_behaves_as_base_consumer_sensor(mock_startup_events):
@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_and_run_cmd")
def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd):
"""
On retry (try_number > 1), the sensor should fall back to executing
as DbtRunKubernetesOperator by calling build_and_run_cmd.
On retry (try_number > 1) with a terminated producer, the sensor should
fall back to executing as DbtRunKubernetesOperator by calling build_and_run_cmd.
"""
sensor = make_sensor()
sensor._get_producer_task_status.return_value = "success"

ti = MagicMock()
ti.try_number = 2
Expand All @@ -191,6 +192,28 @@ def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd):
mock_build_and_run_cmd.assert_called_once()


@patch("cosmos.operators._watcher.base.get_xcom_val")
@patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events")
def test_retry_keeps_polling_when_producer_still_running(mock_startup_events, mock_get_xcom_val):
"""
On retry (try_number > 1) with the producer still running, the sensor
should keep polling instead of launching a duplicate dbt run.
"""
sensor = make_sensor()
sensor._get_producer_task_status.return_value = "running"

ti = MagicMock()
ti.try_number = 2
ti.xcom_pull.return_value = None
mock_get_xcom_val.return_value = None
context = make_context(ti)

result = sensor.poke(context)

assert result is False
assert sensor.poke_retry_number == 1


class TestCallbacksNormalization:
"""Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator."""

Expand Down
Loading