diff --git a/cosmos/__init__.py b/cosmos/__init__.py index 30c4e09696..dfe5b9d605 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -9,7 +9,7 @@ from cosmos import settings -__version__ = "1.11.0a11" +__version__ = "1.11.0a13" if not settings.enable_memory_optimised_imports: from cosmos.airflow.dag import DbtDag diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 2103dc977b..89d1655c6e 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -184,6 +184,7 @@ def _callback(ev: EventMsg) -> None: class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc] template_fields = ("model_unique_id",) # type: ignore[operator] + poke_retry_number: int = 0 def __init__( self, @@ -229,14 +230,14 @@ def _filter_flags(flags: list[str]) -> list[str]: filtered.append(token) return filtered - def _handle_task_retry(self, try_number: int, context: Context) -> bool: + def _fallback_to_local_run(self, try_number: int, context: Context) -> bool: """ Handles logic for retrying a failed dbt model execution. Reconstructs the dbt command by cloning the project and re-running the model with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. """ logger.info( - "Retry attempt #%s – Re-running model '%s' from project '%s'", + "Retry attempt #%s – Running model '%s' from project '%s' using ExecutionMode.LOCAL", try_number - 1, self.model_unique_id, self.project_dir, @@ -300,6 +301,9 @@ def _get_status_from_run_results(self, ti: Any) -> Any: logger.info("Node Info: %s", run_results_str) return node_result.get("status") + def _get_producer_task_state(self, ti: Any) -> Any: + return ti.xcom_pull(task_ids=self.producer_task_id, key="state") + def poke(self, context: Context) -> bool: """ Checks the status of a dbt model run by pulling relevant XComs from the master task. @@ -308,31 +312,41 @@ def poke(self, context: Context) -> bool: ti = context["ti"] try_number = ti.try_number - if try_number > 1: - return self._handle_task_retry(try_number, context) - logger.info( - "Pulling status from task_id '%s' for model '%s'", + "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'", + try_number, + self.poke_retry_number, self.producer_task_id, self.model_unique_id, ) + if try_number > 1: + return self._fallback_to_local_run(try_number, context) + # We have assumption here that both the build producer and the sensor task will have same invocation mode if not self.invocation_mode: self._discover_invocation_mode() use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None + producer_task_state = self._get_producer_task_state(ti) if use_events: status = self._get_status_from_events(ti) else: status = self._get_status_from_run_results(ti) if status is None: - producer_task_state = ti.xcom_pull(task_ids=self.producer_task_id, key="state") + if producer_task_state == "failed": - raise AirflowException( - f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." - ) + if self.poke_retry_number > 0: + raise AirflowException( + f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + ) + else: + # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` + return self._fallback_to_local_run(try_number, context) + + self.poke_retry_number += 1 + return False elif status == "success": return True diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index fcbad5d732..ef40190ead 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -350,7 +350,8 @@ def test_poke_success_from_run_results(self): result = sensor.poke(context) assert result is True - def test_invocation_mode_none(self): + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_state", return_value=None) + def _fallback_to_local_run(self, mock_get_producer_task_state): sensor = self.make_sensor() sensor.invocation_mode = None @@ -358,7 +359,6 @@ def test_invocation_mode_none(self): ti.try_number = 1 ti.xcom_pull.return_value = ENCODED_RUN_RESULTS context = self.make_context(ti) - result = sensor.poke(context) assert result is True @@ -397,14 +397,14 @@ def test_task_retry(self, mock_build_and_run_cmd): sensor.poke(context) mock_build_and_run_cmd.assert_called_once() - def test_handle_task_retry(self): + def test_fallback_to_local_run(self): sensor = self.make_sensor() ti = MagicMock() ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--select", "some_model", "--threads", "2"] context = self.make_context(ti) sensor.build_and_run_cmd = MagicMock() - result = sensor._handle_task_retry(2, context) + result = sensor._fallback_to_local_run(2, context) assert result is True sensor.build_and_run_cmd.assert_called_once() @@ -457,6 +457,7 @@ def test_producer_state_failed(self, mock_run_result): sensor = self.make_sensor() ti = MagicMock() ti.try_number = 1 + sensor.poke_retry_number = 1 mock_run_result.return_value = None ti.xcom_pull.return_value = "failed" @@ -468,6 +469,27 @@ def test_producer_state_failed(self, mock_run_result): ): sensor.poke(context) + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._fallback_to_local_run") + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") + def test_producer_state_does_not_fail_if_previously_upstream_failed( + self, mock_run_result, mock_fallback_to_local_run + ): + """ + Attempt to run the task using ExecutionMode.LOCAL if State.UPSTREAM_FAILED happens. + More details: https://github.com/astronomer/astronomer-cosmos/pull/2062 + """ + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 1 + sensor.poke_retry_number = 0 + mock_run_result.return_value = None + ti.xcom_pull.return_value = "failed" + + context = self.make_context(ti) + + sensor.poke(context) + mock_fallback_to_local_run.assert_called_once() + class TestDbtBuildWatcherOperator: @@ -591,7 +613,8 @@ def test_dbt_task_group_with_watcher(): operator_args=operator_args, ) - pre_dbt >> dbt_task_group + pre_dbt + dbt_task_group # Unfortunately, due to a bug in Airflow, we are not being able to set the producer task as an upstream task of the other TaskGroup tasks: # https://github.com/apache/airflow/issues/56723