Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cosmos/operators/_watcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"is_dbt_node_status_failed",
"is_dbt_node_status_terminal",
"WatcherTrigger",
"_parse_compressed_xcom",
]

from cosmos.operators._watcher.state import (
Expand All @@ -19,4 +18,4 @@
is_dbt_node_status_terminal,
safe_xcom_push,
)
from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom
from cosmos.operators._watcher.triggerer import WatcherTrigger
148 changes: 55 additions & 93 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
from cosmos.log import get_logger
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key, push_test_result_or_aggregate
from cosmos.operators._watcher.state import (
DBT_FAILED_STATUSES,
_iso_to_string,
_log_dbt_event,
build_producer_state_fetcher,
get_xcom_val,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
safe_xcom_push,
xcom_set_lock,
)
from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger, _parse_compressed_xcom
from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger

try:
from airflow.sdk.bases.sensor import BaseSensorOperator
Expand All @@ -36,11 +36,6 @@
from airflow.sensors.base import BaseSensorOperator
from airflow.utils.context import Context # type: ignore[attr-defined]

try:
from dbt_common.events.base_types import EventMsg
except ImportError: # pragma: no cover
EventMsg = None

logger = get_logger(__name__)

# Subset of dbt event types that represent errors/failures.
Expand Down Expand Up @@ -82,31 +77,20 @@
_DBT_EVENT_ALLOWLIST = _DBT_ERROR_EVENTS_TYPES | _DBT_NODE_STATUS_EVENT_TYPES


def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any] | EventMsg) -> None:
def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None:
logger.debug("dbt_log: %s", dbt_log)
if isinstance(dbt_log, dict): # Subprocess
data = dbt_log.get("data", {})
info = dbt_log.get("info", {})
data = dbt_log.get("data", {})
info = dbt_log.get("info", {})

event_name = info.get("name")
if event_name not in _DBT_EVENT_ALLOWLIST:
return None
node_info = data.get("node_info")
status = node_info.get("node_status") if node_info else None
unique_id = node_info.get("unique_id") if node_info else None
start_time = node_info.get("node_started_at") if node_info else None
finish_time = node_info.get("node_finished_at") if node_info else None
msg = data.get("msg") or info.get("msg") or None
else: # Runner
event_name = getattr(getattr(dbt_log, "info", None), "name", None)
if event_name not in _DBT_EVENT_ALLOWLIST:
return None
node_info = getattr(dbt_log.data, "node_info", None)
unique_id = getattr(node_info, "unique_id") if node_info else None
status = getattr(node_info, "node_status", None) if node_info else None
start_time = getattr(node_info, "node_started_at", None) if node_info else None
finish_time = getattr(node_info, "node_finished_at", None) if node_info else None
msg = getattr(dbt_log.info, "msg", None)
event_name = info.get("name")
if event_name not in _DBT_EVENT_ALLOWLIST:
return None
node_info = data.get("node_info") or {}
status = node_info.get("node_status")
unique_id = node_info.get("unique_id")
start_time = node_info.get("node_started_at")
finish_time = node_info.get("node_finished_at")
msg = data.get("msg") or info.get("msg")

