Skip to content
66 changes: 45 additions & 21 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import zlib
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union

import airflow
Expand All @@ -18,13 +19,13 @@

try:
from airflow.sdk.bases.sensor import BaseSensorOperator
except ImportError:
except ImportError: # pragma: no cover
from airflow.sensors.base import BaseSensorOperator
from airflow.exceptions import AirflowException

try:
from airflow.providers.standard.operators.empty import EmptyOperator
except ImportError:
except ImportError: # pragma: no cover
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]

from cosmos.config import ProfileConfig
Expand Down Expand Up @@ -111,28 +112,46 @@ def _set_on_failure_callback(
return [user_callback, default_callback]

@staticmethod
def _serialize_event(ev: EventMsg) -> dict[str, Any]:
def _serialize_event(event_message: EventMsg) -> dict[str, Any]:
"""Convert structured dbt EventMsg to plain dict."""
from google.protobuf.json_format import MessageToDict

return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return]
return MessageToDict(event_message, preserving_proto_field_name=True) # type: ignore[no-any-return]

def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None:
info = ev.info # type: ignore[attr-defined]
def _handle_startup_event(self, event_message: EventMsg, startup_events: list[dict[str, Any]]) -> None:
info = event_message.info # type: ignore[attr-defined]
raw_ts = getattr(info, "ts", None)
ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr]
startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val})

def _extract_compiled_sql_for_node_event(self, event_message: EventMsg) -> str | None:
if getattr(event_message.data.node_info, "resource_type", None) != "model":
return None
uid = event_message.data.node_info.unique_id
node_path = str(event_message.data.node_info.node_path)
package = uid.split(".")[1]
compiled_sql_path = Path.cwd() / "target" / "compiled" / package / "models" / node_path
if not compiled_sql_path.exists():
logger.warning(
"Compiled sql path %s does not exist and hence the rendered template field compiled_sql for the model will not be populated",
compiled_sql_path,
)
return None
return compiled_sql_path.read_text(encoding="utf-8").strip() or None

def _handle_node_finished(
Comment thread
pankajkoti marked this conversation as resolved.
self,
ev: EventMsg,
event_message: EventMsg,
context: Context,
) -> None:
logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", ev)
logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message)
ti = context["ti"]
uid = ev.data.node_info.unique_id
ev_dict = self._serialize_event(ev)
payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode()
uid = event_message.data.node_info.unique_id
event_message_dict = self._serialize_event(event_message)
compiled_sql = self._extract_compiled_sql_for_node_event(event_message)
if compiled_sql:
event_message_dict["compiled_sql"] = compiled_sql
payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode()
ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload)

def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None:
Expand All @@ -157,12 +176,12 @@ def execute(self, context: Context, **kwargs: Any) -> Any:

if use_events:

def _callback(ev: EventMsg) -> None:
name = ev.info.name
def _callback(event_message: EventMsg) -> None:
name = event_message.info.name
if name in {"MainReportVersion", "AdapterRegistered"}:
self._handle_startup_event(ev, startup_events)
self._handle_startup_event(event_message, startup_events)
elif name == "NodeFinished":
self._handle_node_finished(ev, context)
self._handle_node_finished(event_message, context)

self._dbt_runner_callbacks = [_callback]
result = super().execute(context=context, **kwargs)
Expand All @@ -183,7 +202,7 @@ def _callback(ev: EventMsg) -> None:


class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc]
template_fields = ("model_unique_id",) # type: ignore[operator]
template_fields: tuple[str, ...] = ("model_unique_id", "compiled_sql") # type: ignore[operator]
poke_retry_number: int = 0

def __init__(
Expand All @@ -198,6 +217,7 @@ def __init__(
execution_timeout: timedelta = timedelta(hours=1),
**kwargs: Any,
) -> None:
self.compiled_sql = ""
extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {}
kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT)
kwargs.setdefault("weight_rule", WEIGHT_RULE)
Expand Down Expand Up @@ -258,7 +278,7 @@ def _fallback_to_local_run(self, try_number: int, context: Context) -> bool:
logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id)
return True

def _get_status_from_events(self, ti: Any) -> Any:
def _get_status_from_events(self, ti: Any, context: Context) -> Any:

dbt_startup_events = ti.xcom_pull(task_ids=self.producer_task_id, key="dbt_startup_events")
if dbt_startup_events: # pragma: no cover
Expand All @@ -277,6 +297,10 @@ def _get_status_from_events(self, ti: Any) -> Any:

logger.info("Node Info: %s", event_json_str)

self.compiled_sql = event_json.get("compiled_sql", "")
Comment thread
pankajkoti marked this conversation as resolved.
if self.compiled_sql:
self._override_rtif(context)

return event_json.get("data", {}).get("run_result", {}).get("status")

def _get_status_from_run_results(self, ti: Any) -> Any:
Expand Down Expand Up @@ -330,7 +354,7 @@ def poke(self, context: Context) -> bool:

producer_task_state = self._get_producer_task_state(ti)
if use_events:
status = self._get_status_from_events(ti)
status = self._get_status_from_events(ti, context)
else:
status = self._get_status_from_run_results(ti)

Expand Down Expand Up @@ -368,7 +392,7 @@ class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: i
Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]
template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand All @@ -379,7 +403,7 @@ class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): #
Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields
template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields


class DbtSourceWatcherOperator(DbtSourceLocalOperator):
Expand All @@ -395,7 +419,7 @@ class DbtRunWatcherOperator(DbtConsumerWatcherSensor):
Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator).
"""

template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator]
template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
Expand Down
69 changes: 66 additions & 3 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ class _MockContext(dict):
pass


def _fake_event(name: str = "NodeFinished", uid: str = "model.pkg.m"):
def _fake_event(
name: str = "NodeFinished", uid: str = "model.pkg.m", resource_type: str | None = None, node_path: str | None = None
):
"""Create a minimal fake EventMsg-like object suitable for helper tests."""

class _Info(SimpleNamespace):
Expand All @@ -70,6 +72,10 @@ class _RunResult(SimpleNamespace):
pass

node_info = _NodeInfo(unique_id=uid)
if resource_type is not None:
setattr(node_info, "resource_type", resource_type)
if node_path is not None:
setattr(node_info, "node_path", node_path)
run_result = _RunResult(status="success", message="ok")

data = SimpleNamespace(node_info=node_info, run_result=run_result)
Expand Down Expand Up @@ -156,6 +162,49 @@ def test_handle_node_finished_pushes_xcom():
assert json.loads(raw) == {"foo": "bar"}


def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch):
op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None)
ti = _MockTI()
ctx = _MockContext(ti=ti)

# Create compiled SQL file at expected path: target/compiled/pkg/models/my_model.sql
compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models"
compiled_dir.mkdir(parents=True)
compiled_file = compiled_dir / "my_model.sql"
sql_text = "select 1"
compiled_file.write_text(sql_text, encoding="utf-8")

# Ensure watcher looks up under this tmp project dir
monkeypatch.chdir(tmp_path)

with patch.object(op, "_serialize_event", return_value={}):
ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql")
op._handle_node_finished(ev, ctx)

stored = list(ti.store.values())[0]
raw = zlib.decompress(base64.b64decode(stored)).decode()
data = json.loads(raw)
assert data.get("compiled_sql") == sql_text


def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, monkeypatch):
op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None)
ti = _MockTI()
ctx = _MockContext(ti=ti)

# Ensure watcher looks up under this tmp project dir, but do NOT create compiled file
monkeypatch.chdir(tmp_path)

with patch.object(op, "_serialize_event", return_value={}):
ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql")
op._handle_node_finished(ev, ctx)

stored = list(ti.store.values())[0]
raw = zlib.decompress(base64.b64decode(stored)).decode()
data = json.loads(raw)
assert "compiled_sql" not in data


def test_execute_streaming_mode():
"""Streaming path should push startup + per-model XComs."""
from contextlib import nullcontext
Expand Down Expand Up @@ -440,18 +489,32 @@ def test_get_status_from_events_success(self):
sensor = self.make_sensor()
ti = MagicMock()
ti.xcom_pull.side_effect = [None, ENCODED_EVENT]
context = self.make_context(ti)

result = sensor._get_status_from_events(ti)
result = sensor._get_status_from_events(ti, context)
assert result == "success"

def test_get_status_from_events_none(self):
sensor = self.make_sensor()
ti = MagicMock()
ti.xcom_pull.side_effect = [None, None]
context = self.make_context(ti)

result = sensor._get_status_from_events(ti)
result = sensor._get_status_from_events(ti, context)
assert result is None

def test_get_status_from_events_sets_compiled_sql(self):
sensor = self.make_sensor()
ti = MagicMock()
event_payload = {"data": {"run_result": {"status": "success"}}, "compiled_sql": "select 42"}
encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8")
ti.xcom_pull.side_effect = [None, encoded_event]
context = self.make_context(ti)

result = sensor._get_status_from_events(ti, context)
assert result == "success"
assert sensor.compiled_sql == "select 42"

@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results")
def test_producer_state_failed(self, mock_run_result):
sensor = self.make_sensor()
Expand Down