Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions cosmos/dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

from cosmos.dbt.project import change_working_directory, environ
from cosmos.exceptions import CosmosDbtRunError
Expand Down Expand Up @@ -39,7 +39,7 @@ def is_available() -> bool:


@cache
def get_runner() -> dbtRunner:
def _get_cached_dbt_runner() -> dbtRunner:
"""
Retrieves a dbtRunner instance.
"""
Expand All @@ -48,7 +48,21 @@ def get_runner() -> dbtRunner:
return dbtRunner()


def run_command(command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
def get_runner(callbacks: list[Callable] | None = None) -> dbtRunner: # type: ignore[type-arg]
Comment thread
tatiana marked this conversation as resolved.
"""
Retrieves a dbtRunner instance.
"""
if callbacks and isinstance(callbacks, list):
from dbt.cli.main import dbtRunner

return dbtRunner(callbacks=callbacks)

return _get_cached_dbt_runner()


def run_command(
command: list[str], env: dict[str, str], cwd: str, callbacks: list[Callable] | None = None # type: ignore[type-arg]
) -> dbtRunnerResult:
"""
Invokes the dbt command programmatically.
"""
Expand All @@ -58,7 +72,7 @@ def run_command(command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerR
cli_args = command[1:]
with change_working_directory(cwd), environ(env):
logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd)
runner = get_runner()
runner = get_runner(callbacks=callbacks)
result = runner.invoke(cli_args)
return result

Expand Down
37 changes: 32 additions & 5 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
dbt_runner_callbacks: list[Callable] | None = None, # type: ignore[type-arg]
**kwargs: Any,
) -> None:
self.task_id = task_id
Expand All @@ -199,6 +200,7 @@ def __init__(
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self._dbt_runner: dbtRunner | None = None
self._dbt_runner_callbacks = dbt_runner_callbacks

super().__init__(task_id=task_id, **kwargs)

Expand Down Expand Up @@ -470,7 +472,7 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> d
"Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running."
)

return dbt_runner.run_command(command, env, cwd)
return dbt_runner.run_command(command, env, cwd, callbacks=self._dbt_runner_callbacks)

def _cache_package_lockfile(self, tmp_project_dir: Path) -> None:
project_dir = Path(self.project_dir)
Expand Down Expand Up @@ -577,13 +579,36 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
if partial_parse_file.exists():
cache._update_partial_parse_cache(partial_parse_file, self.cache_dir)

def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None:
def _push_run_results_to_xcom(self, tmp_project_dir: str, context: Context) -> None:
run_results_path = Path(tmp_project_dir) / "target" / "run_results.json"
if not run_results_path.is_file():
raise AirflowException(f"run_results.json not found at {run_results_path}")
Comment thread
pankajkoti marked this conversation as resolved.

try:
with run_results_path.open() as fp:
raw = json.load(fp)
except json.JSONDecodeError as exc:
raise AirflowException("Invalid JSON in run_results.json") from exc
self.log.debug("Loaded run results from %s", run_results_path)

compressed = base64.b64encode(zlib.compress(json.dumps(raw).encode())).decode()
context["ti"].xcom_push(key="run_results", value=compressed)

self.log.info("Pushed run results to XCom")

def _handle_post_execution(
self, tmp_project_dir: str, context: Context, push_run_results_to_xcom: bool = False
) -> None:
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self._override_rtif(context)

if self.should_upload_compiled_sql:
self._upload_sql_files(tmp_project_dir, "compiled")

if push_run_results_to_xcom:
self._push_run_results_to_xcom(tmp_project_dir, context)

