Skip to content
38 changes: 34 additions & 4 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import base64
import gzip
import inspect
import json
import os
Expand Down Expand Up @@ -577,13 +578,40 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
if partial_parse_file.exists():
cache._update_partial_parse_cache(partial_parse_file, self.cache_dir)

def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None:
def _push_run_results_to_xcom(self, tmp_project_dir: str, context: Context) -> None:
run_results_path = Path(tmp_project_dir) / "target" / "run_results.json"
if not run_results_path.is_file(): # pragma: no cover
raise AirflowException(f"run_results.json not found at {run_results_path}")

try:
with run_results_path.open() as fp:
raw = json.load(fp)
except json.JSONDecodeError as exc: # pragma: no cover
raise AirflowException("Invalid JSON in run_results.json") from exc
self.log.debug("Loaded run results from %s", run_results_path)

# results_mapping = {result["unique_id"]: result for result in raw.get("results", [])}
# compressed = base64.b64encode(gzip.compress(json.dumps(results_mapping).encode())).decode()
# self.log.debug("Parsed %d entries out of run_results.json", len(results_mapping))

compressed = base64.b64encode(gzip.compress(json.dumps(raw).encode())).decode()
context["ti"].xcom_push(key="run_results", value=compressed)

self.log.info("Pushed run results to XCom")

def _handle_post_execution(
self, tmp_project_dir: str, context: Context, push_run_results_to_xcom: bool = False
) -> None:
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self._override_rtif(context)

if self.should_upload_compiled_sql:
self._upload_sql_files(tmp_project_dir, "compiled")

if push_run_results_to_xcom:
self._push_run_results_to_xcom(tmp_project_dir, context)

if self.callback:
self.callback_args.update({"context": context})
if isinstance(self.callback, list):
Expand Down Expand Up @@ -614,6 +642,7 @@ def run_command( # noqa: C901
context: Context,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
push_run_results_to_xcom: bool = False,
) -> FullOutputSubprocessResult | dbtRunnerResult | str:
"""
Copies the dbt project to a temporary directory and runs the command.
Expand Down Expand Up @@ -667,7 +696,7 @@ def run_command( # noqa: C901
if self.partial_parse:
self._update_partial_parse_cache(tmp_dir_path)

self._handle_post_execution(tmp_project_dir, context)
self._handle_post_execution(tmp_project_dir, context, push_run_results_to_xcom)
self.handle_exception(result)

if run_as_async and async_context:
Expand Down Expand Up @@ -861,6 +890,7 @@ def build_and_run_cmd(
cmd_flags: list[str] | None = None,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
**kwargs: Any,
) -> FullOutputSubprocessResult | dbtRunnerResult:
# If this is an async run and we're using the setup task, make sure to include the full_refresh flag if set
if run_as_async and settings.enable_setup_async_task and getattr(self, "full_refresh", False):
Expand All @@ -872,7 +902,7 @@ def build_and_run_cmd(
dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags)
dbt_cmd = dbt_cmd or []
result = self.run_command(
cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context
cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context, **kwargs
)
return result

Expand Down Expand Up @@ -974,7 +1004,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult,
self.on_warning_callback and self.on_warning_callback(warning_context)

def execute(self, context: Context, **kwargs: Any) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs)
if self.on_warning_callback:
self._handle_warnings(result, context)

Expand Down
19 changes: 17 additions & 2 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,22 @@ def run_command(
context: Context,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
push_run_results_to_xcom: bool = False,
) -> FullOutputSubprocessResult | dbtRunnerResult:
# No virtualenv_dir set, so create a temporary virtualenv
if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary:
self.log.info("Creating temporary virtualenv")
with TemporaryDirectory(prefix="cosmos-venv") as tempdir:
self.virtualenv_dir = Path(tempdir)
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)
return super().run_command(
cmd,
env,
context,
run_as_async=run_as_async,
async_context=async_context,
push_run_results_to_xcom=push_run_results_to_xcom,
)

try:
self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists")
Expand All @@ -126,7 +134,14 @@ def run_command(
self.log.info("Acquiring the virtualenv lock")
self._acquire_venv_lock()
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)
return super().run_command(
cmd,
env,
context,
run_as_async=run_as_async,
async_context=async_context,
push_run_results_to_xcom=push_run_results_to_xcom,
)
finally:
self.log.info("Releasing virtualenv lock")
self._release_venv_lock()
Expand Down
158 changes: 157 additions & 1 deletion cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,33 @@
status_model_fhir_dbt_analytics_active_encounters_daily_nodefinished = {
from __future__ import annotations

import base64
import gzip
import json
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING: # pragma: no cover
try:
from airflow.sdk.definitions.context import Context
except ImportError:
from airflow.utils.context import Context # type: ignore[attr-defined]

from cosmos.config import RenderConfig
from cosmos.constants import InvocationMode
from cosmos.operators.local import DbtBuildLocalOperator
from cosmos.settings import watcher_build_coordinator_priority_weight

try:
from dbt_common.events.base_types import EventMsg
except ImportError: # pragma: no cover
EventMsg = None

logger = logging.getLogger(__name__)


# Example dbt event JSON dictionaries (kept for reference)
nodefinished_model__fhir_dbt_utils__fhir_table_list = {
"info": {
"name": "NodeFinished",
"code": "Q025",
Expand Down Expand Up @@ -90,3 +119,130 @@
"data": {"adapter_name": "bigquery", "adapter_version": "=1.9.0"},
}
}


class DbtBuildCoordinatorOperator(DbtBuildLocalOperator):
"""Run dbt build and coordinate model run statuses via XCom for *WATCHER* execution mode .

