Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
WATCHER_TASK_WEIGHT_RULE,
)
from cosmos.log import get_logger
from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key, push_test_result_or_aggregate
from cosmos.operators._watcher.state import (
DBT_FAILED_STATUSES,
_iso_to_string,
Expand All @@ -27,7 +27,7 @@
is_dbt_node_status_terminal,
safe_xcom_push,
)
from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom
from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger, _parse_compressed_xcom

try:
from airflow.sdk.bases.sensor import BaseSensorOperator
Expand Down Expand Up @@ -312,6 +312,16 @@ def __init__(
self.deferrable = deferrable
self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id")

@property
def is_test_sensor(self) -> bool:
"""Whether this sensor watches aggregated test results instead of individual node results."""
return False

@property
def _resource_label(self) -> str:
"""Human-readable label for log and error messages."""
return "Tests for model" if self.is_test_sensor else "Model"

@staticmethod
def _filter_flags(flags: list[str]) -> list[str]:
"""Filters out dbt flags that are incompatible with retry (e.g., --select, --exclude)."""
Expand All @@ -334,7 +344,16 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo
Handles logic for retrying a failed dbt model execution.
Reconstructs the dbt command by cloning the project and re-running the model
with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded.

For test sensors, re-execution is not supported in watcher mode; retries are skipped.
"""
if self.is_test_sensor:
raise AirflowException(
f"Test re-execution is not yet supported in watcher mode. "
f"{self._resource_label} '{self.model_unique_id}' cannot be retried. "
f"A future release will add fallback to local test execution."
)
Comment thread
michal-mrazek marked this conversation as resolved.

Comment thread
michal-mrazek marked this conversation as resolved.
logger.info(
f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}",
try_number - 1,
Expand Down Expand Up @@ -425,6 +444,7 @@ def execute(self, context: Context, **kwargs: Any) -> None:
map_index=context["task_instance"].map_index,
use_event=self.use_event(),
poke_interval=self.poke_interval,
is_test_sensor=self.is_test_sensor,
),
timeout=self.execution_timeout,
method_name=self.execute_complete.__name__,
Expand All @@ -434,9 +454,10 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
status = event.get("status")
reason = event.get("reason")

if status == "success" and reason == "model_not_run":
if status == "success" and reason == WatcherEventReason.NODE_NOT_RUN:
logger.info(
"Model '%s' was skipped by the dbt command. This may happen if it is an ephemeral model or if the model sql file is empty.",
"%s '%s' was skipped by the dbt command. This may happen if it is an ephemeral model or if the model sql file is empty.",
self._resource_label,
self.model_unique_id,
)
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.

Expand All @@ -456,14 +477,14 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
task_ids=self.producer_task_id,
)
_log_dbt_event(dbt_events)
if reason == "model_failed":
if reason == WatcherEventReason.NODE_FAILED:
raise AirflowException(
f"dbt model '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details."
f"dbt {self._resource_label.lower()} '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details."
)

if reason == "producer_failed":
if reason == WatcherEventReason.PRODUCER_FAILED:
raise AirflowException(
f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error."
f"Watcher producer task '{self.producer_task_id}' failed before reporting results for {self._resource_label.lower()} '{self.model_unique_id}'. Check its logs for the underlying error."
)

def use_event(self) -> bool:
Expand All @@ -481,19 +502,33 @@ def _log_startup_events(self, ti: Any) -> None:
# Adding debug level to avoid redundant logs for non-deferrable mode
logger.debug("%s", event.get("msg"))

def _get_node_status(self, ti: Any, context: Context) -> Any:
"""Return the current status of the watched dbt node from XCom.

For test sensors, reads the aggregated ``_tests_status`` key.
For model sensors, reads from event-based or subprocess-based keys.
"""
if self.is_test_sensor:
xcom_key = get_tests_status_xcom_key(self.model_unique_id)
return get_xcom_val(ti, self.producer_task_id, xcom_key)
if self.use_event():
return self._get_status_from_events(ti, context)
return get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status")

def poke(self, context: Context) -> bool:
"""
Checks the status of a dbt model run by pulling relevant XComs from the master task.
Handles retries and checks for successful completion of the model execution.
Checks the status of a dbt node (model or aggregated tests) by pulling relevant XComs from the producer task.
Handles retries and checks for successful completion.
"""
ti = context["ti"]
try_number = ti.try_number

logger.info(
"Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'",
"Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for %s '%s'",
try_number,
Comment thread
michal-mrazek marked this conversation as resolved.
self.poke_retry_number,
self.producer_task_id,
self._resource_label.lower(),
self.model_unique_id,
)

Expand All @@ -502,11 +537,9 @@ def poke(self, context: Context) -> bool:

# We have assumption here that both the build producer and the sensor task will have same invocation mode
producer_task_state = self._get_producer_task_status(context)
if self.use_event():
status = self._get_status_from_events(ti, context)
else:
if not self.use_event() and not self.is_test_sensor:
self._log_startup_events(ti)
status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status")
status = self._get_node_status(ti, context)

# compiled_sql is always in the canonical per-model XCom key (same for event and subprocess modes)
if status is not None:
Expand Down Expand Up @@ -542,4 +575,4 @@ def poke(self, context: Context) -> bool:
elif is_dbt_node_status_success(status):
return True
else:
raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'")
raise AirflowException(f"{self._resource_label} '{self.model_unique_id}' finished with status '{status}'")
32 changes: 28 additions & 4 deletions cosmos/operators/_watcher/triggerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import zlib
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -25,6 +26,14 @@
logger = get_logger(__name__)


class WatcherEventReason(str, Enum):
"""Reason codes used in TriggerEvent payloads between WatcherTrigger and BaseConsumerSensor.execute_complete."""

NODE_FAILED = "node_failed"
PRODUCER_FAILED = "producer_failed"
NODE_NOT_RUN = "node_not_run"


class WatcherTrigger(BaseTrigger):

def __init__(
Expand All @@ -36,6 +45,7 @@ def __init__(
map_index: int | None,
use_event: bool,
poke_interval: float = 5.0,
is_test_sensor: bool = False,
):
self.model_unique_id = model_unique_id
self.producer_task_id = producer_task_id
Expand All @@ -44,6 +54,7 @@ def __init__(
self.map_index = map_index
self.use_event = use_event
self.poke_interval = poke_interval
self.is_test_sensor = is_test_sensor

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -56,6 +67,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"map_index": self.map_index,
"use_event": self.use_event,
"poke_interval": self.poke_interval,
"is_test_sensor": self.is_test_sensor,
},
)

Expand Down Expand Up @@ -127,9 +139,21 @@ async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str
Parse node status and compiled_sql from XCom.

Returns a tuple of (status, compiled_sql).
Status comes from mode-specific keys (nodefinished_* for event, *_status for subprocess).

For test sensors (``is_test_sensor=True``), the aggregated test status is
read from the ``<model_uid>_tests_status`` key. No compiled_sql is relevant.

For regular sensors, status comes from mode-specific keys
(nodefinished_* for event, *_status for subprocess).
compiled_sql is always read from the canonical per-model key (same for both modes).
"""
if self.is_test_sensor:
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key

status_key = get_tests_status_xcom_key(self.model_unique_id)
status = await self.get_xcom_val(status_key)
return status, None
Comment thread
michal-mrazek marked this conversation as resolved.

Comment thread
michal-mrazek marked this conversation as resolved.
compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql"

status = await self._get_node_status()
Expand Down Expand Up @@ -208,7 +232,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
return
elif is_dbt_node_status_failed(dbt_node_status):
logger.warning("dbt node '%s' failed", self.model_unique_id)
event_data = {"status": EventStatus.FAILED, "reason": "model_failed"}
event_data = {"status": EventStatus.FAILED, "reason": WatcherEventReason.NODE_FAILED}
if compiled_sql:
event_data["compiled_sql"] = compiled_sql
yield TriggerEvent(event_data) # type: ignore[no-untyped-call]
Expand All @@ -219,15 +243,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.producer_task_id,
self.model_unique_id,
)
yield TriggerEvent({"status": EventStatus.FAILED, "reason": "producer_failed"}) # type: ignore[no-untyped-call]
yield TriggerEvent({"status": EventStatus.FAILED, "reason": WatcherEventReason.PRODUCER_FAILED}) # type: ignore[no-untyped-call]
return
elif producer_task_state == "success" and dbt_node_status is None:
logger.info(
"The producer task '%s' succeeded. There is no information about the node '%s' execution.",
self.producer_task_id,
self.model_unique_id,
)
yield TriggerEvent({"status": EventStatus.SUCCESS, "reason": "model_not_run"}) # type: ignore[no-untyped-call]
yield TriggerEvent({"status": EventStatus.SUCCESS, "reason": WatcherEventReason.NODE_NOT_RUN}) # type: ignore[no-untyped-call]
return