if self.callback:
self.callback_args.update({"context": context})
if isinstance(self.callback, list):
Expand Down Expand Up @@ -614,6 +639,7 @@ def run_command( # noqa: C901
context: Context,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
push_run_results_to_xcom: bool = False,
) -> FullOutputSubprocessResult | dbtRunnerResult | str:
"""
Copies the dbt project to a temporary directory and runs the command.
Expand Down Expand Up @@ -667,7 +693,7 @@ def run_command( # noqa: C901
if self.partial_parse:
self._update_partial_parse_cache(tmp_dir_path)

self._handle_post_execution(tmp_project_dir, context)
self._handle_post_execution(tmp_project_dir, context, push_run_results_to_xcom)
self.handle_exception(result)

if run_as_async and async_context:
Expand Down Expand Up @@ -861,6 +887,7 @@ def build_and_run_cmd(
cmd_flags: list[str] | None = None,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
**kwargs: Any,
) -> FullOutputSubprocessResult | dbtRunnerResult:
# If this is an async run and we're using the setup task, make sure to include the full_refresh flag if set
if run_as_async and settings.enable_setup_async_task and getattr(self, "full_refresh", False):
Expand All @@ -872,7 +899,7 @@ def build_and_run_cmd(
dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags)
dbt_cmd = dbt_cmd or []
result = self.run_command(
cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context
cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context, **kwargs
)
return result

Expand Down Expand Up @@ -974,7 +1001,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult,
self.on_warning_callback and self.on_warning_callback(warning_context)

def execute(self, context: Context, **kwargs: Any) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags(), **kwargs)
if self.on_warning_callback:
self._handle_warnings(result, context)

Expand Down
19 changes: 17 additions & 2 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,22 @@ def run_command(
context: Context,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
push_run_results_to_xcom: bool = False,
) -> FullOutputSubprocessResult | dbtRunnerResult:
# No virtualenv_dir set, so create a temporary virtualenv
if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary:
self.log.info("Creating temporary virtualenv")
with TemporaryDirectory(prefix="cosmos-venv") as tempdir:
self.virtualenv_dir = Path(tempdir)
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)
return super().run_command(
cmd,
env,
context,
run_as_async=run_as_async,
async_context=async_context,
push_run_results_to_xcom=push_run_results_to_xcom,
)

try:
self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists")
Expand All @@ -126,7 +134,14 @@ def run_command(
self.log.info("Acquiring the virtualenv lock")
self._acquire_venv_lock()
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)
return super().run_command(
cmd,
env,
context,
run_as_async=run_as_async,
async_context=async_context,
push_run_results_to_xcom=push_run_results_to_xcom,
)
finally:
self.log.info("Releasing virtualenv lock")
self._release_venv_lock()
Expand Down
119 changes: 118 additions & 1 deletion cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@
status_model_fhir_dbt_analytics_active_encounters_daily_nodefinished = {
from __future__ import annotations

import base64
import json
import logging
import zlib
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING: # pragma: no cover
try:
from airflow.sdk.definitions.context import Context
except ImportError:
from airflow.utils.context import Context # type: ignore[attr-defined]

from cosmos.constants import InvocationMode
from cosmos.operators.local import DbtLocalBaseOperator

try:
from dbt_common.events.base_types import EventMsg
except ImportError: # pragma: no cover
EventMsg = None

logger = logging.getLogger(__name__)


# Example dbt event JSON dictionaries (kept for reference)
nodefinished_model__fhir_dbt_utils__fhir_table_list = {
Comment thread
pankajkoti marked this conversation as resolved.
"info": {
"name": "NodeFinished",
"code": "Q025",
Expand Down Expand Up @@ -90,3 +116,94 @@
"data": {"adapter_name": "bigquery", "adapter_version": "=1.9.0"},
}
}


Comment thread
pankajkoti marked this conversation as resolved.
class DbtProducerWatcherOperator(DbtLocalBaseOperator):
"""Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode.

Executes **one** ``dbt build`` covering the whole selection.

- **When ``InvocationMode.DBT_RUNNER`` is set** we patch
``dbtRunner`` so we receive structured events *while* dbt is running. In
this real-time mode the operator:
– pushes startup metadata events (``MainReportVersion``,
``AdapterRegistered``) together under XCom key
``dbt_startup_events``;
– pushes each ``NodeFinished`` event immediately to XCom under
``nodefinished_<unique_id>`` (zlib zipped+base64 JSON) so downstream
sensors can react with near-zero latency.

- **When ``dbtRunner`` is *not* available** (older dbt or
``InvocationMode=SUBPROCESS``) we fallback to delayed strategy: after
dbt exits we read ``target/run_results.json`` and push the whole mapping
once under key ``run_results`` to XCom. Sensors can poll this key but will not
get per-model updates until the build completes - by the end of the execution of all dbt nodes.

This keeps the heavy dbt work centralised while providing near real-time
feedback and granular task-level observability downstream.
"""

base_cmd = ["build"]

def __init__(self, *args: Any, **kwargs: Any) -> None:
task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator")
super().__init__(task_id=task_id, *args, **kwargs)

@staticmethod
def _serialize_event(ev: EventMsg) -> dict[str, Any]:
"""Convert structured dbt EventMsg to plain dict."""
from google.protobuf.json_format import MessageToDict
Comment thread
pankajkoti marked this conversation as resolved.

return MessageToDict(ev, preserving_proto_field_name=True) # type: ignore[no-any-return]

def _handle_startup_event(self, ev: EventMsg, startup_events: list[dict[str, Any]]) -> None:
info = ev.info # type: ignore[attr-defined]
raw_ts = getattr(info, "ts", None)
ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr]
startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val})

def _handle_node_finished(
self,
ev: EventMsg,
context: Context,
) -> None:
self.log.debug("DbtProducerWatcherOperator: handling node finished event: %s", ev)
ti = context["ti"]
uid = ev.data.node_info.unique_id
ev_dict = self._serialize_event(ev)
payload = base64.b64encode(zlib.compress(json.dumps(ev_dict).encode())).decode()
ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload)

def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None:
ti = context["ti"]
# Only push startup events; per-model statuses are available via individual nodefinished_<uid> entries.
if startup_events:
ti.xcom_push(key="dbt_startup_events", value=startup_events)

def execute(self, context: Context, **kwargs: Any) -> Any:
if not self.invocation_mode:
self._discover_invocation_mode()

use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None
self.log.debug("DbtProducerWatcherOperator: use_events=%s", use_events)

startup_events: list[dict[str, Any]] = []

if use_events:

def _callback(ev: EventMsg) -> None:
name = ev.info.name
if name in {"MainReportVersion", "AdapterRegistered"}:
self._handle_startup_event(ev, startup_events)
elif name == "NodeFinished":
self._handle_node_finished(ev, context)
Comment thread
pankajkoti marked this conversation as resolved.

self._dbt_runner_callbacks = [_callback]
result = super().execute(context=context, **kwargs)

self._finalize(context, startup_events)
return result

# Fallback – push run_results.json via base class helper
kwargs["push_run_results_to_xcom"] = True
return super().execute(context=context, **kwargs)
Loading