diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index e448f9ddc4..9c781490e2 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -374,6 +374,56 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: def _get_producer_task_state(self, ti: Any) -> Any: return ti.xcom_pull(task_ids=self.producer_task_id, key="state") + def _get_producer_task_status(self, context: Context) -> str | None: + """ + Get the task status of the producer task for both Airflow 2 and Airflow 3. + + Returns the state of the producer task instance, or None if not found. + """ + ti = context["ti"] + run_id = context["run_id"] + dag_id = ti.dag_id + + if AIRFLOW_VERSION < Version("3.0.0"): + # Airflow 2: Query TaskInstance from database + try: + from airflow.models import TaskInstance + from airflow.utils.session import create_session + except ImportError as exc: # pragma: no cover - defensive fallback for tests without DB + logger.warning("Could not import create_session to read producer state: %s", exc) + return None + + with create_session() as session: + producer_ti = ( + session.query(TaskInstance) + .filter_by( + dag_id=dag_id, + task_id=self.producer_task_id, + run_id=run_id, + ) + .first() + ) + if producer_ti: + return str(producer_ti.state) + return None + else: + # Airflow 3: Use RuntimeTaskInstance.get_task_states + try: + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states = RuntimeTaskInstance.get_task_states( + dag_id=dag_id, + task_ids=[self.producer_task_id], + run_ids=[run_id], + ) + state = task_states.get(run_id, {}).get(self.producer_task_id) + if state is not None: + return str(state) + return None + except (ImportError, NameError) as exc: + logger.warning("Could not retrieve producer task status via RuntimeTaskInstance: %s", exc) + return None + def execute(self, context: Context, **kwargs: Any) -> None: if not self.deferrable: super().execute(context) @@ -433,7 +483,7 @@ def poke(self, context: Context) -> bool: return self._fallback_to_local_run(try_number, context) # We have assumption here that both the build producer and the sensor task will have same invocation mode - producer_task_state = self._get_producer_task_state(ti) + producer_task_state = self._get_producer_task_status(context) if self._use_event(): status = self._get_status_from_events(ti, context) else: diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 1cf1ba0c2e..396f620492 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -511,10 +511,117 @@ def make_sensor(self, **kwargs): ) sensor.invocation_mode = "DBT_RUNNER" + sensor._get_producer_task_status = MagicMock(return_value=None) return sensor - def make_context(self, ti_mock): - return {"ti": ti_mock} + def make_context(self, ti_mock, *, run_id: str = "test-run", map_index: int = 0): + return { + "ti": ti_mock, + "run_id": run_id, + "task_instance": MagicMock(map_index=map_index), + } + + @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") + @patch("airflow.utils.session.create_session") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + def test_get_producer_task_status_airflow2(self, mock_create_session): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_1") + + mock_state_ti = MagicMock() + mock_state_ti.state = "success" + session_cm = mock_create_session.return_value + session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = mock_state_ti + + status = sensor._get_producer_task_status(context) + + assert status == "success" + mock_create_session.assert_called_once() + + @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") + @patch("airflow.utils.session.create_session") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + def test_get_producer_task_status_airflow2_missing_instance(self, mock_create_session): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_2") + + session_cm = mock_create_session.return_value + session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = None + + status = sensor._get_producer_task_status(context) + + assert status is None + + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") + def test_get_producer_task_status_airflow3(self, mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_3") + + mock_get_task_states.return_value = {"run_3": {sensor.producer_task_id: "running"}} + + status = sensor._get_producer_task_status(context) + + assert status == "running" + mock_get_task_states.assert_called_once_with( + dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3"] + ) + + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") + def test_get_producer_task_status_airflow3_missing_state(self, mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_3_missing") + + mock_get_task_states.return_value = {"run_3_missing": {}} + + status = sensor._get_producer_task_status(context) + + assert status is None + mock_get_task_states.assert_called_once_with( + dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3_missing"] + ) + + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch( + "airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states", + side_effect=ImportError("missing runtime"), + ) + def test_get_producer_task_status_airflow3_import_error(self, _mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_4") + + status = sensor._get_producer_task_status(context) + + assert status is None @patch("cosmos.operators.watcher.EventMsg") def test_poke_status_none_from_events(self, MockEventMsg): @@ -544,7 +651,7 @@ def test_poke_success_from_run_results(self): result = sensor.poke(context) assert result is True - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_state", return_value=None) + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None) def _fallback_to_local_run(self, mock_get_producer_task_state): sensor = self.make_sensor() sensor.invocation_mode = None @@ -663,6 +770,7 @@ def test_get_status_from_events_sets_compiled_sql(self): @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") def test_producer_state_failed(self, mock_run_result): sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "failed" ti = MagicMock() ti.try_number = 1 sensor.poke_retry_number = 1 @@ -687,6 +795,7 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed( More details: https://github.com/astronomer/astronomer-cosmos/pull/2062 """ sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "failed" ti = MagicMock() ti.try_number = 1 sensor.poke_retry_number = 0