# Sleep briefly before re-polling
Expand Down
43 changes: 26 additions & 17 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
from airflow.exceptions import AirflowException

from cosmos.config import ProfileConfig
from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push
from cosmos.operators._watcher.state import DBT_FAILED_STATUSES
from cosmos.settings import watcher_dbt_execution_queue

try:
from airflow.providers.standard.operators.empty import EmptyOperator
except ImportError: # pragma: no cover
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]

from cosmos.constants import (
_DBT_STARTUP_EVENTS_XCOM_KEY,
Comment thread
michal-mrazek marked this conversation as resolved.
PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT,
Comment thread
michal-mrazek marked this conversation as resolved.
Expand All @@ -27,13 +18,15 @@
InvocationMode,
)
from cosmos.log import get_logger
from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push
from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate
Comment thread
michal-mrazek marked this conversation as resolved.
from cosmos.operators._watcher.base import (
BaseConsumerSensor,
_process_dbt_log_event,
store_compiled_sql_for_model,
store_dbt_resource_status_from_log,
)
from cosmos.operators._watcher.state import DBT_FAILED_STATUSES
from cosmos.operators.base import (
DbtBuildMixin,
DbtRunMixin,
Expand All @@ -45,6 +38,7 @@
DbtRunLocalOperator,
DbtSourceLocalOperator,
)
from cosmos.settings import watcher_dbt_execution_queue

