From 3be6bbcf003cb4dda2ab4d195190dc1212a88d2a Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 17:54:06 +0530 Subject: [PATCH 1/8] Store compiled SQL as template fielf for Watcher mode events --- cosmos/operators/watcher.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 2103dc977b..ee9225c3e5 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -3,8 +3,10 @@ import base64 import json import logging +import os import zlib from datetime import timedelta +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union import airflow @@ -132,6 +134,22 @@ def _handle_node_finished( ti = context["ti"] uid = ev.data.node_info.unique_id ev_dict = self._serialize_event(ev) + if ev.data.node_info.resource_type == "model": + project_root = Path(os.getcwd()) + compiled_sql_path = ( + project_root + / "target" + / "compiled" + / str(uid).split(".")[1] + / "models" + / str(ev.data.node_info.node_path) + ) + logger.info("Compile sql path: %s, exists: %s", compiled_sql_path, Path.exists(compiled_sql_path)) + with compiled_sql_path.open("r") as f: + compiled_sql = f.read() + if compiled_sql: + logger.info("Uid: %s, Compiled sql: %s", uid, compiled_sql) + ev_dict["compiled_sql"] = compiled_sql payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode() ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) @@ -183,7 +201,7 @@ def _callback(ev: EventMsg) -> None: class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc] - template_fields = ("model_unique_id",) # type: ignore[operator] + template_fields = ("model_unique_id", "compiled_sql") # type: ignore[operator] def __init__( self, @@ -257,7 +275,7 @@ def _handle_task_retry(self, try_number: int, context: Context) -> bool: logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) return True - def _get_status_from_events(self, ti: Any) -> Any: + def _get_status_from_events(self, ti: Any, context: Context) -> Any: dbt_startup_events = ti.xcom_pull(task_ids=self.producer_task_id, key="dbt_startup_events") if dbt_startup_events: # pragma: no cover @@ -276,6 +294,10 @@ def _get_status_from_events(self, ti: Any) -> Any: logger.info("Node Info: %s", event_json_str) + self.compiled_sql = event_json.get("compiled_sql", "") + if self.compiled_sql: + self._override_rtif(context) + return event_json.get("data", {}).get("run_result", {}).get("status") def _get_status_from_run_results(self, ti: Any) -> Any: @@ -323,7 +345,7 @@ def poke(self, context: Context) -> bool: use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None if use_events: - status = self._get_status_from_events(ti) + status = self._get_status_from_events(ti, context) else: status = self._get_status_from_run_results(ti) @@ -354,7 +376,7 @@ class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: i Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -365,7 +387,7 @@ class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields class DbtSourceWatcherOperator(DbtSourceLocalOperator): @@ -381,7 +403,7 @@ class DbtRunWatcherOperator(DbtConsumerWatcherSensor): Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) From 9c7cae3edc229bf54c323500023e237dcb80b544 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 17:57:01 +0530 Subject: [PATCH 2/8] Apply suggestion from @pankajkoti --- cosmos/operators/watcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 3a95b5e729..a21f7ecf7e 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -148,7 +148,7 @@ def _handle_node_finished( with compiled_sql_path.open("r") as f: compiled_sql = f.read() if compiled_sql: - logger.info("Uid: %s, Compiled sql: %s", uid, compiled_sql) + logger.debug("Uid: %s, Compiled sql: %s", uid, compiled_sql) ev_dict["compiled_sql"] = compiled_sql payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode() ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) From 5a30727f430853adee52019269309c7127a7ab9d Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 17:57:07 +0530 Subject: [PATCH 3/8] Apply suggestion from @pankajkoti --- cosmos/operators/watcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index a21f7ecf7e..04a2233e4e 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -144,7 +144,7 @@ def _handle_node_finished( / "models" / str(ev.data.node_info.node_path) ) - logger.info("Compile sql path: %s, exists: %s", compiled_sql_path, Path.exists(compiled_sql_path)) + logger.debug("Compiled sql path: %s, exists: %s", compiled_sql_path, Path.exists(compiled_sql_path)) with compiled_sql_path.open("r") as f: compiled_sql = f.read() if compiled_sql: From ef7790c78f20132a24e31474854387ec3f4cf66b Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 20:21:10 +0530 Subject: [PATCH 4/8] Refactor --- cosmos/operators/watcher.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 04a2233e4e..64d6a87ff6 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -3,7 +3,6 @@ import base64 import json import logging -import os import zlib from datetime import timedelta from pathlib import Path @@ -125,6 +124,17 @@ def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr] startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) + def _extract_compiled_sql_for_node_event(self, ev: EventMsg) -> str | None: + if getattr(ev.data.node_info, "resource_type", None) != "model": + return None + uid = ev.data.node_info.unique_id + node_path = str(ev.data.node_info.node_path) + package = uid.split(".")[1] + compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path + if compiled_sql_path.exists(): + return compiled_sql_path.read_text(encoding="utf-8").strip() or None + return None + def _handle_node_finished( self, ev: EventMsg, @@ -134,22 +144,9 @@ def _handle_node_finished( ti = context["ti"] uid = ev.data.node_info.unique_id ev_dict = self._serialize_event(ev) - if ev.data.node_info.resource_type == "model": - project_root = Path(os.getcwd()) - compiled_sql_path = ( - project_root - / "target" - / "compiled" - / str(uid).split(".")[1] - / "models" - / str(ev.data.node_info.node_path) - ) - logger.debug("Compiled sql path: %s, exists: %s", compiled_sql_path, Path.exists(compiled_sql_path)) - with compiled_sql_path.open("r") as f: - compiled_sql = f.read() - if compiled_sql: - logger.debug("Uid: %s, Compiled sql: %s", uid, compiled_sql) - ev_dict["compiled_sql"] = compiled_sql + compiled_sql = self._extract_compiled_sql_for_node_event(ev) + if compiled_sql: + ev_dict["compiled_sql"] = compiled_sql payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode() ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) From 0de41802726601bab4d5217dee64bb3ba70bffb3 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 20:27:18 +0530 Subject: [PATCH 5/8] Add tests --- cosmos/operators/watcher.py | 4 +-- tests/operators/test_watcher.py | 51 +++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 64d6a87ff6..fddafaacd4 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -19,13 +19,13 @@ try: from airflow.sdk.bases.sensor import BaseSensorOperator -except ImportError: +except ImportError: # pragma: no cover from airflow.sensors.base import BaseSensorOperator from airflow.exceptions import AirflowException try: from airflow.providers.standard.operators.empty import EmptyOperator -except ImportError: +except ImportError: # pragma: no cover from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] from cosmos.config import ProfileConfig diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index ef40190ead..27a65b7829 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -57,7 +57,9 @@ class _MockContext(dict): pass -def _fake_event(name: str = "NodeFinished", uid: str = "model.pkg.m"): +def _fake_event( + name: str = "NodeFinished", uid: str = "model.pkg.m", resource_type: str | None = None, node_path: str | None = None +): """Create a minimal fake EventMsg-like object suitable for helper tests.""" class _Info(SimpleNamespace): @@ -70,6 +72,10 @@ class _RunResult(SimpleNamespace): pass node_info = _NodeInfo(unique_id=uid) + if resource_type is not None: + setattr(node_info, "resource_type", resource_type) + if node_path is not None: + setattr(node_info, "node_path", node_path) run_result = _RunResult(status="success", message="ok") data = SimpleNamespace(node_info=node_info, run_result=run_result) @@ -156,6 +162,31 @@ def test_handle_node_finished_pushes_xcom(): assert json.loads(raw) == {"foo": "bar"} +def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch): + op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) + ti = _MockTI() + ctx = _MockContext(ti=ti) + + # Create compiled SQL file at expected path: target/compiled/pkg/models/my_model.sql + compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" + compiled_dir.mkdir(parents=True) + compiled_file = compiled_dir / "my_model.sql" + sql_text = "select 1" + compiled_file.write_text(sql_text, encoding="utf-8") + + # Ensure watcher looks up under this tmp project dir + monkeypatch.chdir(tmp_path) + + with patch.object(op, "_serialize_event", return_value={}): + ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") + op._handle_node_finished(ev, ctx) + + stored = list(ti.store.values())[0] + raw = zlib.decompress(base64.b64decode(stored)).decode() + data = json.loads(raw) + assert data.get("compiled_sql") == sql_text + + def test_execute_streaming_mode(): """Streaming path should push startup + per-model XComs.""" from contextlib import nullcontext @@ -440,18 +471,32 @@ def test_get_status_from_events_success(self): sensor = self.make_sensor() ti = MagicMock() ti.xcom_pull.side_effect = [None, ENCODED_EVENT] + context = self.make_context(ti) - result = sensor._get_status_from_events(ti) + result = sensor._get_status_from_events(ti, context) assert result == "success" def test_get_status_from_events_none(self): sensor = self.make_sensor() ti = MagicMock() ti.xcom_pull.side_effect = [None, None] + context = self.make_context(ti) - result = sensor._get_status_from_events(ti) + result = sensor._get_status_from_events(ti, context) assert result is None + def test_get_status_from_events_sets_compiled_sql(self): + sensor = self.make_sensor() + ti = MagicMock() + event_payload = {"data": {"run_result": {"status": "success"}}, "compiled_sql": "select 42"} + encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") + ti.xcom_pull.side_effect = [None, encoded_event] + context = self.make_context(ti) + + result = sensor._get_status_from_events(ti, context) + assert result == "success" + assert sensor.compiled_sql == "select 42" + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results") def test_producer_state_failed(self, mock_run_result): sensor = self.make_sensor() From a10f14a46766ec0912e901f6b4a213e2fd562f76 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 20:51:16 +0530 Subject: [PATCH 6/8] Apply suggestion from @pankajkoti --- cosmos/operators/watcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index fddafaacd4..dea7a011a0 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -133,7 +133,6 @@ def _extract_compiled_sql_for_node_event(self, ev: EventMsg) -> str | None: compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path if compiled_sql_path.exists(): return compiled_sql_path.read_text(encoding="utf-8").strip() or None - return None def _handle_node_finished( self, From 61ef04657333878d218029b8e2fa068cdd2b47f8 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 28 Oct 2025 20:56:35 +0530 Subject: [PATCH 7/8] Add test for missing coverage --- cosmos/operators/watcher.py | 1 + tests/operators/test_watcher.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index dea7a011a0..fddafaacd4 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -133,6 +133,7 @@ def _extract_compiled_sql_for_node_event(self, ev: EventMsg) -> str | None: compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path if compiled_sql_path.exists(): return compiled_sql_path.read_text(encoding="utf-8").strip() or None + return None def _handle_node_finished( self, diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 27a65b7829..e1341db90f 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -187,6 +187,24 @@ def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch): assert data.get("compiled_sql") == sql_text +def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, monkeypatch): + op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) + ti = _MockTI() + ctx = _MockContext(ti=ti) + + # Ensure watcher looks up under this tmp project dir, but do NOT create compiled file + monkeypatch.chdir(tmp_path) + + with patch.object(op, "_serialize_event", return_value={}): + ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") + op._handle_node_finished(ev, ctx) + + stored = list(ti.store.values())[0] + raw = zlib.decompress(base64.b64decode(stored)).decode() + data = json.loads(raw) + assert "compiled_sql" not in data + + def test_execute_streaming_mode(): """Streaming path should push startup + per-model XComs.""" from contextlib import nullcontext From d4a98f88f166fa48eebe0bf338ed61eceff6039d Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 29 Oct 2025 14:22:27 +0530 Subject: [PATCH 8/8] Address review comments --- cosmos/operators/watcher.py | 57 ++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index fddafaacd4..108d936991 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -112,42 +112,46 @@ def _set_on_failure_callback( return [user_callback, default_callback] @staticmethod - def _serialize_event(ev: EventMsg) -> dict[str, Any]: + def _serialize_event(event_message: EventMsg) -> dict[str, Any]: """Convert structured dbt EventMsg to plain dict.""" from google.protobuf.json_format import MessageToDict - return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return] + return MessageToDict(event_message, preserving_proto_field_name=True) # type: ignore[no-any-return] - def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None: - info = ev.info # type: ignore[attr-defined] + def _handle_startup_event(self, event_message: EventMsg, startup_events: list[dict[str, Any]]) -> None: + info = event_message.info # type: ignore[attr-defined] raw_ts = getattr(info, "ts", None) ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr] startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) - def _extract_compiled_sql_for_node_event(self, ev: EventMsg) -> str | None: - if getattr(ev.data.node_info, "resource_type", None) != "model": + def _extract_compiled_sql_for_node_event(self, event_message: EventMsg) -> str | None: + if getattr(event_message.data.node_info, "resource_type", None) != "model": return None - uid = ev.data.node_info.unique_id - node_path = str(ev.data.node_info.node_path) + uid = event_message.data.node_info.unique_id + node_path = str(event_message.data.node_info.node_path) package = uid.split(".")[1] compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path - if compiled_sql_path.exists(): - return compiled_sql_path.read_text(encoding="utf-8").strip() or None - return None + if not compiled_sql_path.exists(): + logger.warning( + "Compiled sql path %s does not exist and hence the rendered template field compiled_sql for the model will not be populated", + compiled_sql_path, + ) + return None + return compiled_sql_path.read_text(encoding="utf-8").strip() or None def _handle_node_finished( self, - ev: EventMsg, + event_message: EventMsg, context: Context, ) -> None: - logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", ev) + logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) ti = context["ti"] - uid = ev.data.node_info.unique_id - ev_dict = self._serialize_event(ev) - compiled_sql = self._extract_compiled_sql_for_node_event(ev) + uid = event_message.data.node_info.unique_id + event_message_dict = self._serialize_event(event_message) + compiled_sql = self._extract_compiled_sql_for_node_event(event_message) if compiled_sql: - ev_dict["compiled_sql"] = compiled_sql - payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode() + event_message_dict["compiled_sql"] = compiled_sql + payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: @@ -172,12 +176,12 @@ def execute(self, context: Context, **kwargs: Any) -> Any: if use_events: - def _callback(ev: EventMsg) -> None: - name = ev.info.name + def _callback(event_message: EventMsg) -> None: + name = event_message.info.name if name in {"MainReportVersion", "AdapterRegistered"}: - self._handle_startup_event(ev, startup_events) + self._handle_startup_event(event_message, startup_events) elif name == "NodeFinished": - self._handle_node_finished(ev, context) + self._handle_node_finished(event_message, context) self._dbt_runner_callbacks = [_callback] result = super().execute(context=context, **kwargs) @@ -198,7 +202,7 @@ def _callback(ev: EventMsg) -> None: class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc] - template_fields = ("model_unique_id", "compiled_sql") # type: ignore[operator] + template_fields: tuple[str, ...] = ("model_unique_id", "compiled_sql") # type: ignore[operator] poke_retry_number: int = 0 def __init__( @@ -213,6 +217,7 @@ def __init__( execution_timeout: timedelta = timedelta(hours=1), **kwargs: Any, ) -> None: + self.compiled_sql = "" extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WEIGHT_RULE) @@ -387,7 +392,7 @@ class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: i Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -398,7 +403,7 @@ class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields class DbtSourceWatcherOperator(DbtSourceLocalOperator): @@ -414,7 +419,7 @@ class DbtRunWatcherOperator(DbtConsumerWatcherSensor): Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str, str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs)