diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 34efad55ff..4a33d70495 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import gzip import inspect import json import os @@ -577,13 +578,40 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: if partial_parse_file.exists(): cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) - def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None: + def _push_run_results_to_xcom(self, tmp_project_dir: str, context: Context) -> None: + run_results_path = Path(tmp_project_dir) / "target" / "run_results.json" + if not run_results_path.is_file(): # pragma: no cover + raise AirflowException(f"run_results.json not found at {run_results_path}") + + try: + with run_results_path.open() as fp: + raw = json.load(fp) + except json.JSONDecodeError as exc: # pragma: no cover + raise AirflowException("Invalid JSON in run_results.json") from exc + self.log.debug("Loaded run results from %s", run_results_path) + + # results_mapping = {result["unique_id"]: result for result in raw.get("results", [])} + # compressed = base64.b64encode(gzip.compress(json.dumps(results_mapping).encode())).decode() + # self.log.debug("Parsed %d entries out of run_results.json", len(results_mapping)) + + compressed = base64.b64encode(gzip.compress(json.dumps(raw).encode())).decode() + context["ti"].xcom_push(key="run_results", value=compressed) + + self.log.info("Pushed run results to XCom") + + def _handle_post_execution( + self, tmp_project_dir: str, context: Context, push_run_results_to_xcom: bool = False + ) -> None: self.store_freshness_json(tmp_project_dir, context) self.store_compiled_sql(tmp_project_dir, context) self._override_rtif(context) if self.should_upload_compiled_sql: self._upload_sql_files(tmp_project_dir, "compiled") + + if push_run_results_to_xcom: + self._push_run_results_to_xcom(tmp_project_dir, context) + if self.callback: self.callback_args.update({"context": context}) if isinstance(self.callback, list): @@ -614,6 +642,7 @@ def run_command( # noqa: C901 context: Context, run_as_async: bool = False, async_context: dict[str, Any] | None = None, + push_run_results_to_xcom: bool = False, ) -> FullOutputSubprocessResult | dbtRunnerResult | str: """ Copies the dbt project to a temporary directory and runs the command. @@ -667,7 +696,7 @@ def run_command( # noqa: C901 if self.partial_parse: self._update_partial_parse_cache(tmp_dir_path) - self._handle_post_execution(tmp_project_dir, context) + self._handle_post_execution(tmp_project_dir, context, push_run_results_to_xcom) self.handle_exception(result) if run_as_async and async_context: @@ -861,6 +890,7 @@ def build_and_run_cmd( cmd_flags: list[str] | None = None, run_as_async: bool = False, async_context: dict[str, Any] | None = None, + **kwargs: Any, ) -> FullOutputSubprocessResult | dbtRunnerResult: # If this is an async run and we're using the setup task, make sure to include the full_refresh flag if set if run_as_async and settings.enable_setup_async_task and getattr(self, "full_refresh", False): @@ -872,7 +902,7 @@ def build_and_run_cmd( dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] result = self.run_command( - cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context + cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context, **kwargs ) return result @@ -974,7 +1004,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, self.on_warning_callback and self.on_warning_callback(warning_context) def execute(self, context: Context, **kwargs: Any) -> None: - result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) + result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs) if self.on_warning_callback: self._handle_warnings(result, context) diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index f33bd4185f..ec5c78acb3 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -107,6 +107,7 @@ def run_command( context: Context, run_as_async: bool = False, async_context: dict[str, Any] | None = None, + push_run_results_to_xcom: bool = False, ) -> FullOutputSubprocessResult | dbtRunnerResult: # No virtualenv_dir set, so create a temporary virtualenv if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary: @@ -114,7 +115,14 @@ def run_command( with TemporaryDirectory(prefix="cosmos-venv") as tempdir: self.virtualenv_dir = Path(tempdir) self._py_bin = self._prepare_virtualenv() - return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context) + return super().run_command( + cmd, + env, + context, + run_as_async=run_as_async, + async_context=async_context, + push_run_results_to_xcom=push_run_results_to_xcom, + ) try: self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists") @@ -126,7 +134,14 @@ def run_command( self.log.info("Acquiring the virtualenv lock") self._acquire_venv_lock() self._py_bin = self._prepare_virtualenv() - return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context) + return super().run_command( + cmd, + env, + context, + run_as_async=run_as_async, + async_context=async_context, + push_run_results_to_xcom=push_run_results_to_xcom, + ) finally: self.log.info("Releasing virtualenv lock") self._release_venv_lock() diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index cffd7e6ee3..2834d6fbc5 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,4 +1,33 @@ -status_model_fhir_dbt_analytics_active_encounters_daily_nodefinished = { +from __future__ import annotations + +import base64 +import gzip +import json +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: # pragma: no cover + try: + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context # type: ignore[attr-defined] + +from cosmos.config import RenderConfig +from cosmos.constants import InvocationMode +from cosmos.operators.local import DbtBuildLocalOperator +from cosmos.settings import watcher_build_coordinator_priority_weight + +try: + from dbt_common.events.base_types import EventMsg +except ImportError: # pragma: no cover + EventMsg = None + +logger = logging.getLogger(__name__) + + +# Example dbt event JSON dictionaries (kept for reference) +nodefinished_model__fhir_dbt_utils__fhir_table_list = { "info": { "name": "NodeFinished", "code": "Q025", @@ -90,3 +119,130 @@ "data": {"adapter_name": "bigquery", "adapter_version": "=1.9.0"}, } } + + +class DbtBuildCoordinatorOperator(DbtBuildLocalOperator): + """Run dbt build and coordinate model run statuses via XCom for *WATCHER* execution mode . + + Executes **one** ``dbt build`` covering the whole selection. + + - **When ``InvocationMode.DBT_RUNNER`` is set** we patch + ``dbtRunner`` so we receive structured events *while* dbt is running. In + this real-time mode the operator: + – pushes startup metadata events (``MainReportVersion``, + ``AdapterRegistered``) together under XCom key + ``dbt_startup_events``; + – pushes each ``NodeFinished`` event immediately to XCom under + ``nodefinished_`` (gzipped+base64 JSON) so downstream + sensors can react with near-zero latency. + + - **When ``dbtRunner`` is *not* available** (older dbt or + ``InvocationMode=SUBPROCESS``) we fallback to delayed strategy: after + dbt exits we read ``target/run_results.json`` and push the whole mapping + once under key ``run_results`` to XCom. Sensors can poll this key but will not + get per-model updates until the build completes. + + This keeps the heavy dbt work centralised while providing near real-time + feedback and granular task-level observability downstream. + """ + + def __init__( + self, + *, + render_config: RenderConfig | None = None, + **kwargs: Any, + ) -> None: + # Store so we can honour select/exclude when building flags + self.render_config: RenderConfig | None = render_config + + task_id = kwargs.pop("task_id", "dbt_build_coordinator") + + super().__init__(task_id=task_id, priority_weight=watcher_build_coordinator_priority_weight, **kwargs) + + def add_cmd_flags(self) -> list[str]: + flags: list[str] = super().add_cmd_flags() + + self.log.info("DbtBuildCoordinatorOperator: render_config: %s", self.render_config) + if self.render_config is not None and self.render_config.exclude: + flags.extend(["--exclude", *self.render_config.exclude]) + if self.render_config is not None and self.render_config.select: + flags.extend(["--select", *self.render_config.select]) + return flags + + @staticmethod + def _serialize_event(ev: EventMsg) -> dict[str, Any]: + """Convert structured dbt EventMsg to plain dict.""" + from google.protobuf.json_format import MessageToDict + + return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return] + + def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None: + info = ev.info # type: ignore[attr-defined] + raw_ts = getattr(info, "ts", None) + ts_val = raw_ts.ToJsonString() if getattr(raw_ts, "ToJsonString", None) else str(raw_ts) # type: ignore[union-attr] + startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) + + def _handle_node_finished( + self, + ev: EventMsg, + context: Context, + ) -> None: + self.log.debug("DbtBuildCoordinatorOperator: handling node finished event: %s", ev) + ti = context["ti"] + uid = ev.data.node_info.unique_id + ev_dict = self._serialize_event(ev) + payload = base64.b64encode(gzip.compress(json.dumps(ev_dict).encode())).decode() + ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) + + @contextmanager + def _patch_runner(self, callback: Callable[[EventMsg], None]) -> Any: + import cosmos.dbt.runner as _dbt_runner_mod + + original = _dbt_runner_mod.get_runner + + def _patched_get_runner() -> Any: + from dbt.cli.main import dbtRunner + + return dbtRunner(callbacks=[callback]) + + _dbt_runner_mod.get_runner = _patched_get_runner # type: ignore[assignment] + if hasattr(original, "cache_clear"): + original.cache_clear() + try: + yield + finally: + _dbt_runner_mod.get_runner = original + + def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: + ti = context["ti"] + # Only push startup events; per-model statuses are available via individual nodefinished_ entries. + if startup_events: + ti.xcom_push(key="dbt_startup_events", value=startup_events) + + def execute(self, context: Context, **kwargs: Any) -> Any: + if not self.invocation_mode: + self._discover_invocation_mode() + + use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None + self.log.debug("DbtBuildCoordinatorOperator: use_events=%s", use_events) + + startup_events: list[dict[str, Any]] = [] + + if use_events: + + def _cb(ev: EventMsg) -> None: + name = ev.info.name + if name in {"MainReportVersion", "AdapterRegistered"}: + self._handle_startup_event(ev, startup_events) + elif name == "NodeFinished": + self._handle_node_finished(ev, context) + + with self._patch_runner(_cb): + result = super().execute(context=context, **kwargs) + + self._finalize(context, startup_events) + return result + + # Fallback – push run_results.json via base class helper + kwargs["push_run_results_to_xcom"] = True + return super().execute(context=context, **kwargs) diff --git a/cosmos/settings.py b/cosmos/settings.py index a4f35aa74d..9a5174c53d 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -56,6 +56,11 @@ enable_setup_async_task = conf.getboolean("cosmos", "enable_setup_async_task", fallback=True) enable_teardown_async_task = conf.getboolean("cosmos", "enable_teardown_async_task", fallback=True) +# Airflow task priority weight for the DbtBuildCoordinatorOperator task in WATCHER Execution Mode +watcher_build_coordinator_priority_weight = conf.getint( + "cosmos", "watcher_build_coordinator_priority_weight", fallback=9999 +) + AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0") # The following environment variable is populated in Astro Cloud diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py new file mode 100644 index 0000000000..06153f7b3d --- /dev/null +++ b/tests/operators/test_watcher.py @@ -0,0 +1,217 @@ +import base64 +import gzip +import json +from types import SimpleNamespace +from unittest.mock import patch + +from cosmos.config import InvocationMode, RenderConfig +from cosmos.operators.watcher import DbtBuildCoordinatorOperator + + +class _MockTI: + def __init__(self) -> None: + self.store: dict[str, str] = {} + + def xcom_push(self, key: str, value: str, **_): + self.store[key] = value + + +class _MockContext(dict): + pass + + +def _fake_event(name: str = "NodeFinished", uid: str = "model.pkg.m"): + """Create a minimal fake EventMsg-like object suitable for helper tests.""" + + class _Info(SimpleNamespace): + pass + + class _NodeInfo(SimpleNamespace): + pass + + class _RunResult(SimpleNamespace): + pass + + node_info = _NodeInfo(unique_id=uid) + run_result = _RunResult(status="success", message="ok") + + data = SimpleNamespace(node_info=node_info, run_result=run_result) + info = _Info(name=name, code="X", msg="msg") + return SimpleNamespace(info=info, data=data) + + +@patch("google.protobuf.json_format.MessageToDict") +def test_serialize_event(mock_mtd): + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + + mock_mtd.side_effect = lambda ev, **kwargs: {"dummy": True} + + out = op._serialize_event(_fake_event()) + assert out == {"dummy": True} + mock_mtd.assert_called() + + +def test_handle_startup_event(): + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + lst: list[dict] = [] + ev = _fake_event("MainReportVersion") + op._handle_startup_event(ev, lst) + assert lst and lst[0]["name"] == "MainReportVersion" + + +def test_handle_node_finished_pushes_xcom(): + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + ti = _MockTI() + ctx = _MockContext(ti=ti) + + with patch.object(op, "_serialize_event", return_value={"foo": "bar"}): + ev = _fake_event() + op._handle_node_finished(ev, ctx) + + stored = list(ti.store.values())[0] + raw = gzip.decompress(base64.b64decode(stored)).decode() + assert json.loads(raw) == {"foo": "bar"} + + +def test_execute_streaming_mode(): + """Streaming path should push startup + per-model XComs.""" + + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + op.invocation_mode = InvocationMode.DBT_RUNNER + + import cosmos.operators.watcher as _watch_mod + + if _watch_mod.EventMsg is None: + + class _DummyEv: + pass + + _watch_mod.EventMsg = _DummyEv + + ti = _MockTI() + ctx = {"ti": ti, "run_id": "dummy"} + + main_rep = _fake_event("MainReportVersion") + node_evt = _fake_event("NodeFinished", uid="model.pkg.x") + + def fake_patch(self, cb): + cb(main_rep) + cb(node_evt) + from contextlib import nullcontext + + return nullcontext() + + with patch.object(DbtBuildCoordinatorOperator, "_patch_runner", fake_patch), patch.object( + DbtBuildCoordinatorOperator, "_serialize_event", lambda self, ev: {"dummy": True} + ), patch("cosmos.operators.watcher.DbtBuildLocalOperator.execute", lambda *_, **__: None): + op.execute(context=ctx) + + assert "dbt_startup_events" in ti.store + + node_key = "nodefinished_model__pkg__x" + assert node_key in ti.store + + +def test_execute_fallback_mode(tmp_path): + """Fallback path pushes compressed run_results once.""" + + tgt = tmp_path / "target" + tgt.mkdir() + with (tgt / "run_results.json").open("w") as fp: + json.dump({"results": [{"unique_id": "a", "status": "success"}]}, fp) + + op = DbtBuildCoordinatorOperator(project_dir=str(tmp_path), profile_config=None) + op.invocation_mode = InvocationMode.SUBPROCESS # force fallback + + ti = _MockTI() + ctx = {"ti": ti, "run_id": "x"} + + def fake_build_run(self, context, **kw): + from cosmos.operators.local import AbstractDbtLocalBase + + AbstractDbtLocalBase._handle_post_execution(self, self.project_dir, context, True) + return None + + with patch("cosmos.operators.local.DbtBuildLocalOperator.build_and_run_cmd", fake_build_run): + op.execute(context=ctx) + + compressed = ti.store.get("run_results") + assert compressed + data = json.loads(gzip.decompress(base64.b64decode(compressed)).decode()) + assert data["results"][0]["status"] == "success" + + +@patch( + "cosmos.operators.watcher.DbtBuildLocalOperator.add_cmd_flags", + return_value=["cmd"], +) +def test_add_cmd_flags_includes_select_and_exclude(_mock_add): + """add_cmd_flags should append --exclude/--select from RenderConfig.""" + + rc = RenderConfig(select=["tag:nightly"], exclude=["model.old"]) + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None, render_config=rc) + + flags = op.add_cmd_flags() + + assert flags[0] == "cmd" + assert "--exclude" in flags and "model.old" in flags + assert "--select" in flags and "tag:nightly" in flags + + +def test_patch_runner_patches_and_restores(): + """_patch_runner should temporarily replace cosmos.dbt.runner.get_runner.""" + import cosmos.dbt.runner as dr + + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + + original = dr.get_runner + called = {} + + def dummy_callback(ev): + called["cb"] = True + + class _FakeRunner: + def __init__(self, callbacks=None): + self.callbacks = callbacks or [] + + import sys + import types + + fake_main = types.ModuleType("dbt.cli.main") + fake_main.dbtRunner = _FakeRunner + with patch.dict( + sys.modules, {"dbt": types.ModuleType("dbt"), "dbt.cli": types.ModuleType("dbt.cli"), "dbt.cli.main": fake_main} + ): + with op._patch_runner(dummy_callback): + runner_instance = dr.get_runner() + assert isinstance(runner_instance, _FakeRunner) + assert dummy_callback in runner_instance.callbacks + + assert dr.get_runner is original + + +@patch("cosmos.dbt.runner.is_available", return_value=False) +@patch("cosmos.operators.watcher.DbtBuildLocalOperator.execute", return_value="done") +def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): + """If invocation_mode is unset, execute() should discover and set it.""" + + from cosmos.config import InvocationMode + + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + assert op.invocation_mode is None # precondition + + ti = _MockTI() + ctx = {"ti": ti, "run_id": "xyz"} + + result = op.execute(context=ctx) + + assert result == "done" + assert op.invocation_mode == InvocationMode.SUBPROCESS + + +@patch("cosmos.operators.watcher.watcher_build_coordinator_priority_weight", 4242) +def test_build_coordinator_uses_priority_weight_setting(): + """Operator should honour watcher_build_coordinator_priority_weight setting.""" + op = DbtBuildCoordinatorOperator(project_dir=".", profile_config=None) + + assert op.priority_weight == 4242