Skip to content
Merged
2 changes: 1 addition & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from cosmos import settings

__version__ = "1.11.0a11"
__version__ = "1.11.0a13"

if not settings.enable_memory_optimised_imports:
from cosmos.airflow.dag import DbtDag
Expand Down
34 changes: 24 additions & 10 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _callback(ev: EventMsg) -> None:

class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc]
template_fields = ("model_unique_id",) # type: ignore[operator]
poke_retry_number: int = 0
Comment thread
tatiana marked this conversation as resolved.

def __init__(
self,
Expand Down Expand Up @@ -229,14 +230,14 @@ def _filter_flags(flags: list[str]) -> list[str]:
filtered.append(token)
return filtered

def _handle_task_retry(self, try_number: int, context: Context) -> bool:
def _fallback_to_local_run(self, try_number: int, context: Context) -> bool:
"""
Handles logic for retrying a failed dbt model execution.
Reconstructs the dbt command by cloning the project and re-running the model
with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded.
"""
logger.info(
"Retry attempt #%s – Re-running model '%s' from project '%s'",
"Retry attempt #%s – Running model '%s' from project '%s' using ExecutionMode.LOCAL",
try_number - 1,
self.model_unique_id,
self.project_dir,
Expand Down Expand Up @@ -300,6 +301,9 @@ def _get_status_from_run_results(self, ti: Any) -> Any:
logger.info("Node Info: %s", run_results_str)
return node_result.get("status")

def _get_producer_task_state(self, ti: Any) -> Any:
return ti.xcom_pull(task_ids=self.producer_task_id, key="state")

def poke(self, context: Context) -> bool:
"""
Checks the status of a dbt model run by pulling relevant XComs from the master task.
Expand All @@ -308,31 +312,41 @@ def poke(self, context: Context) -> bool:
ti = context["ti"]
try_number = ti.try_number

if try_number > 1:
return self._handle_task_retry(try_number, context)

logger.info(
"Pulling status from task_id '%s' for model '%s'",
"Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'",
try_number,
self.poke_retry_number,
self.producer_task_id,
self.model_unique_id,
)

if try_number > 1:
return self._fallback_to_local_run(try_number, context)

# We have assumption here that both the build producer and the sensor task will have same invocation mode
if not self.invocation_mode:
self._discover_invocation_mode()
use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None

producer_task_state = self._get_producer_task_state(ti)
if use_events:
status = self._get_status_from_events(ti)
else:
status = self._get_status_from_run_results(ti)

if status is None:
producer_task_state = ti.xcom_pull(task_ids=self.producer_task_id, key="state")

if producer_task_state == "failed":
raise AirflowException(
f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details."
)
if self.poke_retry_number > 0:
raise AirflowException(
f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details."
)
else:
# This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED`
return self._fallback_to_local_run(try_number, context)

self.poke_retry_number += 1

return False
elif status == "success":
return True
Expand Down
33 changes: 28 additions & 5 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,15 @@ def test_poke_success_from_run_results(self):
result = sensor.poke(context)
assert result is True

def test_invocation_mode_none(self):
@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_state", return_value=None)
def _fallback_to_local_run(self, mock_get_producer_task_state):
sensor = self.make_sensor()
sensor.invocation_mode = None

ti = MagicMock()
ti.try_number = 1
ti.xcom_pull.return_value = ENCODED_RUN_RESULTS
context = self.make_context(ti)

result = sensor.poke(context)
assert result is True

Expand Down Expand Up @@ -397,14 +397,14 @@ def test_task_retry(self, mock_build_and_run_cmd):
sensor.poke(context)
mock_build_and_run_cmd.assert_called_once()

def test_handle_task_retry(self):
def test_fallback_to_local_run(self):
sensor = self.make_sensor()
ti = MagicMock()
ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--select", "some_model", "--threads", "2"]
context = self.make_context(ti)
sensor.build_and_run_cmd = MagicMock()

result = sensor._handle_task_retry(2, context)
result = sensor._fallback_to_local_run(2, context)

assert result is True
sensor.build_and_run_cmd.assert_called_once()
Expand Down Expand Up @@ -457,6 +457,7 @@ def test_producer_state_failed(self, mock_run_result):
sensor = self.make_sensor()
ti = MagicMock()
ti.try_number = 1
sensor.poke_retry_number = 1
mock_run_result.return_value = None
ti.xcom_pull.return_value = "failed"

Expand All @@ -468,6 +469,27 @@ def test_producer_state_failed(self, mock_run_result):
):
sensor.poke(context)

@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._fallback_to_local_run")
@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results")
def test_producer_state_does_not_fail_if_previously_upstream_failed(
self, mock_run_result, mock_fallback_to_local_run
):
"""
Attempt to run the task using ExecutionMode.LOCAL if State.UPSTREAM_FAILED happens.
More details: https://github.com/astronomer/astronomer-cosmos/pull/2062
"""
sensor = self.make_sensor()
ti = MagicMock()
ti.try_number = 1
sensor.poke_retry_number = 0
mock_run_result.return_value = None
ti.xcom_pull.return_value = "failed"

context = self.make_context(ti)

sensor.poke(context)
mock_fallback_to_local_run.assert_called_once()


class TestDbtBuildWatcherOperator:

Expand Down Expand Up @@ -591,7 +613,8 @@ def test_dbt_task_group_with_watcher():
operator_args=operator_args,
)

pre_dbt >> dbt_task_group
pre_dbt
dbt_task_group

# Unfortunately, due to a bug in Airflow, we are not being able to set the producer task as an upstream task of the other TaskGroup tasks:
# https://github.com/apache/airflow/issues/56723
Expand Down