Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,15 @@ def _add_watcher_producer_task(
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
execution_mode: ExecutionMode = ExecutionMode.WATCHER,
tests_per_model: dict[str, list[str]] | None = None,
) -> BaseOperator:
"""
Create the producer task for the watcher execution mode and add it to the tasks_map.
The producer task is the task that will be used to produce the events for the watcher execution mode.
"""
producer_task_args = task_args.copy()
if tests_per_model is not None:
Comment thread
michal-mrazek marked this conversation as resolved.
producer_task_args["tests_per_model"] = tests_per_model

if render_config is not None:
producer_task_args["select"] = _convert_list_to_str(render_config.select)
Expand Down Expand Up @@ -855,6 +858,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
async_py_requirements: list[str] | None = None,
execution_config: ExecutionConfig | None = None,
tests_per_model: dict[str, list[str]] | None = None,
) -> dict[str, TaskGroup | BaseOperator]:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -907,6 +911,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro
task_group=task_group,
render_config=render_config,
execution_mode=execution_mode,
tests_per_model=tests_per_model,
)

for node_id, node in nodes.items():
Expand Down
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
render_config=render_config,
async_py_requirements=execution_config.async_py_requirements,
execution_config=execution_config,
tests_per_model=self.dbt_graph.tests_per_model,
)

current_time = time.perf_counter()
Expand Down
8 changes: 7 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ class DbtGraph:

nodes: dict[str, DbtNode] = dict()
filtered_nodes: dict[str, DbtNode] = dict()
tests_per_model: dict[str, list[str]] = dict()
load_method: LoadMode = LoadMode.AUTOMATIC

def __init__(
Expand Down Expand Up @@ -1217,11 +1218,14 @@ def load_from_dbt_manifest(self) -> None: # noqa: C901
def update_node_dependency(self) -> None:
"""
This will update the property `has_test` if node has `dbt` test and update the property
`has_non_detached_test` if there's at least one non-detached `dbt` test
`has_non_detached_test` if there's at least one non-detached `dbt` test.
Also builds `tests_per_model`: a mapping of model unique_id to its associated test names.

Updates in-place:
* self.filtered_nodes
* self.tests_per_model
"""
tests_per_model: dict[str, list[str]] = {}
for _, node in list(self.nodes.items()):
if node.resource_type == DbtResourceType.TEST:
for node_id in node.depends_on:
Expand All @@ -1233,8 +1237,10 @@ def update_node_dependency(self) -> None:
or self.render_config.should_detach_multiple_parents_tests is False
):
self.filtered_nodes[node_id].has_non_detached_test = True
tests_per_model.setdefault(node_id, []).append(node.unique_id)
else:
for parent_node_id in node.depends_on:
parent_node = self.nodes.get(parent_node_id)
if parent_node is not None:
parent_node.downstream.append(node.unique_id)
self.tests_per_model = tests_per_model
20 changes: 18 additions & 2 deletions cosmos/operators/_watcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
from __future__ import annotations

__all__ = ["get_xcom_val", "safe_xcom_push", "build_producer_state_fetcher", "WatcherTrigger", "_parse_compressed_xcom"]
__all__ = [
"get_xcom_val",
"safe_xcom_push",
"build_producer_state_fetcher",
"is_dbt_node_status_success",
"is_dbt_node_status_failed",
"is_dbt_node_status_terminal",
"WatcherTrigger",
"_parse_compressed_xcom",
]

from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push
from cosmos.operators._watcher.state import (
build_producer_state_fetcher,
get_xcom_val,
is_dbt_node_status_failed,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
safe_xcom_push,
)
from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom
93 changes: 93 additions & 0 deletions cosmos/operators/_watcher/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from threading import Lock
from typing import Any

from cosmos.log import get_logger
from cosmos.operators._watcher.state import DbtTestStatus, is_dbt_node_status_success, safe_xcom_push

logger = get_logger(__name__)

# Protects all mutations of ``test_results_per_model`` so that concurrent
# dbt threads cannot interleave ``setdefault`` / ``append`` / ``len`` checks.
_test_results_lock = Lock()


def get_tests_status_xcom_key(model_uid: str) -> str:
"""Return the XCom key used to store the aggregated test status for a model."""
return f"{model_uid.replace('.', '__')}_tests_status"


def accumulate_test_result(
test_unique_id: str,
status: str,
tests_per_model: dict[str, list[str]],
test_results_per_model: dict[str, list[str]],
) -> str | None:
"""Accumulate a test's terminal status into test_results_per_model for its parent model.

Returns the parent model's unique_id if found, else None.
"""
for model_uid, test_uids in tests_per_model.items():
if test_unique_id in test_uids:
test_results_per_model.setdefault(model_uid, []).append(status)
return model_uid
Comment thread
michal-mrazek marked this conversation as resolved.
return None
Comment thread
tatiana marked this conversation as resolved.
Comment thread
tatiana marked this conversation as resolved.


def get_aggregated_test_status(
model_uid: str,
tests_per_model: dict[str, list[str]],
test_results_per_model: dict[str, list[str]],
) -> str | None:
"""
Check if all tests for a model have finished and return aggregated status.

Returns:
"pass" if all tests passed, "fail" if any test failed,
or None if not all tests have reported yet.
"""
expected = tests_per_model.get(model_uid)
if not expected:
return None
collected = test_results_per_model.get(model_uid, [])
if len(collected) < len(expected):
logger.debug(
"Model '%s' has %s tests, but only %s have reported results so far.",
model_uid,
len(expected),
len(collected),
)
Comment thread
michal-mrazek marked this conversation as resolved.
return None
aggregated_test_result = (
DbtTestStatus.PASS if all(is_dbt_node_status_success(s) for s in collected) else DbtTestStatus.FAIL
)
Comment thread
michal-mrazek marked this conversation as resolved.
logger.debug("Model '%s' has all tests reported. Aggregated result: %s", model_uid, aggregated_test_result)
return aggregated_test_result


