diff --git a/cosmos/_triggers/watcher.py b/cosmos/_triggers/watcher.py index 98f9f47625..bfbf4ca98d 100644 --- a/cosmos/_triggers/watcher.py +++ b/cosmos/_triggers/watcher.py @@ -115,10 +115,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: return elif node_status == "failed": self.log.warning("Model '%s' failed", self.model_unique_id) - yield TriggerEvent({"status": "failed"}) # type: ignore[no-untyped-call] + yield TriggerEvent({"status": "failed", "reason": "model_failed"}) # type: ignore[no-untyped-call] return elif producer_task_state == "failed": - yield TriggerEvent({"status": "failed"}) # type: ignore[no-untyped-call] + self.log.error( + "Watcher producer task '%s' failed before delivering results for model '%s'", + self.producer_task_id, + self.model_unique_id, + ) + yield TriggerEvent({"status": "failed", "reason": "producer_failed"}) # type: ignore[no-untyped-call] return # Sleep briefly before re-polling diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 456a9bc5d6..30bfdb2670 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -393,9 +393,19 @@ def execute(self, context: Context, **kwargs: Any) -> None: ) def execute_complete(self, context: Context, event: dict[str, str]) -> None: - if event.get("status") == "failed": + status = event.get("status") + if status != "failed": + return + + reason = event.get("reason") + if reason == "model_failed": + raise AirflowException( + f"dbt model '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details." + ) + + if reason == "producer_failed": raise AirflowException( - f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." ) def _use_event(self) -> bool: diff --git a/tests/_triggers/test_watcher.py b/tests/_triggers/test_watcher.py index 357cc70c9c..96c17890a1 100644 --- a/tests/_triggers/test_watcher.py +++ b/tests/_triggers/test_watcher.py @@ -99,24 +99,24 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): assert val == "af3" @pytest.mark.parametrize( - "node_status, dr_state, expected", + "node_status, producer_state, expected", [ - ("success", "running", "success"), - ("failed", "running", "failed"), - (None, "failed", "failed"), + ("success", "running", {"status": "success"}), + ("failed", "running", {"status": "failed", "reason": "model_failed"}), + (None, "failed", {"status": "failed", "reason": "producer_failed"}), ], ) - async def test_run_various_outcomes(self, node_status, dr_state, expected): + async def test_run_various_outcomes(self, node_status, producer_state, expected): async def fake_get_xcom_val(key): - return dr_state if key == "state" else "compressed_data" + return producer_state if key == "state" else "compressed_data" with patch.object(self.trigger, "get_xcom_val", side_effect=fake_get_xcom_val), patch( "cosmos._triggers.watcher._parse_compressed_xcom", return_value={"data": {"run_result": {"status": node_status}}} if node_status else {}, ): events = [event async for event in self.trigger.run()] - assert events[0].payload["status"] == expected + assert events[0].payload == expected @pytest.mark.asyncio async def test_run_poke_interval_and_debug_log(self, caplog): diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 213c4f83c2..37857a42f4 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -214,6 +214,44 @@ def test_dbt_producer_watcher_operator_blocks_retry_attempt(caplog): assert any("does not support Airflow retries" in message for message in caplog.messages) +@pytest.mark.parametrize( + "event, expected_message", + [ + ({"status": "success"}, None), + ( + {"status": "failed", "reason": "model_failed"}, + "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher_operator' logs for details.", + ), + ( + {"status": "failed", "reason": "producer_failed"}, + "Watcher producer task 'dbt_producer_watcher_operator' failed before reporting model results. Check its logs for the underlying error.", + ), + ], +) +def test_dbt_consumer_watcher_sensor_execute_complete(event, expected_message): + sensor = DbtConsumerWatcherSensor( + project_dir=".", + profiles_dir=".", + profile_config=profile_config, + model_unique_id="model.pkg.m", + poke_interval=1, + producer_task_id="dbt_producer_watcher_operator", + task_id="consumer_sensor", + ) + sensor.model_unique_id = "model.pkg.m" + + context = {"dag_run": MagicMock()} + + if expected_message is None: + sensor.execute_complete(context, event) + return + + with pytest.raises(AirflowException) as excinfo: + sensor.execute_complete(context, event) + + assert str(excinfo.value) == expected_message + + def test_handle_node_finished_pushes_xcom(): op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) ti = _MockTI() @@ -700,7 +738,8 @@ def test_sensor_not_deferred(self, mock_poke): @pytest.mark.parametrize( "mock_event", [ - {"status": "failed"}, + {"status": "failed", "reason": "model_failed"}, + {"status": "failed", "reason": "producer_failed"}, {"status": "success"}, ], )