From d9763bed13d75cd8b2ed4ba7f015dfdd7bf8146c Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Thu, 13 Nov 2025 19:24:31 +0530 Subject: [PATCH 1/7] Fail consumer sensors when producer task failure observed using Airflow context --- cosmos/operators/watcher.py | 45 ++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index e448f9ddc4..d93cef676d 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -374,6 +374,49 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: def _get_producer_task_state(self, ti: Any) -> Any: return ti.xcom_pull(task_ids=self.producer_task_id, key="state") + def _get_producer_task_status(self, context: Context) -> str | None: + """ + Get the task status of the producer task for both Airflow 2 and Airflow 3. + + Returns the state of the producer task instance, or None if not found. + """ + ti = context["ti"] + run_id = context["run_id"] + dag_id = ti.dag_id + + if AIRFLOW_VERSION < Version("3.0.0"): + # Airflow 2: Query TaskInstance from database + from airflow.models import TaskInstance + from airflow.utils.session import create_session + + with create_session() as session: + producer_ti = ( + session.query(TaskInstance) + .filter_by( + dag_id=dag_id, + task_id=self.producer_task_id, + run_id=run_id, + ) + .first() + ) + if producer_ti: + return str(producer_ti.state) + return None + else: + # Airflow 3: Use RuntimeTaskInstance.get_task_states + try: + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states = RuntimeTaskInstance.get_task_states( + dag_id=dag_id, + task_ids=[self.producer_task_id], + run_ids=[run_id], + ) + return str(task_states.get(run_id, {}).get(self.producer_task_id, "")) + except ImportError: + logger.warning("Could not get producer task status, falling back to XCom state check") + return None + def execute(self, context: Context, **kwargs: Any) -> None: if not self.deferrable: super().execute(context) @@ -433,7 +476,7 @@ def poke(self, context: Context) -> bool: return self._fallback_to_local_run(try_number, context) # We have assumption here that both the build producer and the sensor task will have same invocation mode - producer_task_state = self._get_producer_task_state(ti) + producer_task_state = self._get_producer_task_status(context) if self._use_event(): status = self._get_status_from_events(ti, context) else: From 63f5ca7fc6cafde529d777c2821450e28a58a855 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Fri, 14 Nov 2025 15:56:25 +0300 Subject: [PATCH 2/7] Add tests --- cosmos/operators/watcher.py | 8 ++- tests/operators/test_watcher.py | 91 +++++++++++++++++++++++++++++++-- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index d93cef676d..dbdea847dc 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -386,8 +386,12 @@ def _get_producer_task_status(self, context: Context) -> str | None: if AIRFLOW_VERSION < Version("3.0.0"): # Airflow 2: Query TaskInstance from database - from airflow.models import TaskInstance - from airflow.utils.session import create_session + try: + from airflow.models import TaskInstance + from airflow.utils.session import create_session + except Exception as exc: # pragma: no cover - defensive fallback for tests without DB + logger.warning("Could not import create_session to read producer state: %s", exc) + return None with create_session() as session: producer_ti = ( diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 1cf1ba0c2e..0c978cf0be 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -511,10 +511,93 @@ def make_sensor(self, **kwargs): ) sensor.invocation_mode = "DBT_RUNNER" + sensor._get_producer_task_status = MagicMock(return_value=None) return sensor - def make_context(self, ti_mock): - return {"ti": ti_mock} + def make_context(self, ti_mock, *, run_id: str = "test-run", map_index: int = 0): + return { + "ti": ti_mock, + "run_id": run_id, + "task_instance": MagicMock(map_index=map_index), + } + + @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") + @patch("airflow.utils.session.create_session") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + def test_get_producer_task_status_airflow2(self, mock_create_session): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_1") + + mock_state_ti = MagicMock() + mock_state_ti.state = "success" + session_cm = mock_create_session.return_value + session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = mock_state_ti + + status = sensor._get_producer_task_status(context) + + assert status == "success" + mock_create_session.assert_called_once() + + @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") + @patch("airflow.utils.session.create_session") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + def test_get_producer_task_status_airflow2_missing_instance(self, mock_create_session): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_2") + + session_cm = mock_create_session.return_value + session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = None + + status = sensor._get_producer_task_status(context) + + assert status is None + + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") + def test_get_producer_task_status_airflow3(self, mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_3") + + mock_get_task_states.return_value = {"run_3": {sensor.producer_task_id: "running"}} + + status = sensor._get_producer_task_status(context) + + assert status == "running" + mock_get_task_states.assert_called_once_with( + dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3"] + ) + + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states", side_effect=ImportError("missing runtime")) + def test_get_producer_task_status_airflow3_import_error(self, _mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_4") + + status = sensor._get_producer_task_status(context) + + assert status is None @patch("cosmos.operators.watcher.EventMsg") def test_poke_status_none_from_events(self, MockEventMsg): @@ -544,7 +627,7 @@ def test_poke_success_from_run_results(self): result = sensor.poke(context) assert result is True - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_state", return_value=None) + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None) def _fallback_to_local_run(self, mock_get_producer_task_state): sensor = self.make_sensor() sensor.invocation_mode = None @@ -663,6 +746,7 @@ def test_get_status_from_events_sets_compiled_sql(self): @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") def test_producer_state_failed(self, mock_run_result): sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "failed" ti = MagicMock() ti.try_number = 1 sensor.poke_retry_number = 1 @@ -687,6 +771,7 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed( More details: https://github.com/astronomer/astronomer-cosmos/pull/2062 """ sensor = self.make_sensor() + sensor._get_producer_task_status.return_value = "failed" ti = MagicMock() ti.try_number = 1 sensor.poke_retry_number = 0 From 1f8765cb7381bdf24c39613f3575af7135819613 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:58:02 +0000 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/operators/test_watcher.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 0c978cf0be..2ed9b98358 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -585,7 +585,10 @@ def test_get_producer_task_status_airflow3(self, mock_get_task_states): @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) - @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states", side_effect=ImportError("missing runtime")) + @patch( + "airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states", + side_effect=ImportError("missing runtime"), + ) def test_get_producer_task_status_airflow3_import_error(self, _mock_get_task_states): sensor = self.make_sensor() sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( From e9e107d97a47a501599d19df764965f56bc65885 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 18 Nov 2025 13:45:44 +0300 Subject: [PATCH 4/7] Guard watcher producer status lookup for Airflow 3 --- cosmos/operators/watcher.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index dbdea847dc..36e09f5e09 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -417,9 +417,11 @@ def _get_producer_task_status(self, context: Context) -> str | None: run_ids=[run_id], ) return str(task_states.get(run_id, {}).get(self.producer_task_id, "")) - except ImportError: - logger.warning("Could not get producer task status, falling back to XCom state check") - return None + except (ImportError, NameError) as exc: + logger.warning( + "Could not retrieve producer task status via RuntimeTaskInstance: %s", exc + ) + return None def execute(self, context: Context, **kwargs: Any) -> None: if not self.deferrable: From 593acb94e195ee649fb0791d6cd3c2e01f51b522 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:46:08 +0000 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/operators/watcher.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 36e09f5e09..8e6059d5e3 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -418,9 +418,7 @@ def _get_producer_task_status(self, context: Context) -> str | None: ) return str(task_states.get(run_id, {}).get(self.producer_task_id, "")) except (ImportError, NameError) as exc: - logger.warning( - "Could not retrieve producer task status via RuntimeTaskInstance: %s", exc - ) + logger.warning("Could not retrieve producer task status via RuntimeTaskInstance: %s", exc) return None def execute(self, context: Context, **kwargs: Any) -> None: From 36a0c7635c8322bfc7af5bca72eb91b61511645b Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 18 Nov 2025 22:20:57 +0530 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cosmos/operators/watcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 8e6059d5e3..9c781490e2 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -389,7 +389,7 @@ def _get_producer_task_status(self, context: Context) -> str | None: try: from airflow.models import TaskInstance from airflow.utils.session import create_session - except Exception as exc: # pragma: no cover - defensive fallback for tests without DB + except ImportError as exc: # pragma: no cover - defensive fallback for tests without DB logger.warning("Could not import create_session to read producer state: %s", exc) return None @@ -416,7 +416,10 @@ def _get_producer_task_status(self, context: Context) -> str | None: task_ids=[self.producer_task_id], run_ids=[run_id], ) - return str(task_states.get(run_id, {}).get(self.producer_task_id, "")) + state = task_states.get(run_id, {}).get(self.producer_task_id) + if state is not None: + return str(state) + return None except (ImportError, NameError) as exc: logger.warning("Could not retrieve producer task status via RuntimeTaskInstance: %s", exc) return None From be89e16fa09727ca357dc7a9036a5f49d63e57a3 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 19 Nov 2025 12:48:04 +0300 Subject: [PATCH 7/7] Test: cover watcher producer missing state --- tests/operators/test_watcher.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 2ed9b98358..396f620492 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -583,6 +583,27 @@ def test_get_producer_task_status_airflow3(self, mock_get_task_states): dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3"] ) + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") + @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") + def test_get_producer_task_status_airflow3_missing_state(self, mock_get_task_states): + sensor = self.make_sensor() + sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( + sensor, DbtConsumerWatcherSensor + ) + ti = MagicMock() + ti.dag_id = "example_dag" + context = self.make_context(ti, run_id="run_3_missing") + + mock_get_task_states.return_value = {"run_3_missing": {}} + + status = sensor._get_producer_task_status(context) + + assert status is None + mock_get_task_states.assert_called_once_with( + dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3_missing"] + ) + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) @patch(