def push_test_result_or_aggregate(
test_unique_id: str,
status: str,
tests_per_model: dict[str, list[str]],
test_results_per_model: dict[str, list[str]],
task_instance: Any,
) -> None:
"""Accumulate a test result and, when all tests for the parent model have reported, push aggregated XCom.

:param test_unique_id: The unique_id of the finished test node.
:param status: The terminal status of the test (e.g. "pass", "fail").
:param tests_per_model: Mapping of model unique_id → list of test unique_ids.
:param test_results_per_model: Mutable accumulator, mutated in place.
:param task_instance: The Airflow task instance used for XCom push.
"""
with _test_results_lock:
model_uid = accumulate_test_result(test_unique_id, status, tests_per_model, test_results_per_model)
if model_uid is not None:
aggregated = get_aggregated_test_status(model_uid, tests_per_model, test_results_per_model)
if aggregated is not None:
safe_xcom_push(
task_instance=task_instance,
key=get_tests_status_xcom_key(model_uid),
value=aggregated,
)
47 changes: 39 additions & 8 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
WATCHER_TASK_WEIGHT_RULE,
)
from cosmos.log import get_logger
from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push
from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate
from cosmos.operators._watcher.state import (
build_producer_state_fetcher,
get_xcom_val,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
safe_xcom_push,
)
from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom

try:
Expand Down Expand Up @@ -85,13 +92,27 @@ def store_compiled_sql_for_model(
_push_compiled_sql_for_model(task_instance, unique_id, compiled_sql)


def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None:
def store_dbt_resource_status_from_log(
line: str,
extra_kwargs: Any,
*,
tests_per_model: dict[str, list[str]] | None = None,
test_results_per_model: dict[str, list[str]] | None = None,
) -> None:
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.

This method parses each log line from dbt when --log-format json is used,
extracts node status information, and pushes it to XCom for consumption
by downstream watcher sensors.

:param tests_per_model: Mapping of model unique_id to list of test unique_ids
associated with that model, as built by DbtGraph.update_node_dependency().
Empty dict when no tests exist.
:param test_results_per_model: Mutable accumulator dict. For each model that has
tests, collects the terminal statuses of those tests as they finish.
Keyed by model unique_id, values are lists of test statuses (e.g. ``["pass", "pass"]``).
Mutated in place by this function.
"""
try:
log_line = json.loads(line)
Expand All @@ -101,16 +122,26 @@ def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None:
else:
logger.debug("Log line: %s", log_line)
node_info = log_line.get("data", {}).get("node_info", {})
node_status = node_info.get("node_status")
dbt_node_status = node_info.get("node_status")
dbt_node_resource_type = node_info.get("resource_type")
unique_id = node_info.get("unique_id")

logger.debug("Model: %s is in %s state", unique_id, node_status)
logger.debug("Model: %s is in %s state", unique_id, dbt_node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
# Handle terminal statuses for both models (success/failed) and tests (pass/fail)
# TODO: handle all possible statuses including skipped, warn, etc.
if is_dbt_node_status_terminal(dbt_node_status):
Comment thread
tatiana marked this conversation as resolved.
context = extra_kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)
if dbt_node_resource_type == "test" and tests_per_model and test_results_per_model is not None:
logger.debug("Test '%s' finished with status '%s'", unique_id, dbt_node_status)
push_test_result_or_aggregate(
unique_id, dbt_node_status, tests_per_model, test_results_per_model, context["ti"]
)
else:
safe_xcom_push(
task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=dbt_node_status
)
Comment thread
michal-mrazek marked this conversation as resolved.

# Extract and push compiled_sql for models (centralised for both subprocess and node-event)
# compiled_sql is available for both success and failed models - it's compiled before execution
Expand Down Expand Up @@ -378,7 +409,7 @@ def poke(self, context: Context) -> bool:
self.poke_retry_number += 1

return False
elif status == "success":
elif is_dbt_node_status_success(status):
return True
else:
raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'")
29 changes: 29 additions & 0 deletions cosmos/operators/_watcher/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from collections.abc import Callable
from enum import Enum
from threading import Lock
from typing import Any

Expand All @@ -14,6 +15,34 @@

ProducerStateFetcher = Callable[[], str | None]

# dbt uses different status values for different node types (models/tests):"
Comment thread
michal-mrazek marked this conversation as resolved.
Comment thread
michal-mrazek marked this conversation as resolved.
DBT_SUCCESS_STATUSES = frozenset({"success", "pass"})
DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error"})


class DbtTestStatus(str, Enum):
"""Aggregated status of all tests for a given model."""

__test__ = False

PASS = "pass"
FAIL = "fail"


def is_dbt_node_status_success(status: str | None) -> bool:
"""Check if the dbt node status indicates success (works for both models and tests)."""
return status in DBT_SUCCESS_STATUSES


def is_dbt_node_status_failed(status: str | None) -> bool:
"""Check if the dbt node status indicates failure (works for both models and tests)."""
return status in DBT_FAILED_STATUSES


def is_dbt_node_status_terminal(status: str | None) -> bool:
"""Check if the dbt node status is terminal (success or failed)."""
return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status)


xcom_set_lock = Lock()

Expand Down
Loading