diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 89d1655c6e..108d936991 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -5,6 +5,7 @@ import logging import zlib from datetime import timedelta +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union import airflow @@ -18,13 +19,13 @@ try: from airflow.sdk.bases.sensor import BaseSensorOperator -except ImportError: +except ImportError: # pragma: no cover from airflow.sensors.base import BaseSensorOperator from airflow.exceptions import AirflowException try: from airflow.providers.standard.operators.empty import EmptyOperator -except ImportError: +except ImportError: # pragma: no cover from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] from cosmos.config import ProfileConfig @@ -111,28 +112,46 @@ def _set_on_failure_callback( return [user_callback, default_callback] @staticmethod - def _serialize_event(ev: EventMsg) -> dict[str, Any]: + def _serialize_event(event_message: EventMsg) -> dict[str, Any]: """Convert structured dbt EventMsg to plain dict.""" from google.protobuf.json_format import MessageToDict - return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return] + return MessageToDict(event_message, preserving_proto_field_name=True) # type: ignore[no-any-return] - def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None: - info = ev.info # type: ignore[attr-defined] + def _handle_startup_event(self, event_message: EventMsg, startup_events: list[dict[str, Any]]) -> None: + info = event_message.info # type: ignore[attr-defined] raw_ts = getattr(info, "ts", None) ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr] startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) + def _extract_compiled_sql_for_node_event(self, event_message: EventMsg) -> str | None: + if getattr(event_message.data.node_info, "resource_type", None) != "model": + return None + uid = event_message.data.node_info.unique_id + node_path = str(event_message.data.node_info.node_path) + package = uid.split(".")[1] + compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path + if not compiled_sql_path.exists(): + logger.warning( + "Compiled sql path %s does not exist and hence the rendered template field compiled_sql for the model will not be populated", + compiled_sql_path, + ) + return None + return compiled_sql_path.read_text(encoding="utf-8").strip() or None + def _handle_node_finished( self, - ev: EventMsg, + event_message: EventMsg, context: Context, ) -> None: - logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", ev) + logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) ti = context["ti"] - uid = ev.data.node_info.unique_id - ev_dict = self._serialize_event(ev) - payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode() + uid = event_message.data.node_info.unique_id + event_message_dict = self._serialize_event(event_message) + compiled_sql = self._extract_compiled_sql_for_node_event(event_message) + if compiled_sql: + event_message_dict["compiled_sql"] = compiled_sql + payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: @@ -157,12 +176,12 @@ def execute(self, context: Context, **kwargs: Any) -> Any: if use_events: - def _callback(ev: EventMsg) -> None: - name = ev.info.name + def _callback(event_message: EventMsg) -> None: + name = event_message.info.name if name in {"MainReportVersion", "AdapterRegistered"}: - self._handle_startup_event(ev, startup_events) + self._handle_startup_event(event_message, startup_events) elif name == "NodeFinished": - self._handle_node_finished(ev, context) + self._handle_node_finished(event_message, context) self._dbt_runner_callbacks = [_callback] result = super().execute(context=context, **kwargs) @@ -183,7 +202,7 @@ def _callback(ev: EventMsg) -> None: class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc] - template_fields = ("model_unique_id",) # type: ignore[operator] + template_fields: tuple[str, ...] = ("model_unique_id", "compiled_sql") # type: ignore[operator] poke_retry_number: int = 0 def __init__( @@ -198,6 +217,7 @@ def __init__( execution_timeout: timedelta = timedelta(hours=1), **kwargs: Any, ) -> None: + self.compiled_sql = "" extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WEIGHT_RULE) @@ -258,7 +278,7 @@ def _fallback_to_local_run(self, try_number: int, context: Context) -> bool: logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) return True - def _get_status_from_events(self, ti: Any) -> Any: + def _get_status_from_events(self, ti: Any, context: Context) -> Any: dbt_startup_events = ti.xcom_pull(task_ids=self.producer_task_id, key="dbt_startup_events") if dbt_startup_events: # pragma: no cover @@ -277,6 +297,10 @@ def _get_status_from_events(self, ti: Any) -> Any: logger.info("Node Info: %s", event_json_str) + self.compiled_sql = event_json.get("compiled_sql", "") + if self.compiled_sql: + self._override_rtif(context) + return event_json.get("data", {}).get("run_result", {}).get("status") def _get_status_from_run_results(self, ti: Any) -> Any: @@ -330,7 +354,7 @@ def poke(self, context: Context) -> bool: producer_task_state = self._get_producer_task_state(ti) if use_events: - status = self._get_status_from_events(ti) + status = self._get_status_from_events(ti, context) else: status = self._get_status_from_run_results(ti) @@ -368,7 +392,7 @@ class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: i Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -379,7 +403,7 @@ class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields class DbtSourceWatcherOperator(DbtSourceLocalOperator): @@ -395,7 +419,7 @@ class DbtRunWatcherOperator(DbtConsumerWatcherSensor): Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index ef40190ead..e1341db90f 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -57,7 +57,9 @@ class _MockContext(dict): pass -def _fake_event(name: str = "NodeFinished", uid: str = "model.pkg.m"): +def _fake_event( + name: str = "NodeFinished", uid: str = "model.pkg.m", resource_type: str | None = None, node_path: str | None = None +): """Create a minimal fake EventMsg-like object suitable for helper tests.""" class _Info(SimpleNamespace): @@ -70,6 +72,10 @@ class _RunResult(SimpleNamespace): pass node_info = _NodeInfo(unique_id=uid) + if resource_type is not None: + setattr(node_info, "resource_type", resource_type) + if node_path is not None: + setattr(node_info, "node_path", node_path) run_result = _RunResult(status="success", message="ok") data = SimpleNamespace(node_info=node_info, run_result=run_result) @@ -156,6 +162,49 @@ def test_handle_node_finished_pushes_xcom(): assert json.loads(raw) == {"foo": "bar"} +def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch): + op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) + ti = _MockTI() + ctx = _MockContext(ti=ti) + + # Create compiled SQL file at expected path: target/compiled/pkg/models/my_model.sql + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + compiled_file = compiled_dir / "my_model.sql" + sql_text = "select 1" + compiled_file.write_text(sql_text, encoding="utf-8") + + # Ensure watcher looks up under this tmp project dir + monkeypatch.chdir(tmp_path) + + with patch.object(op, "_serialize_event", return_value={}): + ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") + op._handle_node_finished(ev, ctx) + + stored = list(ti.store.values())[0] + raw = zlib.decompress(base64.b64decode(stored)).decode() + data = json.loads(raw) + assert data.get("compiled_sql") == sql_text + + +def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, monkeypatch): + op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) + ti = _MockTI() + ctx = _MockContext(ti=ti) + + # Ensure watcher looks up under this tmp project dir, but do NOT create compiled file + monkeypatch.chdir(tmp_path) + + with patch.object(op, "_serialize_event", return_value={}): + ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") + op._handle_node_finished(ev, ctx) + + stored = list(ti.store.values())[0] + raw = zlib.decompress(base64.b64decode(stored)).decode() + data = json.loads(raw) + assert "compiled_sql" not in data + + def test_execute_streaming_mode(): """Streaming path should push startup + per-model XComs.""" from contextlib import nullcontext @@ -440,18 +489,32 @@ def test_get_status_from_events_success(self): sensor = self.make_sensor() ti = MagicMock() ti.xcom_pull.side_effect = [None, ENCODED_EVENT] + context = self.make_context(ti) - result = sensor._get_status_from_events(ti) + result = sensor._get_status_from_events(ti, context) assert result == "success" def test_get_status_from_events_none(self): sensor = self.make_sensor() ti = MagicMock() ti.xcom_pull.side_effect = [None, None] + context = self.make_context(ti) - result = sensor._get_status_from_events(ti) + result = sensor._get_status_from_events(ti, context) assert result is None + def test_get_status_from_events_sets_compiled_sql(self): + sensor = self.make_sensor() + ti = MagicMock() + event_payload = {"data": {"run_result": {"status": "success"}}, "compiled_sql": "select 42"} + encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") + ti.xcom_pull.side_effect = [None, encoded_event] + context = self.make_context(ti) + + result = sensor._get_status_from_events(ti, context) + assert result == "success" + assert sensor.compiled_sql == "select 42" + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") def test_producer_state_failed(self, mock_run_result): sensor = self.make_sensor()