From 1170fc669d262971e8ccb3c9869099d084c96d0b Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 23 Mar 2026 21:05:10 +0000 Subject: [PATCH 01/13] First version: unify watcher to always use SUBPROCESS --- cosmos/operators/_watcher/__init__.py | 3 +- cosmos/operators/_watcher/base.py | 95 +--- cosmos/operators/_watcher/triggerer.py | 34 +- cosmos/operators/watcher.py | 172 +------ cosmos/operators/watcher_kubernetes.py | 3 - tests/operators/_watcher/test_triggerer.py | 37 +- tests/operators/_watcher/test_watcher_base.py | 17 - tests/operators/test_watcher.py | 481 +----------------- .../operators/test_watcher_kubernetes_unit.py | 9 - 9 files changed, 71 insertions(+), 780 deletions(-) diff --git a/cosmos/operators/_watcher/__init__.py b/cosmos/operators/_watcher/__init__.py index 0eac9da614..02ed3411e2 100644 --- a/cosmos/operators/_watcher/__init__.py +++ b/cosmos/operators/_watcher/__init__.py @@ -8,7 +8,6 @@ "is_dbt_node_status_failed", "is_dbt_node_status_terminal", "WatcherTrigger", - "_parse_compressed_xcom", ] from cosmos.operators._watcher.state import ( @@ -19,4 +18,4 @@ is_dbt_node_status_terminal, safe_xcom_push, ) -from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom +from cosmos.operators._watcher.triggerer import WatcherTrigger diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 217326ed09..90ad8ba1e3 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -18,7 +18,6 @@ from cosmos.log import get_logger from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate from cosmos.operators._watcher.state import ( - DBT_FAILED_STATUSES, _iso_to_string, _log_dbt_event, build_producer_state_fetcher, @@ -27,7 +26,7 @@ is_dbt_node_status_terminal, safe_xcom_push, ) -from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom +from cosmos.operators._watcher.triggerer import WatcherTrigger try: from airflow.sdk.bases.sensor import BaseSensorOperator @@ -36,11 +35,6 @@ from airflow.sensors.base import BaseSensorOperator from airflow.utils.context import Context # type: ignore[attr-defined] -try: - from dbt_common.events.base_types import EventMsg -except ImportError: # pragma: no cover - EventMsg = None - logger = get_logger(__name__) # Subset of dbt event types that represent errors/failures. @@ -82,31 +76,20 @@ _DBT_EVENT_ALLOWLIST = _DBT_ERROR_EVENTS_TYPES | _DBT_NODE_STATUS_EVENT_TYPES -def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMsg) -> None: +def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None: logger.debug("dbt_log: %s", dbt_log) - if isinstance(dbt_log, dict): # Subprocess - data = dbt_log.get("data", {}) - info = dbt_log.get("info", {}) + data = dbt_log.get("data", {}) + info = dbt_log.get("info", {}) - event_name = info.get("name") - if event_name not in _DBT_EVENT_ALLOWLIST: - return None - node_info = data.get("node_info") - status = node_info.get("node_status") if node_info else None - unique_id = node_info.get("unique_id") if node_info else None - start_time = node_info.get("node_started_at") if node_info else None - finish_time = node_info.get("node_finished_at") if node_info else None - msg = data.get("msg") or info.get("msg") or None - else: # Runner - event_name = getattr(getattr(dbt_log, "info", None), "name", None) - if event_name not in _DBT_EVENT_ALLOWLIST: - return None - node_info = getattr(dbt_log.data, "node_info", None) - unique_id = getattr(node_info, "unique_id") if node_info else None - status = getattr(node_info, "node_status", None) if node_info else None - start_time = getattr(node_info, "node_started_at", None) if node_info else None - finish_time = getattr(node_info, "node_finished_at", None) if node_info else None - msg = getattr(dbt_log.info, "msg", None) + event_name = info.get("name") + if event_name not in _DBT_EVENT_ALLOWLIST: + return None + node_info = data.get("node_info") + status = node_info.get("node_status") if node_info else None + unique_id = node_info.get("unique_id") if node_info else None + start_time = node_info.get("node_started_at") if node_info else None + finish_time = node_info.get("node_finished_at") if node_info else None + msg = data.get("msg") or info.get("msg") or None if unique_id: dbt_event = { @@ -126,9 +109,7 @@ def _extract_compiled_sql( """ Extract compiled SQL from the target directory for a given dbt node. - Used by both the subprocess strategy (via store_dbt_resource_status_from_log) - and the node-event strategy (via DbtProducerWatcherOperator._handle_node_finished); - both consume from the same target/compiled layout under project_dir. + Used by store_dbt_resource_status_from_log; reads from the target/compiled layout under project_dir. Assumes inputs come from dbt (relative node_path, unique_id like model.package.name). """ @@ -357,39 +338,6 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) return True - def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: - compressed_b64_run_results = ti.xcom_pull(task_ids=self.producer_task_id, key="run_results") - - if not compressed_b64_run_results: - return None - - run_results_json = _parse_compressed_xcom(compressed_b64_run_results) - - logger.debug("Run results: %s", run_results_json) - - results = run_results_json.get("results", []) - node_result = next((r for r in results if r.get("unique_id") == self.model_unique_id), None) - - if not node_result: # pragma: no cover - logger.warning( - "The dbt node with unique_id '%s' was not executed by the dbt command run in the producer task. This may happen if it is an ephemeral model or if the model sql file is empty.", - self.model_unique_id, - ) - return None - - logger.info("Node Info: %s", run_results_json) - - status = node_result.get("status") - - if status in DBT_FAILED_STATUSES: - logger.error("%s", node_result.get("message")) - - self.compiled_sql = node_result.get("compiled_code") - if self.compiled_sql and hasattr(self, "_override_rtif"): - self._override_rtif(context) - - return status - def _get_producer_task_status(self, context: Context) -> str | None: """ Get the task status of the producer task for both Airflow 2 and Airflow 3. @@ -423,7 +371,6 @@ def execute(self, context: Context, **kwargs: Any) -> None: dag_id=self.dag_id, run_id=context["run_id"], map_index=context["task_instance"].map_index, - use_event=self.use_event(), poke_interval=self.poke_interval, ), timeout=self.execution_timeout, @@ -466,12 +413,6 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." ) - def use_event(self) -> bool: - raise NotImplementedError("Subclasses must implement this method") - - def _get_status_from_events(self, ti: Any, context: Context) -> Any: - raise NotImplementedError("Subclasses should implement this method if `use_event` may return True") - def _log_startup_events(self, ti: Any) -> None: dbt_startup_events: list[dict[str, Any]] = ti.xcom_pull( task_ids=self.producer_task_id, key=_DBT_STARTUP_EVENTS_XCOM_KEY @@ -500,13 +441,9 @@ def poke(self, context: Context) -> bool: if try_number > 1: return self._fallback_to_non_watcher_run(try_number, context) - # We have assumption here that both the build producer and the sensor task will have same invocation mode producer_task_state = self._get_producer_task_status(context) - if self.use_event(): - status = self._get_status_from_events(ti, context) - else: - self._log_startup_events(ti) - status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + self._log_startup_events(ti) + status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") # compiled_sql is always in the canonical per-model XCom key (same for event and subprocess modes) if status is not None: diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 3b1f41fa01..f30f94d224 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -1,9 +1,6 @@ from __future__ import annotations import asyncio -import base64 -import json -import zlib from collections.abc import AsyncIterator from typing import Any @@ -34,7 +31,6 @@ def __init__( dag_id: str, run_id: str, map_index: int | None, - use_event: bool, poke_interval: float = 5.0, ): self.model_unique_id = model_unique_id @@ -42,7 +38,6 @@ def __init__( self.dag_id = dag_id self.run_id = run_id self.map_index = map_index - self.use_event = use_event self.poke_interval = poke_interval def serialize(self) -> tuple[str, dict[str, Any]]: @@ -54,7 +49,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "dag_id": self.dag_id, "run_id": self.run_id, "map_index": self.map_index, - "use_event": self.use_event, "poke_interval": self.poke_interval, }, ) @@ -106,29 +100,16 @@ async def get_xcom_val(self, key: str) -> Any | None: return await self.get_xcom_val_af3(key) async def _get_node_status(self) -> Any | None: - status_key = ( - f"nodefinished_{self.model_unique_id.replace('.', '__')}" - if self.use_event - else f"{self.model_unique_id.replace('.', '__')}_status" - ) - - if self.use_event: - compressed_xcom_val = await self.get_xcom_val(status_key) - if not compressed_xcom_val: - return None - data_json = _parse_compressed_xcom(compressed_xcom_val) - status = data_json.get("data", {}).get("run_result", {}).get("status") - else: - status = await self.get_xcom_val(status_key) - return status + status_key = f"{self.model_unique_id.replace('.', '__')}_status" + return await self.get_xcom_val(status_key) async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]: """ Parse node status and compiled_sql from XCom. Returns a tuple of (status, compiled_sql). - Status comes from mode-specific keys (nodefinished_* for event, *_status for subprocess). - compiled_sql is always read from the canonical per-model key (same for both modes). + Status is read from the per-model ``*_status`` XCom key pushed by store_dbt_resource_status_from_log. + compiled_sql is read from the canonical per-model ``*_compiled_sql`` key. """ compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql" @@ -233,10 +214,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # Sleep briefly before re-polling await asyncio.sleep(self.poke_interval) logger.debug("Polling again for node '%s' status...", self.model_unique_id) - - -def _parse_compressed_xcom(compressed_b64_event_msg: str) -> Any: - """Decode and decompress a base64-encoded, zlib-compressed XCom payload.""" - compressed_bytes = base64.b64decode(compressed_b64_event_msg) - event_json_str = zlib.decompress(compressed_bytes).decode("utf-8") - return json.loads(event_json_str) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 077822f514..a2b972a6c2 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,17 +1,13 @@ from __future__ import annotations -import base64 import functools -import json -import zlib from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException from cosmos.config import ProfileConfig -from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push -from cosmos.operators._watcher.state import DBT_FAILED_STATUSES +from cosmos.operators._watcher import safe_xcom_push from cosmos.settings import watcher_dbt_execution_queue try: @@ -20,18 +16,14 @@ from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] from cosmos.constants import ( - _DBT_STARTUP_EVENTS_XCOM_KEY, PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, PRODUCER_WATCHER_TASK_ID, WATCHER_TASK_WEIGHT_RULE, InvocationMode, ) from cosmos.log import get_logger -from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate from cosmos.operators._watcher.base import ( BaseConsumerSensor, - _process_dbt_log_event, - store_compiled_sql_for_model, store_dbt_resource_status_from_log, ) from cosmos.operators.base import ( @@ -46,12 +38,6 @@ DbtSourceLocalOperator, ) -try: - from dbt_common.events.base_types import EventMsg -except ImportError: # pragma: no cover - EventMsg = None - - if TYPE_CHECKING: # pragma: no cover try: from airflow.sdk.definitions.context import Context @@ -71,21 +57,11 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): Executes **one** ``dbt build`` covering the whole selection. - - **When ``InvocationMode.DBT_RUNNER`` is set** we patch - ``dbtRunner`` so we receive structured events *while* dbt is running. In - this real-time mode the operator: - – pushes startup metadata events (``MainReportVersion``, - ``AdapterRegistered``) together under XCom key - ``dbt_startup_events``; - – pushes each ``NodeFinished`` event immediately to XCom under - ``nodefinished_`` (zlib zipped+base64 JSON) so downstream - sensors can react with near-zero latency. - - - **When ``dbtRunner`` is *not* available** (older dbt or - ``InvocationMode=SUBPROCESS``) we fallback to delayed strategy: after - dbt exits we read ``target/run_results.json`` and push the whole mapping - once under key ``run_results`` to XCom. Sensors can poll this key but will not - get per-model updates until the build completes - by the end of the execution of all dbt nodes. + dbt is always invoked via ``InvocationMode.SUBPROCESS`` with ``--log-format json`` so that each + log line can be parsed in real-time by ``store_dbt_resource_status_from_log``. As each + ``NodeFinished`` event arrives the operator pushes the per-model status to XCom under key + ``_status`` so downstream sensors can react without waiting for the full build to + complete. This keeps the heavy dbt work centralised while providing near real-time feedback and granular task-level observability downstream. @@ -107,70 +83,24 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("priority_weight", PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) kwargs["queue"] = watcher_dbt_execution_queue or kwargs.get("queue") or DEFAULT_QUEUE + # WATCHER always uses SUBPROCESS so that dbt JSON logs can be parsed line-by-line. + # This is the single, unified approach for real-time per-model status tracking. + kwargs["invocation_mode"] = InvocationMode.SUBPROCESS super().__init__(task_id=task_id, *args, **kwargs) + self.log_format = "json" - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.log_format = "json" - - @staticmethod - def _serialize_event(event_message: EventMsg) -> dict[str, Any]: - """Convert structured dbt EventMsg to plain dict.""" - from google.protobuf.json_format import MessageToDict - - return MessageToDict(event_message, preserving_proto_field_name=True) # type: ignore[no-any-return] - - def _handle_startup_event(self, event_message: EventMsg, startup_events: list[dict[str, Any]]) -> None: - info = event_message.info # type: ignore[attr-defined] - raw_ts = getattr(info, "ts", None) - ts_val = raw_ts.ToJsonString() if hasattr(raw_ts, "ToJsonString") else str(raw_ts) # type: ignore[union-attr] - startup_events.append({"name": info.name, "msg": info.msg, "ts": ts_val}) - - def _handle_node_finished( - self, - event_message: EventMsg, - context: Context, - ) -> None: - logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) - uid = event_message.data.node_info.unique_id - node_path_val = getattr(event_message.data.node_info, "node_path", None) - node_path = str(node_path_val) if node_path_val is not None else None - resource_type = getattr(event_message.data.node_info, "resource_type", None) - event_message_dict = self._serialize_event(event_message) - store_compiled_sql_for_model(context["ti"], self.project_dir, uid, node_path, resource_type) - - if resource_type == "test" and self.tests_per_model: - status: str = getattr(event_message.data.node_info, "node_status", None) or "" - logger.debug("Test '%s' finished with status '%s'", uid, status) - push_test_result_or_aggregate(uid, status, self.tests_per_model, self.test_results_per_model, context["ti"]) - else: - payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() - safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) - - def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: - # Only push startup events; per-model statuses are available via individual nodefinished_ entries. - if startup_events: - safe_xcom_push(task_instance=context["ti"], key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=startup_events) - - def _set_invocation_mode_if_not_set(self) -> None: - if not self.invocation_mode: - logger.info("No invocation mode provided, discovering it") - self._discover_invocation_mode() - - def _set_process_log_line_callable_if_subprocess(self) -> None: - if self.invocation_mode == InvocationMode.SUBPROCESS: - logger.info( - "DbtProducerWatcherOperator: Setting log_format to json and process_log_line_callable to store_dbt_resource_status_from_log" - ) - self.log_format = "json" - self._process_log_line_callable = functools.partial( - store_dbt_resource_status_from_log, - tests_per_model=self.tests_per_model, - test_results_per_model=self.test_results_per_model, - ) + def _set_process_log_line_callable(self) -> None: + logger.info( + "DbtProducerWatcherOperator: Setting process_log_line_callable to store_dbt_resource_status_from_log" + ) + self._process_log_line_callable = functools.partial( + store_dbt_resource_status_from_log, + tests_per_model=self.tests_per_model, + test_results_per_model=self.test_results_per_model, + ) def execute(self, context: Context, **kwargs: Any) -> Any: - self._set_invocation_mode_if_not_set() - self._set_process_log_line_callable_if_subprocess() + self._set_process_log_line_callable() task_instance = context.get("ti") if task_instance is None: @@ -187,41 +117,7 @@ def execute(self, context: Context, **kwargs: Any) -> Any: return None try: - use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None - logger.debug("DbtProducerWatcherOperator: use_events=%s", use_events) - - startup_events: list[dict[str, Any]] = [] - - if use_events: - - def _callback(event_message: EventMsg) -> None: - try: - _process_dbt_log_event(context["ti"], event_message) - name = event_message.info.name - if name in {"MainReportVersion", "AdapterRegistered"}: - self._handle_startup_event(event_message, startup_events) - safe_xcom_push( - task_instance=context["ti"], key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=startup_events - ) - elif name == "NodeFinished": - self._handle_node_finished(event_message, context) - except Exception: - event_name = getattr(getattr(event_message, "info", None), "name", "unknown") - logger.exception( - "DbtProducerWatcherOperator: error while handling dbt event '%s'", - event_name, - ) - - self._dbt_runner_callbacks = [_callback] - result = super().execute(context=context, **kwargs) - - self._finalize(context, startup_events) - return_value = result - else: - # Fallback – push run_results.json via base class helper - kwargs["push_run_results_to_xcom"] = True - return_value = super().execute(context=context, **kwargs) - + return_value = super().execute(context=context, **kwargs) safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") return return_value @@ -254,32 +150,6 @@ def __init__( **kwargs, ) - def _get_status_from_events(self, ti: Any, context: Context) -> Any: - - self._log_startup_events(ti) - - node_finished_key = f"nodefinished_{self.model_unique_id.replace('.', '__')}" - logger.info("Pulling from producer task_id: %s, key: %s", self.producer_task_id, node_finished_key) - compressed_b64_event_msg = ti.xcom_pull(task_ids=self.producer_task_id, key=node_finished_key) - - if not compressed_b64_event_msg: - return None - - event_json = _parse_compressed_xcom(compressed_b64_event_msg) - - logger.info("Node Info: %s", event_json) - node_result = event_json.get("data", {}).get("run_result", {}) - status = node_result.get("status") - if status in DBT_FAILED_STATUSES: - logger.error("%s", node_result.get("message")) - - return status - - def use_event(self) -> bool: - if not self.invocation_mode: - self._discover_invocation_mode() - return self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None - # This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. # That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. diff --git a/cosmos/operators/watcher_kubernetes.py b/cosmos/operators/watcher_kubernetes.py index 76b4da4e18..1d72defc9e 100644 --- a/cosmos/operators/watcher_kubernetes.py +++ b/cosmos/operators/watcher_kubernetes.py @@ -124,9 +124,6 @@ def execute(self, context: Context, **kwargs: Any) -> Any: class DbtConsumerWatcherKubernetesSensor(BaseConsumerSensor, DbtRunKubernetesOperator): template_fields: tuple[str, ...] = BaseConsumerSensor.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] - def use_event(self) -> bool: - return False - # This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. # That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index b9cb8036ec..0fec9ca313 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -21,7 +21,6 @@ def setup_method(self): dag_id="dag_1", run_id="run_123", map_index=None, - use_event=True, poke_interval=0.001, # fast polling ) @@ -30,6 +29,7 @@ def test_serialize(self): assert classpath.endswith("WatcherTrigger") assert args["model_unique_id"] == "model.test" assert args["poke_interval"] == 0.001 + assert "use_event" not in args @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Require Airflow < 3.0.0") @pytest.mark.asyncio @@ -83,37 +83,22 @@ async def runner(*args, **kwargs): assert none_result is None @pytest.mark.parametrize( - "use_event, xcom_val, expected_status, expected_compiled_sql", + "xcom_val, expected_status, expected_compiled_sql", [ - # Event mode: status from event payload; compiled_sql from canonical *_compiled_sql key only - (True, {"data": {"run_result": {"status": "success"}}}, "success", "SELECT 1"), - (True, {"data": {"run_result": {"status": "success"}}}, "success", None), - (True, None, None, None), - # Subprocess mode: status from *_status key; compiled_sql from canonical key - (False, "failed", "failed", None), - (False, "success", "success", "SELECT * FROM table"), + ("failed", "failed", None), + ("success", "success", "SELECT * FROM table"), + (None, None, None), ], ) - async def test_parse_dbt_node_status_and_compiled_sql( - self, use_event, xcom_val, expected_status, expected_compiled_sql - ): - self.trigger.use_event = use_event - + async def test_parse_dbt_node_status_and_compiled_sql(self, xcom_val, expected_status, expected_compiled_sql): async def mock_get_xcom_val(key): - # compiled_sql is always read from the canonical key (same for both modes) if key.endswith("_compiled_sql"): return expected_compiled_sql - if use_event: - return xcom_val if xcom_val else None - # Subprocess mode: status from per-model key if key.endswith("_status"): return xcom_val return None - with ( - patch("cosmos.operators._watcher.triggerer._parse_compressed_xcom", return_value=xcom_val), - patch.object(self.trigger, "get_xcom_val", AsyncMock(side_effect=mock_get_xcom_val)), - ): + with patch.object(self.trigger, "get_xcom_val", AsyncMock(side_effect=mock_get_xcom_val)): status, compiled_sql = await self.trigger._parse_dbt_node_status_and_compiled_sql() assert status == expected_status assert compiled_sql == expected_compiled_sql @@ -155,15 +140,13 @@ async def fake_get_xcom_val(key): return _STARTUP_EVENTS if key.endswith("_compiled_sql"): return None - return "compressed_data" + if key.endswith("_status"): + return dbt_node_status + return None with ( patch.object(self.trigger, "get_xcom_val", side_effect=fake_get_xcom_val), patch.object(self.trigger, "_get_producer_task_status", AsyncMock(return_value=producer_state)), - patch( - "cosmos.operators._watcher.triggerer._parse_compressed_xcom", - return_value={"data": {"run_result": {"status": dbt_node_status}}} if dbt_node_status else {}, - ), ): events = [event async for event in self.trigger.run()] assert events[0].payload == expected diff --git a/tests/operators/_watcher/test_watcher_base.py b/tests/operators/_watcher/test_watcher_base.py index 2892ca7b8d..a409497888 100644 --- a/tests/operators/_watcher/test_watcher_base.py +++ b/tests/operators/_watcher/test_watcher_base.py @@ -8,23 +8,6 @@ class TestBaseConsumerSensor: - def test__methods_to_be_implemented(self): - class SubclassBaseConsumerSensor(BaseConsumerSensor, DbtRunLocalOperator): - something_to_be_implemented = True - - sensor = SubclassBaseConsumerSensor( - task_id="test_sensor", - model_unique_id="model.jaffle_shop.stg_orders", - producer_task_id="dbt_run_local", - profile_config=None, - project_dir="/tmp/sample_project", - ) - with pytest.raises(NotImplementedError): - sensor.use_event() - - with pytest.raises(NotImplementedError): - assert sensor._get_status_from_events(None, None) is None - def test_extra_context_is_stored_on_instance(self): """Consumer sensor stores extra_context so it is available at runtime.""" diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 261fe5c94d..60c46529b4 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -1,13 +1,9 @@ from __future__ import annotations -import base64 import json import logging -import zlib -from contextlib import nullcontext from datetime import datetime, timedelta from pathlib import Path -from types import SimpleNamespace from typing import Any from unittest.mock import ANY, MagicMock, Mock, patch @@ -79,43 +75,6 @@ class _MockContext(dict): pass -def _fake_event( - name: str = "NodeFinished", uid: str = "model.pkg.m", resource_type: str | None = None, node_path: str | None = None -): - """Create a minimal fake EventMsg-like object suitable for helper tests.""" - - class _Info(SimpleNamespace): - pass - - class _NodeInfo(SimpleNamespace): - pass - - class _RunResult(SimpleNamespace): - pass - - node_info = _NodeInfo(unique_id=uid) - if resource_type is not None: - setattr(node_info, "resource_type", resource_type) - if node_path is not None: - setattr(node_info, "node_path", node_path) - run_result = _RunResult(status="success", message="ok") - - data = SimpleNamespace(node_info=node_info, run_result=run_result) - info = _Info(name=name, code="X", msg="msg") - return SimpleNamespace(info=info, data=data) - - -@patch("google.protobuf.json_format.MessageToDict") -def test_serialize_event(mock_mtd): - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - - mock_mtd.side_effect = lambda ev, **kwargs: {"dummy": True} - - out = op._serialize_event(_fake_event()) - assert out == {"dummy": True} - mock_mtd.assert_called() - - def test_dbt_producer_watcher_operator_priority_weight_default(): """Test that DbtProducerWatcherOperator uses default priority_weight of 9999.""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) @@ -192,16 +151,10 @@ def test_dbt_producer_watcher_operator_priority_weight_override(): assert op.priority_weight == 100 -@pytest.mark.parametrize( - "invocation_mode, expected_log_format", - ( - (InvocationMode.SUBPROCESS, "json"), - (InvocationMode.DBT_RUNNER, None), - ), -) -def test_dbt_producer_log_format_adjusts_with_invocation(invocation_mode, expected_log_format): - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None, invocation_mode=invocation_mode) - assert getattr(op, "log_format", None) == expected_log_format +def test_dbt_producer_log_format_always_json(): + """WATCHER always uses --log-format json regardless of any invocation_mode hint passed.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + assert op.log_format == "json" def test_dbt_producer_watcher_operator_pushes_completion_status(): @@ -250,14 +203,6 @@ def test_dbt_producer_watcher_operator_requires_task_instance(): assert "expects a task instance" in str(excinfo.value) -def test_handle_startup_event(): - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - lst: list[dict] = [] - ev = _fake_event("MainReportVersion") - op._handle_startup_event(ev, lst) - assert lst and lst[0]["name"] == "MainReportVersion" - - def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message(caplog): """Test that execute_complete logs an info message when model was skipped (model_not_run).""" sensor = DbtConsumerWatcherSensor( @@ -339,181 +284,6 @@ def test_dbt_consumer_watcher_sensor_execute_complete(mock_dbt_event, event, exp assert str(excinfo.value) == expected_message -def test_handle_node_finished_pushes_xcom(): - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - ti = _MockTI() - ctx = _MockContext(ti=ti) - - with patch.object(op, "_serialize_event", return_value={"foo": "bar"}): - ev = _fake_event() - op._handle_node_finished(ev, ctx) - - stored = list(ti.store.values())[0] - raw = zlib.decompress(base64.b64decode(stored)).decode() - assert json.loads(raw) == {"foo": "bar"} - - -def test_handle_node_finished_injects_compiled_sql(tmp_path, monkeypatch): - op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) - ti = _MockTI() - ctx = _MockContext(ti=ti) - - # Create compiled SQL file at expected path: target/compiled/pkg/models/my_model.sql - compiled_dir = tmp_path / "target" / "compiled" / "pkg" / "models" - compiled_dir.mkdir(parents=True) - compiled_file = compiled_dir / "my_model.sql" - sql_text = "select 1" - compiled_file.write_text(sql_text, encoding="utf-8") - - # Ensure watcher looks up under this tmp project dir - monkeypatch.chdir(tmp_path) - - with patch.object(op, "_serialize_event", return_value={}): - ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") - op._handle_node_finished(ev, ctx) - - # Compiled SQL is pushed to the canonical XCom key (single strategy for both event and subprocess) - assert ti.store.get("model__pkg__my_model_compiled_sql") == sql_text - - -def test_handle_node_finished_without_compiled_sql_does_not_inject(tmp_path, monkeypatch): - op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) - ti = _MockTI() - ctx = _MockContext(ti=ti) - - # Ensure watcher looks up under this tmp project dir, but do NOT create compiled file - monkeypatch.chdir(tmp_path) - - with patch.object(op, "_serialize_event", return_value={}): - ev = _fake_event(name="NodeFinished", uid="model.pkg.my_model", resource_type="model", node_path="my_model.sql") - op._handle_node_finished(ev, ctx) - - # Event payload does not contain compiled_sql; canonical key is only set when extraction succeeds - stored = list(ti.store.values())[0] - data = json.loads(zlib.decompress(base64.b64decode(stored)).decode()) - assert "compiled_sql" not in data - assert "model__pkg__my_model_compiled_sql" not in ti.store - - -def test_execute_streaming_mode(): - """Streaming path should push startup + per-model XComs.""" - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - op.invocation_mode = InvocationMode.DBT_RUNNER - - import cosmos.operators.watcher as _watch_mod - - # Ensure EventMsg symbol exists without permanently altering the module - if _watch_mod.EventMsg is None: - - class _DummyEv: - pass - - eventmsg_patch = patch("cosmos.operators.watcher.EventMsg", _DummyEv, create=True) - else: - eventmsg_patch = nullcontext() # type: ignore - - ti = _MockTI() - ctx = {"ti": ti, "run_id": "dummy"} - - main_rep = _fake_event("MainReportVersion") - node_evt = _fake_event("NodeFinished", uid="model.pkg.x") - - def fake_base_execute(self, context=None, **_): # type: ignore[override] - for cb in getattr(self, "_dbt_runner_callbacks", []): - cb(main_rep) - cb(node_evt) - return None - - with ( - eventmsg_patch, - patch.object( - DbtProducerWatcherOperator, - "_serialize_event", - lambda self, ev: {"dummy": True}, - ), - patch( - "cosmos.operators.watcher.DbtLocalBaseOperator.execute", - fake_base_execute, - ), - ): - op.execute(context=ctx) - - assert _DBT_STARTUP_EVENTS_XCOM_KEY in ti.store - startup_events = ti.store[_DBT_STARTUP_EVENTS_XCOM_KEY] - assert isinstance(startup_events, list) and len(startup_events) == 1 - assert startup_events[0]["name"] == "MainReportVersion" - - node_key = "nodefinished_model__pkg__x" - assert node_key in ti.store - - -def test_execute_callback_exception_is_logged(caplog): - """Errors inside dbt callback should be logged instead of bubbling up.""" - - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - op.invocation_mode = InvocationMode.DBT_RUNNER - - import cosmos.operators.watcher as _watch_mod - - if _watch_mod.EventMsg is None: - - class _DummyEv: - pass - - eventmsg_patch = patch("cosmos.operators.watcher.EventMsg", _DummyEv, create=True) - else: - eventmsg_patch = nullcontext() # type: ignore - - ti = _MockTI() - ctx = {"ti": ti, "run_id": "dummy"} - - def fake_base_execute(self, context=None, **_): # type: ignore[override] - for cb in getattr(self, "_dbt_runner_callbacks", []): - cb(_fake_event("MainReportVersion")) - return "ok" - - with ( - eventmsg_patch, - patch.object(DbtProducerWatcherOperator, "_handle_startup_event", side_effect=RuntimeError("boom")), - patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", fake_base_execute), - caplog.at_level("ERROR"), - ): - result = op.execute(context=ctx) - - assert result == "ok" - assert "error while handling dbt event" in caplog.text - assert ti.store.get("task_status") == "completed" - - -def test_execute_fallback_mode(tmp_path): - """Fallback path pushes compressed run_results once.""" - - tgt = tmp_path / "target" - tgt.mkdir() - with (tgt / "run_results.json").open("w") as fp: - json.dump({"results": [{"unique_id": "a", "status": "success"}]}, fp) - - op = DbtProducerWatcherOperator(project_dir=str(tmp_path), profile_config=None) - op.invocation_mode = InvocationMode.SUBPROCESS # force fallback - - ti = _MockTI() - ctx = {"ti": ti, "run_id": "x"} - - def fake_build_run(self, context, **kw): - from cosmos.operators.local import AbstractDbtLocalBase - - AbstractDbtLocalBase._handle_post_execution(self, self.project_dir, context, True) - return None - - with patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd", fake_build_run): - op.execute(context=ctx) - - compressed = ti.store.get("run_results") - assert compressed - data = json.loads(zlib.decompress(base64.b64decode(compressed)).decode()) - assert data["results"][0]["status"] == "success" - - class TestStoreDbtStatusFromLog: """Tests for store_dbt_resource_status_from_log and _process_log_line_callable.""" @@ -994,35 +764,16 @@ def test_compiled_sql_flat_path_pushed(self, tmp_path): assert ti.store.get("model__pkg__foo_compiled_sql") == "SELECT 1" -@patch("cosmos.dbt.runner.is_available", return_value=False) -@patch("cosmos.operators.watcher.DbtLocalBaseOperator.execute", return_value="done") -def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): - """If invocation_mode is unset, execute() should discover and set it.""" - +def test_producer_always_uses_subprocess_invocation_mode(): + """DbtProducerWatcherOperator always forces InvocationMode.SUBPROCESS regardless of what was passed.""" from cosmos.config import InvocationMode op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - assert op.invocation_mode is None # precondition - - ti = _MockTI() - ctx = {"ti": ti, "run_id": "xyz"} - - result = op.execute(context=ctx) - - assert result == "done" assert op.invocation_mode == InvocationMode.SUBPROCESS + assert op.log_format == "json" MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" -ENCODED_RUN_RESULTS = base64.b64encode( - zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}') -).decode("utf-8") - -ENCODED_RUN_RESULTS_FAILED = base64.b64encode( - zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"fail"}]}') -).decode("utf-8") - -ENCODED_EVENT = base64.b64encode(zlib.compress(b'{"data": {"run_result": {"status": "success"}}}')).decode("utf-8") class TestDbtConsumerWatcherSensor: @@ -1037,7 +788,6 @@ def make_sensor(self, **kwargs): **kwargs, ) - sensor.invocation_mode = "DBT_RUNNER" sensor._get_producer_task_status = MagicMock(return_value=None) return sensor @@ -1158,26 +908,22 @@ def test_get_producer_task_status_airflow3_import_error(self): ) assert status is None - @patch("cosmos.operators.watcher.EventMsg") - def test_poke_status_none_from_events(self, MockEventMsg): - mock_event_instance = MagicMock() - mock_event_instance.status = "done" - MockEventMsg.return_value = mock_event_instance - + @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") + def test_poke_status_none(self, mock_startup_events): + """poke returns False when no status has been written to XCom yet.""" sensor = self.make_sensor() - sensor.invocation_mode = InvocationMode.DBT_RUNNER + ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.side_effect = [None, None, None] # no event msg found + ti.xcom_pull.return_value = None context = self.make_context(ti) result = sensor.poke(context) assert result is False @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_poke_success_from_run_results(self, mock_startup_events): + def test_poke_success(self, mock_startup_events): sensor = self.make_sensor() - sensor.invocation_mode = "SUBPROCESS" ti = MagicMock() ti.try_number = 1 @@ -1207,65 +953,19 @@ def test_poke_subprocess_mode_extracts_compiled_sql_from_xcom(self, mock_startup assert sensor.compiled_sql == "SELECT * FROM orders" mock_override_rtif.assert_called_once_with(context) - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor.use_event", return_value=True) - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value="running") - @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") - def test_poke_event_mode_extracts_compiled_sql_from_canonical_key( - self, mock_override_rtif, mock_get_producer, mock_use_event - ): - """Test that in event (DBT_RUNNER) mode, poke gets compiled_sql from canonical *_compiled_sql key after status.""" - sensor = self.make_sensor() - sensor.model_unique_id = MODEL_UNIQUE_ID - - ti = MagicMock() - ti.try_number = 1 - # _get_status_from_events: dbt_startup_events=None, nodefinished_*=ENCODED_EVENT; then get_xcom_val(compiled_sql_key), get_xcom_val(_dbt_event) - ti.xcom_pull.side_effect = [None, ENCODED_EVENT, "SELECT * FROM orders", None] - context = self.make_context(ti) - - assert sensor.compiled_sql == "" # Initially empty - result = sensor.poke(context) - assert result is True - assert sensor.compiled_sql == "SELECT * FROM orders" - mock_override_rtif.assert_called_once_with(context) - - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None) - def _fallback_to_non_watcher_run(self, mock_get_producer_task_status): - sensor = self.make_sensor() - sensor.invocation_mode = None - - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS - context = self.make_context(ti) - result = sensor.poke(context) - assert result is True - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_poke_failure_from_run_results(self, mock_startup_events): + def test_poke_failure(self, mock_startup_events): + """poke raises AirflowException when model status is a failure value.""" sensor = self.make_sensor() - sensor.invocation_mode = "OTHER_MODE" ti = MagicMock() ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS_FAILED + ti.xcom_pull.return_value = "failed" context = self.make_context(ti) with pytest.raises(AirflowException): sensor.poke(context) - def test_poke_status_none_from_run_results(self): - sensor = self.make_sensor() - sensor.invocation_mode = "OTHER_MODE" - - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = None - context = self.make_context(ti) - - result = sensor.poke(context) - assert result is False - @patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd") def test_task_retry(self, mock_build_and_run_cmd): sensor = self.make_sensor() @@ -1300,97 +1000,6 @@ def test_filter_flags(self): assert result == expected - def test_get_status_from_run_results_success(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS - - result = sensor._get_status_from_run_results(ti, _MockContext(ti=ti)) - assert result == "success" - - def test_get_status_from_run_results_logs_error(self): - sensor = self.make_sensor() - ti = MagicMock() - - run_results_payload = { - "results": [ - { - "unique_id": sensor.model_unique_id, - "status": "error", - "message": "dbt model failed", - "compiled_code": "select 1", - } - ] - } - - encoded = base64.b64encode(zlib.compress(json.dumps(run_results_payload).encode())).decode("utf-8") - - ti.xcom_pull.return_value = encoded - context = self.make_context(ti) - - with patch("cosmos.operators._watcher.base.logger") as mock_logger: - result = sensor._get_status_from_run_results(ti, context) - - assert result == "error" - mock_logger.error.assert_called_once() - - def test_get_status_from_run_results_none(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.return_value = None - - result = sensor._get_status_from_run_results(ti, _MockContext(ti=ti)) - assert result is None - - def test_get_status_from_events_success(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.side_effect = [None, ENCODED_EVENT] - context = self.make_context(ti) - - result = sensor._get_status_from_events(ti, context) - assert result == "success" - - def test_get_status_from_events_none(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.side_effect = [None, None] - context = self.make_context(ti) - - result = sensor._get_status_from_events(ti, context) - assert result is None - - def test_get_status_from_events_does_not_set_compiled_sql_from_event(self): - """compiled_sql is no longer in the event payload; it is read from the canonical XCom key in poke().""" - sensor = self.make_sensor() - ti = MagicMock() - event_payload = {"data": {"run_result": {"status": "success"}}} - encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") - ti.xcom_pull.side_effect = [None, encoded_event] - context = self.make_context(ti) - - result = sensor._get_status_from_events(ti, context) - assert result == "success" - assert sensor.compiled_sql == "" # not set from event; poke() will get it from canonical key - - @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") - def test_get_status_logs_error(self, mock_log_startup_events): - sensor = self.make_sensor() - ti = MagicMock() - - event_payload = {"data": {"run_result": {"status": "error", "message": "dbt model failed"}}} - - encoded_event = base64.b64encode(zlib.compress(json.dumps(event_payload).encode())).decode("utf-8") - - ti.xcom_pull.return_value = encoded_event - context = self.make_context(ti) - - with patch("cosmos.operators.watcher.logger.error") as mock_log_error: - result = sensor._get_status_from_events(ti, context) - - assert result == "error" - mock_log_error.assert_called_once_with("%s", "dbt model failed") - @patch("cosmos.operators._watcher.base.get_xcom_val") @patch("cosmos.operators._watcher.base.BaseConsumerSensor._log_startup_events") def test_producer_state_failed(self, mock_startup_events, mock_get_xcom_val): @@ -1433,32 +1042,6 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed( sensor.poke(context) mock_fallback_to_non_watcher_run.assert_called_once() - @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") - def test_get_status_from_run_results_with_compiled_sql(self, mock_override_rtif, monkeypatch): - sensor = self.make_sensor() - sensor.model_unique_id = "model.test_table" - - # Create a fake run_results payload containing compiled_code and status - run_results = { - "results": [ - { - "unique_id": "model.test_table", - "compiled_code": "SELECT * FROM dummy_table;", - "status": "success", - } - ] - } - - compressed = zlib.compress(json.dumps(run_results).encode()) - encoded = base64.b64encode(compressed).decode() - - # Mock TaskInstance.xcom_pull to return encoded results - ti = MagicMock() - ti.xcom_pull.return_value = encoded - context = {"ti": ti} - sensor._get_status_from_run_results(ti, context) - mock_override_rtif.assert_called_with(context) - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor.poke") def test_sensor_deferred(self, mock_poke): mock_poke.return_value = False @@ -1551,21 +1134,20 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self): class TestWatcherTrigger: """Tests for WatcherTrigger compiled_sql extraction.""" - def make_trigger(self, use_event: bool = False): + def make_trigger(self): return WatcherTrigger( model_unique_id="model.pkg.my_model", producer_task_id="dbt_producer_watcher", dag_id="test_dag", run_id="test_run", map_index=None, - use_event=use_event, poke_interval=1.0, ) @pytest.mark.asyncio async def test_parse_dbt_node_status_and_compiled_sql_subprocess_mode(self): """Test that compiled_sql is extracted from XCom in subprocess mode.""" - trigger = self.make_trigger(use_event=False) + trigger = self.make_trigger() # Mock get_xcom_val to return status and compiled_sql async def mock_get_xcom_val(key): @@ -1585,7 +1167,7 @@ async def mock_get_xcom_val(key): @pytest.mark.asyncio async def test_parse_dbt_node_status_and_compiled_sql_subprocess_no_compiled_sql(self): """Test that missing compiled_sql is handled gracefully in subprocess mode.""" - trigger = self.make_trigger(use_event=False) + trigger = self.make_trigger() # Mock get_xcom_val to return only status async def mock_get_xcom_val(key): @@ -1600,29 +1182,6 @@ async def mock_get_xcom_val(key): assert status == "success" assert compiled_sql is None - @pytest.mark.asyncio - async def test_parse_dbt_node_status_and_compiled_sql_dbt_runner_mode(self): - """Test that in dbt_runner mode status comes from event payload and compiled_sql from canonical key.""" - trigger = self.make_trigger(use_event=True) - - # Event payload (no longer contains compiled_sql; it is stored under canonical key) - event_data = {"data": {"run_result": {"status": "success"}}} - compressed = base64.b64encode(zlib.compress(json.dumps(event_data).encode())).decode() - - async def mock_get_xcom_val(key): - if key == "nodefinished_model__pkg__my_model": - return compressed - if key == "model__pkg__my_model_compiled_sql": - return "SELECT id FROM users" - return None - - trigger.get_xcom_val = mock_get_xcom_val - - status, compiled_sql = await trigger._parse_dbt_node_status_and_compiled_sql() - - assert status == "success" - assert compiled_sql == "SELECT id FROM users" - @pytest.mark.asyncio async def test_log_startup_events_returns_when_events_available(self, caplog): """Test that _log_startup_events returns once dbt_startup_events is available and logs.""" @@ -1689,7 +1248,6 @@ def test_dbt_dag_with_watcher(capsys): dag_id="watcher_dag", execution_config=ExecutionConfig( execution_mode=ExecutionMode.WATCHER, - invocation_mode=InvocationMode.DBT_RUNNER, ), render_config=RenderConfig(emit_datasets=False), operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, @@ -1916,7 +1474,6 @@ def test_dbt_dag_with_watcher_and_empty_model(caplog): dag_id="watcher_dag_empty_model", execution_config=ExecutionConfig( execution_mode=ExecutionMode.WATCHER, - invocation_mode=InvocationMode.DBT_RUNNER, ), render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), operator_args={ diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index 5ba35845e5..feb4766025 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -191,15 +191,6 @@ def test_retry_executes_as_dbt_run_kubernetes_operator(mock_build_and_run_cmd): mock_build_and_run_cmd.assert_called_once() -def test_use_event_returns_false(): - """ - DbtConsumerWatcherKubernetesSensor should return False for use_event(), - meaning it uses XCom-based status retrieval instead of events. - """ - sensor = make_sensor() - assert sensor.use_event() is False - - class TestCallbacksNormalization: """Tests for the callbacks normalization logic in DbtProducerWatcherKubernetesOperator.""" From f087857a059cc80124cd1cf5d10903271a449ec1 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 23 Mar 2026 22:23:58 +0000 Subject: [PATCH 02/13] Support both InvocationMode.SUBPROCESS and DBT_RUNNER in WATCHER producer The producer operator no longer forces InvocationMode.SUBPROCESS. Instead, the invocation mode is auto-discovered at runtime (DBT_RUNNER preferred when dbt-core is in-process, SUBPROCESS otherwise), while an explicit `invocation_mode` passed by the caller takes full precedence. Both modes share the same `store_dbt_resource_status_from_log` parser: - SUBPROCESS: JSON log lines from stdout are parsed directly (unchanged). - DBT_RUNNER: EventMsg callbacks are serialised to JSON via `google.protobuf.json_format.MessageToJson` (a transitive dbt-core dep) and then forwarded to the same parser. The wiring is done by overriding `run_subprocess` and `run_dbt_runner` so that `project_dir` (the temp copy) is captured correctly in each path. Co-Authored-By: Claude Sonnet 4.6 --- cosmos/operators/watcher.py | 71 ++++++++++++++++++++++++--------- tests/operators/test_watcher.py | 69 +++++++++++++++++++++++++++++--- 2 files changed, 117 insertions(+), 23 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index a2b972a6c2..17e0737b56 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -19,7 +19,6 @@ PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, PRODUCER_WATCHER_TASK_ID, WATCHER_TASK_WEIGHT_RULE, - InvocationMode, ) from cosmos.log import get_logger from cosmos.operators._watcher.base import ( @@ -57,14 +56,20 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): Executes **one** ``dbt build`` covering the whole selection. - dbt is always invoked via ``InvocationMode.SUBPROCESS`` with ``--log-format json`` so that each - log line can be parsed in real-time by ``store_dbt_resource_status_from_log``. As each - ``NodeFinished`` event arrives the operator pushes the per-model status to XCom under key - ``_status`` so downstream sensors can react without waiting for the full build to - complete. + dbt is invoked with ``--log-format json`` and the invocation mode is auto-discovered at runtime: + ``InvocationMode.DBT_RUNNER`` is preferred when dbt-core is available in the same environment + (faster, no subprocess overhead), falling back to ``InvocationMode.SUBPROCESS`` otherwise. + The user may override this by passing ``invocation_mode`` explicitly — that value takes precedence. - This keeps the heavy dbt work centralised while providing near real-time - feedback and granular task-level observability downstream. + Both modes feed the same parser (``store_dbt_resource_status_from_log``): + - SUBPROCESS: each JSON log line from stdout is parsed directly. + - DBT_RUNNER: each ``EventMsg`` from the dbt callback is serialised to JSON via + ``google.protobuf.json_format.MessageToJson`` — a transitive dbt-core dependency — and then + passed through the same parser. + + As each ``NodeFinished`` event arrives the operator pushes the per-model status to XCom under + key ``_status`` so downstream sensors can react without waiting for the full build + to complete. """ template_fields = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] @@ -83,25 +88,55 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("priority_weight", PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) kwargs["queue"] = watcher_dbt_execution_queue or kwargs.get("queue") or DEFAULT_QUEUE - # WATCHER always uses SUBPROCESS so that dbt JSON logs can be parsed line-by-line. - # This is the single, unified approach for real-time per-model status tracking. - kwargs["invocation_mode"] = InvocationMode.SUBPROCESS + # invocation_mode is intentionally NOT forced here; the parent's _discover_invocation_mode() + # picks DBT_RUNNER when available and falls back to SUBPROCESS otherwise. + # An explicit invocation_mode passed by the caller is preserved as-is. super().__init__(task_id=task_id, *args, **kwargs) self.log_format = "json" - def _set_process_log_line_callable(self) -> None: - logger.info( - "DbtProducerWatcherOperator: Setting process_log_line_callable to store_dbt_resource_status_from_log" - ) - self._process_log_line_callable = functools.partial( + def _make_parse_callable(self) -> Callable[[str, Any], None]: + """Returns store_dbt_resource_status_from_log with the operator's test maps pre-bound.""" + return functools.partial( store_dbt_resource_status_from_log, tests_per_model=self.tests_per_model, test_results_per_model=self.test_results_per_model, ) - def execute(self, context: Context, **kwargs: Any) -> Any: - self._set_process_log_line_callable() + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any: + """Wire up per-line JSON log parsing before delegating to the subprocess runner. + + The subprocess hook passes ``{"context": ..., "project_dir": cwd}`` as ``extra_kwargs`` to + the callable, so no additional closure is needed here. + """ + self._process_log_line_callable = self._make_parse_callable() + return super().run_subprocess(command, env, cwd, **kwargs) + + def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any: + """Register an EventMsg → JSON → parse callback before delegating to the dbt runner. + + dbt callbacks receive only the ``EventMsg`` protobuf object; context and project_dir are + captured via closure so the unified ``store_dbt_resource_status_from_log`` parser can be + reused identically to the SUBPROCESS path. + ``google.protobuf.json_format`` is a transitive dependency of dbt-core and is always + available when ``InvocationMode.DBT_RUNNER`` is in use. + """ + context = kwargs.get("context") + extra_kwargs: dict[str, Any] = {"project_dir": cwd} + if context is not None: + extra_kwargs["context"] = context + parse = self._make_parse_callable() + + def _event_callback(event: Any) -> None: + from google.protobuf.json_format import MessageToJson + + json_str = MessageToJson(event, preserving_proto_field_name=True) + parse(json_str, extra_kwargs) + + self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] + return super().run_dbt_runner(command, env, cwd, **kwargs) + + def execute(self, context: Context, **kwargs: Any) -> Any: task_instance = context.get("ti") if task_instance is None: raise AirflowException("DbtProducerWatcherOperator expects a task instance in the execution context") diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 60c46529b4..aa90765c0e 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -764,13 +764,72 @@ def test_compiled_sql_flat_path_pushed(self, tmp_path): assert ti.store.get("model__pkg__foo_compiled_sql") == "SELECT 1" -def test_producer_always_uses_subprocess_invocation_mode(): - """DbtProducerWatcherOperator always forces InvocationMode.SUBPROCESS regardless of what was passed.""" - from cosmos.config import InvocationMode - +def test_producer_does_not_force_invocation_mode(): + """DbtProducerWatcherOperator does not force an invocation_mode; auto-discovery runs at runtime.""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + assert op.invocation_mode is None # resolved lazily by _discover_invocation_mode() + + +def test_producer_respects_explicit_invocation_mode(): + """An explicit invocation_mode passed by the caller is preserved unchanged.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None, invocation_mode=InvocationMode.SUBPROCESS) assert op.invocation_mode == InvocationMode.SUBPROCESS - assert op.log_format == "json" + + op2 = DbtProducerWatcherOperator(project_dir=".", profile_config=None, invocation_mode=InvocationMode.DBT_RUNNER) + assert op2.invocation_mode == InvocationMode.DBT_RUNNER + + +def test_run_subprocess_sets_process_log_line_callable(): + """run_subprocess wires up _process_log_line_callable before executing the subprocess.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + assert op._process_log_line_callable is None + + with patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess", return_value=MagicMock()): + op.run_subprocess(command=["dbt", "build"], env={}, cwd="/tmp/proj") + + assert op._process_log_line_callable is not None + + +def test_run_dbt_runner_registers_event_callback(): + """run_dbt_runner appends an EventMsg→JSON→parse callback to _dbt_runner_callbacks.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + assert not op._dbt_runner_callbacks + + mock_ti = _MockTI() + context = {"ti": mock_ti, "run_id": "run-1"} + + with patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner", return_value=MagicMock()): + op.run_dbt_runner(command=["dbt", "build"], env={}, cwd="/tmp/proj", context=context) + + assert len(op._dbt_runner_callbacks) == 1 + + +def test_run_dbt_runner_event_callback_calls_store_from_log(): + """The registered callback converts an EventMsg to JSON and passes it to store_dbt_resource_status_from_log.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + mock_ti = _MockTI() + context = {"ti": mock_ti, "run_id": "run-1"} + + fake_json = '{"info": {"name": "NodeFinished"}, "data": {}}' + fake_event = MagicMock() + + # Patch store_dbt_resource_status_from_log *before* run_dbt_runner so that _make_parse_callable + # captures the mock through functools.partial, not the real function. + with ( + patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner", return_value=MagicMock()), + patch("cosmos.operators.watcher.store_dbt_resource_status_from_log") as mock_parse, + patch("google.protobuf.json_format.MessageToJson", return_value=fake_json) as mock_to_json, + ): + op.run_dbt_runner(command=["dbt", "build"], env={}, cwd="/tmp/proj", context=context) + callback = op._dbt_runner_callbacks[0] + callback(fake_event) + + mock_to_json.assert_called_once_with(fake_event, preserving_proto_field_name=True) + mock_parse.assert_called_once() + call_args = mock_parse.call_args + assert call_args[0][0] == fake_json # first positional arg is the JSON string + assert call_args[0][1]["project_dir"] == "/tmp/proj" + assert call_args[0][1]["context"] is context MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" From 9df8d422aad0ba545f9ffe304cd0affee5572950 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 09:42:13 +0000 Subject: [PATCH 03/13] Try to fix integration tests --- cosmos/operators/_watcher/base.py | 29 ++++++++++++++++++----------- cosmos/operators/watcher.py | 20 ++++++++++++-------- tests/dbt/test_runner.py | 6 +----- tests/operators/test_watcher.py | 16 ++++++++-------- 4 files changed, 39 insertions(+), 32 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 90ad8ba1e3..98ad03633b 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -174,6 +174,21 @@ def _store_startup_event_from_log(task_instance: Any, log_line: dict[str, Any]) safe_xcom_push(task_instance=task_instance, key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=current) +def _log_dbt_msg(log_line: dict[str, Any]) -> None: + """Log the human-readable message from a parsed dbt JSON log line.""" + log_info = log_line.get("info", {}) + msg = log_info.get("msg") + if msg is None: + return + level = log_info.get("level", "INFO").upper() + ts = log_info.get("ts") + formatted_ts = _iso_to_string(ts) + if formatted_ts: + logger.log(getattr(logging, level, logging.INFO), "%s %s", formatted_ts, msg) + else: + logger.log(getattr(logging, level, logging.INFO), msg) + + def store_dbt_resource_status_from_log( line: str, extra_kwargs: Any, @@ -224,7 +239,8 @@ def store_dbt_resource_status_from_log( # TODO: handle all possible statuses including skipped, warn, etc. if is_dbt_node_status_terminal(dbt_node_status): context = extra_kwargs.get("context") - assert context is not None # Make MyPy happy + if context is None: + return if dbt_node_resource_type == "test" and tests_per_model and test_results_per_model is not None: logger.debug("Test '%s' finished with status '%s'", unique_id, dbt_node_status) push_test_result_or_aggregate( @@ -244,16 +260,7 @@ def store_dbt_resource_status_from_log( ) # Additionally, log the message from dbt logs - log_info = log_line.get("info", {}) - msg = log_info.get("msg") - level = log_info.get("level", "INFO").upper() - ts = log_info.get("ts") - if msg is not None: - formatted_ts = _iso_to_string(ts) - if formatted_ts: - logger.log(getattr(logging, level, logging.INFO), "%s %s", formatted_ts, msg) - else: - logger.log(getattr(logging, level, logging.INFO), msg) + _log_dbt_msg(log_line) class BaseConsumerSensor(BaseSensorOperator): # type: ignore[misc] diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 17e0737b56..680de8a80a 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -120,20 +120,24 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kw ``google.protobuf.json_format`` is a transitive dependency of dbt-core and is always available when ``InvocationMode.DBT_RUNNER`` is in use. + + The callback is only registered when ``context`` is present (i.e. during task execution, + not during auxiliary calls such as ``dbt deps``). Without a context there is no XCom + backend to push to, so registering a callback would cause it to raise and dbt would emit + ``GenericExceptionOnRun`` for every node. """ context = kwargs.get("context") - extra_kwargs: dict[str, Any] = {"project_dir": cwd} if context is not None: - extra_kwargs["context"] = context - parse = self._make_parse_callable() + extra_kwargs: dict[str, Any] = {"project_dir": cwd, "context": context} + parse = self._make_parse_callable() - def _event_callback(event: Any) -> None: - from google.protobuf.json_format import MessageToJson + def _event_callback(event: Any) -> None: + from google.protobuf.json_format import MessageToJson - json_str = MessageToJson(event, preserving_proto_field_name=True) - parse(json_str, extra_kwargs) + json_str = MessageToJson(event, preserving_proto_field_name=True) + parse(json_str, extra_kwargs) - self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] + self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] return super().run_dbt_runner(command, env, cwd, **kwargs) def execute(self, context: Context, **kwargs: Any) -> Any: diff --git a/tests/dbt/test_runner.py b/tests/dbt/test_runner.py index a696b9954d..e312dad4f2 100644 --- a/tests/dbt/test_runner.py +++ b/tests/dbt/test_runner.py @@ -299,11 +299,7 @@ def invoke(self, *args): ) op2.invocation_mode = InvocationMode.DBT_RUNNER - class _DummyEv: - pass - - with patch("cosmos.operators.watcher.EventMsg", _DummyEv): - op2.execute(context=mock_context) + op2.execute(context=mock_context) # Verify: # 1. We have two dbt Runner instances (cached + new with callbacks) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index aa90765c0e..abd2740ec1 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -1294,12 +1294,14 @@ async def mock_producer_failed(): @pytest.mark.integration -def test_dbt_dag_with_watcher(capsys): +def test_dbt_dag_with_watcher(caplog): """ Run a DbtDag using `ExecutionMode.WATCHER`. Confirm the right amount of tasks is created and that tasks are in the expected topological order. Confirm that the producer watcher task is created and that it is the parent of the root dbt nodes. """ + caplog.set_level(logging.INFO, logger="cosmos.operators._watcher.base") + watcher_dag = DbtDag( project_config=project_config, profile_config=profile_config, @@ -1357,20 +1359,16 @@ def test_dbt_dag_with_watcher(capsys): "raw_customers_seed", } - # dbt runner logs are not captured by caplog, so we need to capture them using capsys - capsys_output = capsys.readouterr() - stdout = capsys_output.out - assert ( '''"node_status": "success", "resource_type": "seed", "unique_id": "seed.jaffle_shop.raw_orders"''' - not in stdout + not in caplog.text ) log_message = "OK loaded seed file public.raw_orders" - assert log_message in stdout + assert log_message in caplog.text # Verify that log messages are not duplicated (each dbt message should appear only once) - message_count = stdout.count(log_message) + message_count = caplog.text.count(log_message) assert message_count == 1, f"Expected '{log_message}' to be logged exactly once, but found {message_count} times" @@ -1526,6 +1524,8 @@ def test_dbt_dag_with_watcher_and_empty_model(caplog): # 10:29:03 # 10:29:03 Done. PASS=1 WARN=0 ERROR=0 SKIP=0 NO-OP=0 TOTAL=1 + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + watcher_dag = DbtDag( project_config=project_config, profile_config=profile_config, From 68c45c73d1df892b9276ef46cba21b1c35862869 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 09:53:12 +0000 Subject: [PATCH 04/13] Fix unittests --- cosmos/operators/watcher.py | 8 ++++++++ tests/hooks/test_subprocess.py | 25 ++++++++++--------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 680de8a80a..b51af28a93 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,6 +1,8 @@ from __future__ import annotations +import contextlib import functools +import io from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any @@ -125,6 +127,10 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kw not during auxiliary calls such as ``dbt deps``). Without a context there is no XCom backend to push to, so registering a callback would cause it to raise and dbt would emit ``GenericExceptionOnRun`` for every node. + + When a callback is registered, dbt's stdout is redirected to a null buffer so that the + raw ``--log-format json`` lines do not appear in Airflow task logs alongside the + human-readable messages already emitted by ``_log_dbt_msg`` inside the callback. """ context = kwargs.get("context") if context is not None: @@ -138,6 +144,8 @@ def _event_callback(event: Any) -> None: parse(json_str, extra_kwargs) self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] + with contextlib.redirect_stdout(io.StringIO()): + return super().run_dbt_runner(command, env, cwd, **kwargs) return super().run_dbt_runner(command, env, cwd, **kwargs) def execute(self, context: Context, **kwargs: Any) -> Any: diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 00e1bcd351..6611a81b84 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -81,33 +81,28 @@ def test_send_sigterm(mock_killpg, mock_getpgid): @pytest.mark.parametrize( - "status,context,expect_assert", + "status,context,expect_xcom_push", [ - ("success", {"ti": MagicMock()}, False), - ("failed", {"ti": MagicMock()}, False), - ("success", None, True), - ("failed", None, True), + ("success", {"ti": MagicMock()}, True), + ("failed", {"ti": MagicMock()}, True), + ("success", None, False), + ("failed", None, False), ], ) -def test_store_dbt_resource_status_from_log_param(status, context, expect_assert): +def test_store_dbt_resource_status_from_log_param(status, context, expect_xcom_push): # Prepare log line log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}} line = json.dumps(log_line) with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: - if expect_assert: - with pytest.raises(AssertionError): - store_dbt_resource_status_from_log( - line, {"context": context}, tests_per_model={}, test_results_per_model={} - ) - else: - store_dbt_resource_status_from_log( - line, {"context": context}, tests_per_model={}, test_results_per_model={} - ) + store_dbt_resource_status_from_log(line, {"context": context}, tests_per_model={}, test_results_per_model={}) + if expect_xcom_push: mock_push.assert_called_with( task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status ) assert mock_push.call_count == 1 + else: + mock_push.assert_not_called() def test_store_dbt_resource_status_from_log_invalid_json(): From ca56f7cb6790959024294634a98ac6f087d4f4a4 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 10:51:44 +0000 Subject: [PATCH 05/13] Address memory utilisation concern raised during PR review --- cosmos/operators/watcher.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 2ec5b02f07..324dc8f6d6 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -2,7 +2,6 @@ import contextlib import functools -import io from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any @@ -47,6 +46,22 @@ logger = get_logger(__name__) +class _NullWriter: + """Write-only sink that discards all data; used to suppress dbt stdout in DBT_RUNNER mode. + + Preferred over ``io.StringIO()`` because StringIO buffers every byte written to it for the + lifetime of the context manager. On large projects dbt emits megabytes of JSON log lines, + so StringIO would grow unbounded and increase worker memory usage proportionally to project + size and verbosity. _NullWriter discards each write immediately with no allocation. + """ + + def write(self, *args: Any) -> int: + return 0 + + def flush(self) -> None: + pass + + class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): """Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode. @@ -138,7 +153,7 @@ def _event_callback(event: Any) -> None: parse(json_str, extra_kwargs) self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] - with contextlib.redirect_stdout(io.StringIO()): + with contextlib.redirect_stdout(_NullWriter()): return super().run_dbt_runner(command, env, cwd, **kwargs) return super().run_dbt_runner(command, env, cwd, **kwargs) From 319fd34d77a7fda5410e22ef86db957181dcf771 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 11:46:21 +0000 Subject: [PATCH 06/13] Tackle issues while parsing dbt events gracefully, so sensors are not hanging --- cosmos/operators/watcher.py | 23 +++++++++++++++++++---- tests/operators/test_watcher.py | 23 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 324dc8f6d6..e09a47dee4 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -145,16 +145,31 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kw if context is not None: extra_kwargs: dict[str, Any] = {"project_dir": cwd, "context": context} parse = self._make_parse_callable() + # Collect callback errors rather than raising inside the callback: dbt catches + # exceptions raised by callbacks and wraps them as GenericExceptionOnRun, which + # would cause the build to emit spurious failures but potentially still succeed. + # Instead we capture the first error here and re-raise it after the dbt run so it + # propagates through execute(), triggering the existing task_status XCom mechanism + # that signals consumer sensors to check the producer task state. + callback_error: list[BaseException] = [] def _event_callback(event: Any) -> None: - from google.protobuf.json_format import MessageToJson + try: + from google.protobuf.json_format import MessageToJson - json_str = MessageToJson(event, preserving_proto_field_name=True) - parse(json_str, extra_kwargs) + json_str = MessageToJson(event, preserving_proto_field_name=True) + parse(json_str, extra_kwargs) + except Exception as e: + logger.exception("Error in dbt event callback: %s", e) + if not callback_error: + callback_error.append(e) self._dbt_runner_callbacks = [*(self._dbt_runner_callbacks or []), _event_callback] with contextlib.redirect_stdout(_NullWriter()): - return super().run_dbt_runner(command, env, cwd, **kwargs) + result = super().run_dbt_runner(command, env, cwd, **kwargs) + if callback_error: + raise callback_error[0] + return result return super().run_dbt_runner(command, env, cwd, **kwargs) def execute(self, context: Context, **kwargs: Any) -> Any: diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 87ab2b8df0..e864076d27 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -767,6 +767,29 @@ def test_run_dbt_runner_event_callback_calls_store_from_log(): assert call_args[0][1]["context"] is context +def test_run_dbt_runner_callback_error_fails_producer_after_run(caplog): + """A callback error must not surface as GenericExceptionOnRun inside dbt; instead it must be + re-raised after the dbt run so it propagates through execute() and triggers the task_status + XCom push that signals consumer sensors to check the producer task state.""" + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + context = {"ti": _MockTI(), "run_id": "run-1"} + + def fake_run_dbt_runner(self_inner, command, env, cwd, **kw): + # Simulate dbt calling the registered callback for one event, as the real runner would. + for cb in op._dbt_runner_callbacks or []: + cb(MagicMock()) + + with ( + patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner", fake_run_dbt_runner), + patch("google.protobuf.json_format.MessageToJson", side_effect=RuntimeError("serialisation error")), + caplog.at_level(logging.ERROR), + pytest.raises(RuntimeError, match="serialisation error"), + ): + op.run_dbt_runner(command=["dbt", "build"], env={}, cwd="/tmp/proj", context=context) + + assert "Error in dbt event callback" in caplog.text + + MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" From c548c23aee6021e02b864d88f78702b10b909747 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 11:50:49 +0000 Subject: [PATCH 07/13] Add interface backward-compatibility. Removing the use_event parameter from WatcherTrigger.__init__ is a compatibility risk for in-flight deferred tasks: triggers serialized before this change may still contain a use_event kwarg, and deserialization would raise TypeError: __init__() got an unexpected keyword argument 'use_event'. --- cosmos/operators/_watcher/triggerer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index ff44dc9c45..0b5b404e64 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -42,6 +42,11 @@ def __init__( map_index: int | None, poke_interval: float = 5.0, is_test_sensor: bool = False, + # Accepted for upgrade-compatibility only: triggers serialized before the + # invocation-mode unification may still carry this kwarg. It is no longer + # used because both SUBPROCESS and DBT_RUNNER now push the same *_status + # XCom keys, so the trigger does not need to know the invocation mode. + use_event: bool = True, # noqa: ARG002 ): self.model_unique_id = model_unique_id self.producer_task_id = producer_task_id From 47f7463b164268c117573baaa4d437a9f9183606 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 11:53:15 +0000 Subject: [PATCH 08/13] =?UTF-8?q?Fix=20incorrect=20signature=20for=20write?= =?UTF-8?q?=20method.=20=20matches=20io.TextIOBase.write(s:=20str)=20->=20?= =?UTF-8?q?int=20exactly=20=E2=80=94=20data=20is=20still=20discarded,=20bu?= =?UTF-8?q?t=20the=20contract=20is=20correct:=20callers=20get=20back=20the?= =?UTF-8?q?=20number=20of=20=20=20characters=20they=20wrote,=20which=20is?= =?UTF-8?q?=20what=20the=20standard=20IO=20protocol=20requires.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/operators/watcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index e09a47dee4..55662f1499 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -55,8 +55,8 @@ class _NullWriter: size and verbosity. _NullWriter discards each write immediately with no allocation. """ - def write(self, *args: Any) -> int: - return 0 + def write(self, s: str) -> int: + return len(s) def flush(self) -> None: pass From b1a85fa3267c2c50382a269775037047f41db4e7 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 11:57:51 +0000 Subject: [PATCH 09/13] Address copilot concern in https://github.com/astronomer/astronomer-cosmos/pull/2498\#discussion_r2980974306. The fix is to hold xcom_set_lock across the entire pull+append+push. Since safe_xcom_push would re-acquire the lock and deadlock, we call xcom_push directly while already holding it. --- cosmos/operators/_watcher/base.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index e66828db6d..0a91983cdb 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -25,6 +25,7 @@ is_dbt_node_status_success, is_dbt_node_status_terminal, safe_xcom_push, + xcom_set_lock, ) from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger @@ -162,6 +163,12 @@ def _store_startup_event_from_log(task_instance: Any, log_line: dict[str, Any]) """ When dbt JSON log contains MainReportVersion or AdapterRegistered, append to dbt_startup_events XCom (same shape as runner path) for trigger to log versions. + + The pull+append+push is performed under ``xcom_set_lock`` to prevent a race + condition: dbt runner callbacks are invoked from multiple threads, so two + startup events arriving concurrently could both read the same stale list and + one append would be silently lost. Holding the same lock used by + ``safe_xcom_push`` makes the entire read-modify-write atomic. """ event_name = log_line.get("info", {}).get("name") if event_name not in ("MainReportVersion", "AdapterRegistered"): @@ -169,9 +176,13 @@ def _store_startup_event_from_log(task_instance: Any, log_line: dict[str, Any]) info = log_line.get("info", {}) msg = info.get("msg", "") ts = info.get("ts", "") - current = list(task_instance.xcom_pull(key=_DBT_STARTUP_EVENTS_XCOM_KEY) or []) - current.append({"name": event_name, "msg": msg, "ts": ts}) - safe_xcom_push(task_instance=task_instance, key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=current) + # Hold the lock for the full read-modify-write cycle. We call xcom_push + # directly (bypassing safe_xcom_push) to avoid a deadlock: Lock is not + # re-entrant, so acquiring it again inside safe_xcom_push would block forever. + with xcom_set_lock: + current = list(task_instance.xcom_pull(key=_DBT_STARTUP_EVENTS_XCOM_KEY) or []) + current.append({"name": event_name, "msg": msg, "ts": ts}) + task_instance.xcom_push(key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=current) def _log_dbt_msg(log_line: dict[str, Any]) -> None: From 7d776dbae076b5e5aec9b82a3477c0fd7b5429ac Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 24 Mar 2026 13:52:35 +0000 Subject: [PATCH 10/13] Update cosmos/operators/_watcher/base.py --- cosmos/operators/_watcher/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 0a91983cdb..96f138df94 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -85,12 +85,12 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None: event_name = info.get("name") if event_name not in _DBT_EVENT_ALLOWLIST: return None - node_info = data.get("node_info") - status = node_info.get("node_status") if node_info else None - unique_id = node_info.get("unique_id") if node_info else None - start_time = node_info.get("node_started_at") if node_info else None - finish_time = node_info.get("node_finished_at") if node_info else None - msg = data.get("msg") or info.get("msg") or None + node_info = data.get("node_info") or {} + status = node_info.get("node_status") + unique_id = node_info.get("unique_id") + start_time = node_info.get("node_started_at") + finish_time = node_info.get("node_finished_at") + msg = data.get("msg") or info.get("msg") if unique_id: dbt_event = { From 1f6a7ada62005930f687ad2738a5b9efb114fefd Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 26 Mar 2026 14:29:59 +0000 Subject: [PATCH 11/13] Apply suggestion from @tatiana --- cosmos/operators/_watcher/triggerer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 0b5b404e64..0b19c465b7 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -43,7 +43,7 @@ def __init__( poke_interval: float = 5.0, is_test_sensor: bool = False, # Accepted for upgrade-compatibility only: triggers serialized before the - # invocation-mode unification may still carry this kwarg. It is no longer + # invocation-mode unification may still carry this kwarg (Cosmos < 1.14.0). It is no longer # used because both SUBPROCESS and DBT_RUNNER now push the same *_status # XCom keys, so the trigger does not need to know the invocation mode. use_event: bool = True, # noqa: ARG002 From 5e42d7eb4f38f301640357aa513d3e3ee0be9adb Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 26 Mar 2026 14:31:05 +0000 Subject: [PATCH 12/13] Apply suggestion from @michal-mrazek MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michal Mrázek <121952333+michal-mrazek@users.noreply.github.com> --- cosmos/operators/_watcher/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 96f138df94..3bfdf8f56e 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -250,6 +250,11 @@ def store_dbt_resource_status_from_log( if is_dbt_node_status_terminal(dbt_node_status): context = extra_kwargs.get("context") if context is None: + logger.warning( + "context is None for terminal node '%s' — XCom status will not be pushed. " + "This is unexpected and should never happen; check the caller is passing context correctly.", + unique_id, + ) return if dbt_node_resource_type == "test" and tests_per_model and test_results_per_model is not None: logger.debug("Test '%s' finished with status '%s'", unique_id, dbt_node_status) From 8834e50a3791cfecfa58d0702878eb79d49b9112 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 26 Mar 2026 14:32:10 +0000 Subject: [PATCH 13/13] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/operators/_watcher/test_triggerer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index 060d97334f..c390a8de65 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -31,7 +31,7 @@ def test_serialize(self): assert args["poke_interval"] == 0.001 assert "use_event" not in args - @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Require Airflow < 3.0.0") + @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Require Airflow >= 3.0.0") @pytest.mark.asyncio async def test_get_xcom_val_af3(self): expected_value = {"foo": "bar"}