if unique_id:
dbt_event = {
Expand All @@ -126,9 +110,7 @@ def _extract_compiled_sql(
"""
Extract compiled SQL from the target directory for a given dbt node.

Used by both the subprocess strategy (via store_dbt_resource_status_from_log)
and the node-event strategy (via DbtProducerWatcherOperator._handle_node_finished);
both consume from the same target/compiled layout under project_dir.
Used by store_dbt_resource_status_from_log; reads from the target/compiled layout under project_dir.

Assumes inputs come from dbt (relative node_path, unique_id like model.package.name).
"""
Expand Down Expand Up @@ -181,16 +163,41 @@ def _store_startup_event_from_log(task_instance: Any, log_line: dict[str, Any])
"""
When dbt JSON log contains MainReportVersion or AdapterRegistered, append to
dbt_startup_events XCom (same shape as runner path) for trigger to log versions.

The pull+append+push is performed under ``xcom_set_lock`` to prevent a race
condition: dbt runner callbacks are invoked from multiple threads, so two
startup events arriving concurrently could both read the same stale list and
one append would be silently lost. Holding the same lock used by
``safe_xcom_push`` makes the entire read-modify-write atomic.
"""
event_name = log_line.get("info", {}).get("name")
if event_name not in ("MainReportVersion", "AdapterRegistered"):
return
info = log_line.get("info", {})
msg = info.get("msg", "")
ts = info.get("ts", "")
current = list(task_instance.xcom_pull(key=_DBT_STARTUP_EVENTS_XCOM_KEY) or [])
current.append({"name": event_name, "msg": msg, "ts": ts})
safe_xcom_push(task_instance=task_instance, key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=current)
# Hold the lock for the full read-modify-write cycle. We call xcom_push
# directly (bypassing safe_xcom_push) to avoid a deadlock: Lock is not
# re-entrant, so acquiring it again inside safe_xcom_push would block forever.
with xcom_set_lock:
current = list(task_instance.xcom_pull(key=_DBT_STARTUP_EVENTS_XCOM_KEY) or [])
current.append({"name": event_name, "msg": msg, "ts": ts})
task_instance.xcom_push(key=_DBT_STARTUP_EVENTS_XCOM_KEY, value=current)


def _log_dbt_msg(log_line: dict[str, Any]) -> None:
"""Log the human-readable message from a parsed dbt JSON log line."""
log_info = log_line.get("info", {})
msg = log_info.get("msg")
if msg is None:
return
level = log_info.get("level", "INFO").upper()
ts = log_info.get("ts")
formatted_ts = _iso_to_string(ts)
if formatted_ts:
logger.log(getattr(logging, level, logging.INFO), "%s %s", formatted_ts, msg)
else:
logger.log(getattr(logging, level, logging.INFO), msg)


def store_dbt_resource_status_from_log(
Expand Down Expand Up @@ -242,7 +249,13 @@ def store_dbt_resource_status_from_log(
# TODO: handle all possible statuses including skipped, warn, etc.
if is_dbt_node_status_terminal(dbt_node_status):
context = extra_kwargs.get("context")
assert context is not None # Make MyPy happy
if context is None:
Comment thread
tatiana marked this conversation as resolved.
logger.warning(
"context is None for terminal node '%s' — XCom status will not be pushed. "
"This is unexpected and should never happen; check the caller is passing context correctly.",
unique_id,
)
return
if dbt_node_resource_type == "test" and tests_per_model and test_results_per_model is not None:
logger.debug("Test '%s' finished with status '%s'", unique_id, dbt_node_status)
push_test_result_or_aggregate(
Expand All @@ -262,16 +275,7 @@ def store_dbt_resource_status_from_log(
)

# Additionally, log the message from dbt logs
log_info = log_line.get("info", {})
msg = log_info.get("msg")
level = log_info.get("level", "INFO").upper()
ts = log_info.get("ts")
if msg is not None:
formatted_ts = _iso_to_string(ts)
if formatted_ts:
logger.log(getattr(logging, level, logging.INFO), "%s %s", formatted_ts, msg)
else:
logger.log(getattr(logging, level, logging.INFO), msg)
_log_dbt_msg(log_line)


class BaseConsumerSensor(BaseSensorOperator): # type: ignore[misc]
Expand Down Expand Up @@ -375,39 +379,6 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo
logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id)
return True

def _get_status_from_run_results(self, ti: Any, context: Context) -> Any:
compressed_b64_run_results = ti.xcom_pull(task_ids=self.producer_task_id, key="run_results")

if not compressed_b64_run_results:
return None

run_results_json = _parse_compressed_xcom(compressed_b64_run_results)

logger.debug("Run results: %s", run_results_json)

results = run_results_json.get("results", [])
node_result = next((r for r in results if r.get("unique_id") == self.model_unique_id), None)

if not node_result: # pragma: no cover
logger.warning(
"The dbt node with unique_id '%s' was not executed by the dbt command run in the producer task. This may happen if it is an ephemeral model or if the model sql file is empty.",
self.model_unique_id,
)
return None

logger.info("Node Info: %s", run_results_json)

status = node_result.get("status")

if status in DBT_FAILED_STATUSES:
logger.error("%s", node_result.get("message"))

self.compiled_sql = node_result.get("compiled_code")
if self.compiled_sql and hasattr(self, "_override_rtif"):
self._override_rtif(context)

return status

def _get_producer_task_status(self, context: Context) -> str | None:
"""
Get the task status of the producer task for both Airflow 2 and Airflow 3.
Expand Down Expand Up @@ -441,7 +412,6 @@ def execute(self, context: Context, **kwargs: Any) -> None:
dag_id=self.dag_id,
run_id=context["run_id"],
map_index=context["task_instance"].map_index,
use_event=self.use_event(),
poke_interval=self.poke_interval,
is_test_sensor=self.is_test_sensor,
),
Expand Down Expand Up @@ -486,17 +456,11 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
f"Watcher producer task '{self.producer_task_id}' failed before reporting results for {self._resource_label.lower()} '{self.model_unique_id}'. Check its logs for the underlying error."
)

def use_event(self) -> bool:
raise NotImplementedError("Subclasses must implement this method")

def _get_status_from_events(self, ti: Any, context: Context) -> Any:
raise NotImplementedError("Subclasses should implement this method if `use_event` may return True")

def _log_startup_events(self, ti: Any) -> None:
dbt_startup_events: list[dict[str, Any]] = ti.xcom_pull(
task_ids=self.producer_task_id, key=_DBT_STARTUP_EVENTS_XCOM_KEY
)
if dbt_startup_events: # pragma: no cover
if isinstance(dbt_startup_events, list) and dbt_startup_events: # pragma: no cover
for event in dbt_startup_events:
# Adding debug level to avoid redundant logs for non-deferrable mode
logger.debug("%s", event.get("msg"))
Expand All @@ -505,13 +469,12 @@ def _get_node_status(self, ti: Any, context: Context) -> Any:
"""Return the current status of the watched dbt node from XCom.

For test sensors, reads the aggregated ``_tests_status`` key.
For model sensors, reads from event-based or subprocess-based keys.
For model sensors, reads the per-model ``*_status`` key (same for both
SUBPROCESS and DBT_RUNNER invocation modes).
"""
if self.is_test_sensor:
xcom_key = get_tests_status_xcom_key(self.model_unique_id)
return get_xcom_val(ti, self.producer_task_id, xcom_key)
if self.use_event():
return self._get_status_from_events(ti, context)
return get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status")

def poke(self, context: Context) -> bool:
Expand All @@ -534,9 +497,8 @@ def poke(self, context: Context) -> bool:
if try_number > 1:
return self._fallback_to_non_watcher_run(try_number, context)

# We have assumption here that both the build producer and the sensor task will have same invocation mode
producer_task_state = self._get_producer_task_status(context)
if not self.use_event() and not self.is_test_sensor:
if not self.is_test_sensor:
self._log_startup_events(ti)
status = self._get_node_status(ti, context)

Expand Down
42 changes: 11 additions & 31 deletions cosmos/operators/_watcher/triggerer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import json
import zlib
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any
Expand Down Expand Up @@ -43,16 +40,19 @@ def __init__(
dag_id: str,
run_id: str,
map_index: int | None,
use_event: bool,
poke_interval: float = 5.0,
is_test_sensor: bool = False,
Comment thread
tatiana marked this conversation as resolved.
# Accepted for upgrade-compatibility only: triggers serialized before the
Comment thread
tatiana marked this conversation as resolved.
# invocation-mode unification may still carry this kwarg (Cosmos < 1.14.0). It is no longer
# used because both SUBPROCESS and DBT_RUNNER now push the same *_status
# XCom keys, so the trigger does not need to know the invocation mode.
use_event: bool = True, # noqa: ARG002
):
self.model_unique_id = model_unique_id
self.producer_task_id = producer_task_id
self.dag_id = dag_id
self.run_id = run_id
self.map_index = map_index
self.use_event = use_event
self.poke_interval = poke_interval
self.is_test_sensor = is_test_sensor

Expand All @@ -65,7 +65,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"dag_id": self.dag_id,
"run_id": self.run_id,
"map_index": self.map_index,
"use_event": self.use_event,
"poke_interval": self.poke_interval,
"is_test_sensor": self.is_test_sensor,
},
Expand Down Expand Up @@ -118,21 +117,8 @@ async def get_xcom_val(self, key: str) -> Any | None:
return await self.get_xcom_val_af3(key)

async def _get_node_status(self) -> Any | None:
status_key = (
f"nodefinished_{self.model_unique_id.replace('.', '__')}"
if self.use_event
else f"{self.model_unique_id.replace('.', '__')}_status"
)

if self.use_event:
compressed_xcom_val = await self.get_xcom_val(status_key)
if not compressed_xcom_val:
return None
data_json = _parse_compressed_xcom(compressed_xcom_val)
status = data_json.get("data", {}).get("run_result", {}).get("status")
else:
status = await self.get_xcom_val(status_key)
return status
status_key = f"{self.model_unique_id.replace('.', '__')}_status"
return await self.get_xcom_val(status_key)

async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]:
"""
Expand All @@ -143,9 +129,10 @@ async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str
For test sensors (``is_test_sensor=True``), the aggregated test status is
read from the ``<model_uid>_tests_status`` key. No compiled_sql is relevant.

For regular sensors, status comes from mode-specific keys
(nodefinished_* for event, *_status for subprocess).
compiled_sql is always read from the canonical per-model key (same for both modes).
For regular sensors, status is read from the per-model ``*_status`` XCom key
pushed by store_dbt_resource_status_from_log (same key for both SUBPROCESS
and DBT_RUNNER invocation modes).
compiled_sql is always read from the canonical per-model ``*_compiled_sql`` key.
"""
if self.is_test_sensor:
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key
Expand Down Expand Up @@ -257,10 +244,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
# Sleep briefly before re-polling
await asyncio.sleep(self.poke_interval)
logger.debug("Polling again for node '%s' status...", self.model_unique_id)


def _parse_compressed_xcom(compressed_b64_event_msg: str) -> Any:
"""Decode and decompress a base64-encoded, zlib-compressed XCom payload."""
compressed_bytes = base64.b64decode(compressed_b64_event_msg)
event_json_str = zlib.decompress(compressed_bytes).decode("utf-8")
return json.loads(event_json_str)
Loading
Loading