Executes **one** ``dbt build`` covering the whole selection.

- **When ``InvocationMode.DBT_RUNNER`` is set** we patch
``dbtRunner`` so we receive structured events *while* dbt is running. In
this real-time mode the operator:
– pushes startup metadata events (``MainReportVersion``,
``AdapterRegistered``) together under XCom key
``dbt_startup_events``;
– pushes each ``NodeFinished`` event immediately to XCom under
``nodefinished_<unique_id>`` (gzipped+base64 JSON) so downstream
sensors can react with near-zero latency.

- **When ``dbtRunner`` is *not* available** (older dbt or
``InvocationMode=SUBPROCESS``) we fallback to delayed strategy: after
dbt exits we read ``target/run_results.json`` and push the whole mapping
once under key ``run_results`` to XCom. Sensors can poll this key but will not
get per-model updates until the build completes.

This keeps the heavy dbt work centralised while providing near real-time
feedback and granular task-level observability downstream.
"""

def __init__(
self,
*,
render_config: RenderConfig | None = None,
**kwargs: Any,
) -> None:
# Store so we can honour select/exclude when building flags
self.render_config: RenderConfig | None = render_config

task_id = kwargs.pop("task_id", "dbt_build_coordinator")

super().__init__(task_id=task_id, priority_weight=watcher_build_coordinator_priority_weight, **kwargs)

def add_cmd_flags(self) -> list[str]:
flags: list[str] = super().add_cmd_flags()

self.log.info("DbtBuildCoordinatorOperator: render_config: %s", self.render_config)
if self.render_config is not None and self.render_config.exclude:
flags.extend(["--exclude", *self.render_config.exclude])
if self.render_config is not None and self.render_config.select:
flags.extend(["--select", *self.render_config.select])
return flags

@staticmethod
def _serialize_event(ev: EventMsg) -> dict[str, Any]:
"""Convert structured dbt EventMsg to plain dict."""
from google.protobuf.json_format import MessageToDict

return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return]

def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None:
info = ev.info # type: ignore[attr-defined]
raw_ts = getattr(info, "ts", None)
ts_val = raw_ts.ToJsonString() if getattr(raw_ts, "ToJsonString", None) else str(raw_ts) # type: ignore[union-attr]
startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val})

def _handle_node_finished(
self,
ev: EventMsg,
context: Context,
) -> None:
self.log.debug("DbtBuildCoordinatorOperator: handling node finished event: %s", ev)
ti = context["ti"]
uid = ev.data.node_info.unique_id
ev_dict = self._serialize_event(ev)
payload = base64.b64encode(gzip.compress(json.dumps(ev_dict).encode())).decode()
ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload)

@contextmanager
def _patch_runner(self, callback: Callable[[EventMsg], None]) -> Any:
import cosmos.dbt.runner as _dbt_runner_mod

original = _dbt_runner_mod.get_runner

def _patched_get_runner() -> Any:
from dbt.cli.main import dbtRunner

return dbtRunner(callbacks=[callback])

_dbt_runner_mod.get_runner = _patched_get_runner # type: ignore[assignment]
if hasattr(original, "cache_clear"):
original.cache_clear()
try:
yield
finally:
_dbt_runner_mod.get_runner = original

def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None:
ti = context["ti"]
# Only push startup events; per-model statuses are available via individual nodefinished_<uid> entries.
if startup_events:
ti.xcom_push(key="dbt_startup_events", value=startup_events)

def execute(self, context: Context, **kwargs: Any) -> Any:
if not self.invocation_mode:
self._discover_invocation_mode()

use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None
self.log.debug("DbtBuildCoordinatorOperator: use_events=%s", use_events)

startup_events: list[dict[str, Any]] = []

if use_events:

def _cb(ev: EventMsg) -> None:
name = ev.info.name
if name in {"MainReportVersion", "AdapterRegistered"}:
self._handle_startup_event(ev, startup_events)
elif name == "NodeFinished":
self._handle_node_finished(ev, context)

with self._patch_runner(_cb):
result = super().execute(context=context, **kwargs)

self._finalize(context, startup_events)
return result

# Fallback – push run_results.json via base class helper
kwargs["push_run_results_to_xcom"] = True
return super().execute(context=context, **kwargs)
5 changes: 5 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
enable_setup_async_task = conf.getboolean("cosmos", "enable_setup_async_task", fallback=True)
enable_teardown_async_task = conf.getboolean("cosmos", "enable_teardown_async_task", fallback=True)

# Airflow task priority weight for the DbtBuildCoordinatorOperator task in WATCHER Execution Mode
watcher_build_coordinator_priority_weight = conf.getint(
"cosmos", "watcher_build_coordinator_priority_weight", fallback=9999
)

AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0")

# The following environment variable is populated in Astro Cloud
Expand Down
Loading