try:
from dbt_common.events.base_types import EventMsg
Expand Down Expand Up @@ -328,13 +322,28 @@ def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)


class DbtTestWatcherOperator(EmptyOperator):
"""
As a starting point, this operator does nothing.
We'll be implementing this operator as part of: https://github.com/astronomer/astronomer-cosmos/issues/1974
class DbtTestWatcherOperator(DbtConsumerWatcherSensor): # type: ignore[misc]
"""Sensor that watches the aggregated test status for a dbt model in watcher execution mode.

The producer task (``DbtProducerWatcherOperator``) collects individual test
results as they finish and, once every test for a given model has reported,
pushes a single aggregated XCom (``"pass"`` or ``"fail"``) under the key
``<model_unique_id>_tests_status``.
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.

This sensor polls that key and:
Comment thread
michal-mrazek marked this conversation as resolved.
* returns success when the value is ``"pass"``,
* raises ``AirflowException`` when the value is ``"fail"``.

Deferral is fully supported: the ``WatcherTrigger`` receives
Comment thread
michal-mrazek marked this conversation as resolved.
``is_test_sensor=True`` and polls the correct aggregated key.
"""

def __init__(self, *args: Any, **kwargs: Any):
desired_keys = ("dag", "task_group", "task_id")
new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys}
super().__init__(**new_kwargs) # type: ignore[no-untyped-call]
template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields # type: ignore[operator]

@property
def is_test_sensor(self) -> bool:
return True

def use_event(self) -> bool:
"""This sensor relies on the producer task pushing aggregated test results to XCom, so it does not use real-time events."""
return False
27 changes: 20 additions & 7 deletions tests/operators/_watcher/test_triggerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from packaging.version import Version

from cosmos.constants import _DBT_STARTUP_EVENTS_XCOM_KEY, AIRFLOW_VERSION
from cosmos.operators._watcher.triggerer import WatcherTrigger
from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger

_STARTUP_EVENTS = [{"name": "MainReportVersion", "msg": "Running with dbt=1.0.0", "ts": ""}]

Expand Down Expand Up @@ -118,6 +118,19 @@ async def mock_get_xcom_val(key):
assert status == expected_status
assert compiled_sql == expected_compiled_sql

async def test_parse_dbt_node_status_for_test_sensor(self):
"""When is_test_sensor=True, _parse_dbt_node_status_and_compiled_sql reads the aggregated tests_status key."""
self.trigger.is_test_sensor = True
self.trigger.model_unique_id = "model.jaffle_shop.stg_orders"

mock_get_xcom_val = AsyncMock(return_value="pass")
with patch.object(self.trigger, "get_xcom_val", mock_get_xcom_val):
status, compiled_sql = await self.trigger._parse_dbt_node_status_and_compiled_sql()

assert status == "pass"
assert compiled_sql is None
mock_get_xcom_val.assert_called_once_with("model__jaffle_shop__stg_orders_tests_status")

@pytest.mark.parametrize(
"airflow_version, expected_val",
[
Expand All @@ -140,9 +153,9 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val):
"dbt_node_status, producer_state, expected",
[
("success", "running", {"status": "success"}),
("failed", "running", {"status": "failed", "reason": "model_failed"}),
(None, "failed", {"status": "failed", "reason": "producer_failed"}),
(None, "success", {"status": "success", "reason": "model_not_run"}),
("failed", "running", {"status": "failed", "reason": WatcherEventReason.NODE_FAILED}),
(None, "failed", {"status": "failed", "reason": WatcherEventReason.PRODUCER_FAILED}),
(None, "success", {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}),
],
)
@patch("cosmos.operators._watcher.triggerer.WatcherTrigger._log_startup_events")
Expand Down Expand Up @@ -236,7 +249,7 @@ def _import_side_effect(name: str, *args, **kwargs):

@pytest.mark.asyncio
async def test_run_producer_success_model_not_run(self, caplog):
"""Test that when producer succeeds but model has no status, trigger yields success with model_not_run reason."""
"""Test that when producer succeeds but model has no status, trigger yields success with node_not_run reason."""
get_xcom_val_mock = AsyncMock(
side_effect=lambda key: _STARTUP_EVENTS if key == _DBT_STARTUP_EVENTS_XCOM_KEY else None
)
Expand All @@ -257,7 +270,7 @@ async def test_run_producer_success_model_not_run(self, caplog):
events.append(event)

assert len(events) == 1
assert events[0].payload == {"status": "success", "reason": "model_not_run"}
assert events[0].payload == {"status": "success", "reason": WatcherEventReason.NODE_NOT_RUN}
assert "The producer task 'task_1' succeeded" in caplog.text
assert "There is no information about the node 'model.test' execution" in caplog.text

Expand Down Expand Up @@ -309,7 +322,7 @@ async def get_xcom_val_side_effect(key):
events = [event async for event in self.trigger.run()]
assert len(events) == 1
assert events[0].payload["status"] == "failed"
assert events[0].payload["reason"] == "model_failed"
assert events[0].payload["reason"] == WatcherEventReason.NODE_FAILED
assert events[0].payload["compiled_sql"] == "SELECT * FROM broken_model"

@patch("cosmos.operators._watcher.triggerer.logger")
Expand Down
Loading
Loading