diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 83826ff171..0a323ceba4 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -678,12 +678,15 @@ def _add_watcher_producer_task( task_group: TaskGroup | None, render_config: RenderConfig | None = None, execution_mode: ExecutionMode = ExecutionMode.WATCHER, + tests_per_model: dict[str, list[str]] | None = None, ) -> BaseOperator: """ Create the producer task for the watcher execution mode and add it to the tasks_map. The producer task is the task that will be used to produce the events for the watcher execution mode. """ producer_task_args = task_args.copy() + if tests_per_model is not None: + producer_task_args["tests_per_model"] = tests_per_model if render_config is not None: producer_task_args["select"] = _convert_list_to_str(render_config.select) @@ -855,6 +858,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command async_py_requirements: list[str] | None = None, execution_config: ExecutionConfig | None = None, + tests_per_model: dict[str, list[str]] | None = None, ) -> dict[str, TaskGroup | BaseOperator]: """ Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory). @@ -907,6 +911,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_group=task_group, render_config=render_config, execution_mode=execution_mode, + tests_per_model=tests_per_model, ) for node_id, node in nodes.items(): diff --git a/cosmos/converter.py b/cosmos/converter.py index 16026bf6a5..c06a0db106 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -365,6 +365,7 @@ def __init__( render_config=render_config, async_py_requirements=execution_config.async_py_requirements, execution_config=execution_config, + tests_per_model=self.dbt_graph.tests_per_model, ) current_time = time.perf_counter() diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index b59d4cc181..bf9fdd1ae5 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -373,6 +373,7 @@ class DbtGraph: nodes: dict[str, DbtNode] = dict() filtered_nodes: dict[str, DbtNode] = dict() + tests_per_model: dict[str, list[str]] = dict() load_method: LoadMode = LoadMode.AUTOMATIC def __init__( @@ -1217,11 +1218,14 @@ def load_from_dbt_manifest(self) -> None: # noqa: C901 def update_node_dependency(self) -> None: """ This will update the property `has_test` if node has `dbt` test and update the property - `has_non_detached_test` if there's at least one non-detached `dbt` test + `has_non_detached_test` if there's at least one non-detached `dbt` test. + Also builds `tests_per_model`: a mapping of model unique_id to its associated test names. Updates in-place: * self.filtered_nodes + * self.tests_per_model """ + tests_per_model: dict[str, list[str]] = {} for _, node in list(self.nodes.items()): if node.resource_type == DbtResourceType.TEST: for node_id in node.depends_on: @@ -1233,8 +1237,10 @@ def update_node_dependency(self) -> None: or self.render_config.should_detach_multiple_parents_tests is False ): self.filtered_nodes[node_id].has_non_detached_test = True + tests_per_model.setdefault(node_id, []).append(node.unique_id) else: for parent_node_id in node.depends_on: parent_node = self.nodes.get(parent_node_id) if parent_node is not None: parent_node.downstream.append(node.unique_id) + self.tests_per_model = tests_per_model diff --git a/cosmos/operators/_watcher/__init__.py b/cosmos/operators/_watcher/__init__.py index 5db5b607b5..0eac9da614 100644 --- a/cosmos/operators/_watcher/__init__.py +++ b/cosmos/operators/_watcher/__init__.py @@ -1,6 +1,22 @@ from __future__ import annotations -__all__ = ["get_xcom_val", "safe_xcom_push", "build_producer_state_fetcher", "WatcherTrigger", "_parse_compressed_xcom"] +__all__ = [ + "get_xcom_val", + "safe_xcom_push", + "build_producer_state_fetcher", + "is_dbt_node_status_success", + "is_dbt_node_status_failed", + "is_dbt_node_status_terminal", + "WatcherTrigger", + "_parse_compressed_xcom", +] -from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push +from cosmos.operators._watcher.state import ( + build_producer_state_fetcher, + get_xcom_val, + is_dbt_node_status_failed, + is_dbt_node_status_success, + is_dbt_node_status_terminal, + safe_xcom_push, +) from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom diff --git a/cosmos/operators/_watcher/aggregation.py b/cosmos/operators/_watcher/aggregation.py new file mode 100644 index 0000000000..eeec6ac518 --- /dev/null +++ b/cosmos/operators/_watcher/aggregation.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from threading import Lock +from typing import Any + +from cosmos.log import get_logger +from cosmos.operators._watcher.state import DbtTestStatus, is_dbt_node_status_success, safe_xcom_push + +logger = get_logger(__name__) + +# Protects all mutations of ``test_results_per_model`` so that concurrent +# dbt threads cannot interleave ``setdefault`` / ``append`` / ``len`` checks. +_test_results_lock = Lock() + + +def get_tests_status_xcom_key(model_uid: str) -> str: + """Return the XCom key used to store the aggregated test status for a model.""" + return f"{model_uid.replace('.', '__')}_tests_status" + + +def accumulate_test_result( + test_unique_id: str, + status: str, + tests_per_model: dict[str, list[str]], + test_results_per_model: dict[str, list[str]], +) -> str | None: + """Accumulate a test's terminal status into test_results_per_model for its parent model. + + Returns the parent model's unique_id if found, else None. + """ + for model_uid, test_uids in tests_per_model.items(): + if test_unique_id in test_uids: + test_results_per_model.setdefault(model_uid, []).append(status) + return model_uid + return None + + +def get_aggregated_test_status( + model_uid: str, + tests_per_model: dict[str, list[str]], + test_results_per_model: dict[str, list[str]], +) -> str | None: + """ + Check if all tests for a model have finished and return aggregated status. + + Returns: + "pass" if all tests passed, "fail" if any test failed, + or None if not all tests have reported yet. + """ + expected = tests_per_model.get(model_uid) + if not expected: + return None + collected = test_results_per_model.get(model_uid, []) + if len(collected) < len(expected): + logger.debug( + "Model '%s' has %s tests, but only %s have reported results so far.", + model_uid, + len(expected), + len(collected), + ) + return None + aggregated_test_result = ( + DbtTestStatus.PASS if all(is_dbt_node_status_success(s) for s in collected) else DbtTestStatus.FAIL + ) + logger.debug("Model '%s' has all tests reported. Aggregated result: %s", model_uid, aggregated_test_result) + return aggregated_test_result + + +def push_test_result_or_aggregate( + test_unique_id: str, + status: str, + tests_per_model: dict[str, list[str]], + test_results_per_model: dict[str, list[str]], + task_instance: Any, +) -> None: + """Accumulate a test result and, when all tests for the parent model have reported, push aggregated XCom. + + :param test_unique_id: The unique_id of the finished test node. + :param status: The terminal status of the test (e.g. "pass", "fail"). + :param tests_per_model: Mapping of model unique_id → list of test unique_ids. + :param test_results_per_model: Mutable accumulator, mutated in place. + :param task_instance: The Airflow task instance used for XCom push. + """ + with _test_results_lock: + model_uid = accumulate_test_result(test_unique_id, status, tests_per_model, test_results_per_model) + if model_uid is not None: + aggregated = get_aggregated_test_status(model_uid, tests_per_model, test_results_per_model) + if aggregated is not None: + safe_xcom_push( + task_instance=task_instance, + key=get_tests_status_xcom_key(model_uid), + value=aggregated, + ) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index f128e29b0c..4161dcf282 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -14,7 +14,14 @@ WATCHER_TASK_WEIGHT_RULE, ) from cosmos.log import get_logger -from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push +from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate +from cosmos.operators._watcher.state import ( + build_producer_state_fetcher, + get_xcom_val, + is_dbt_node_status_success, + is_dbt_node_status_terminal, + safe_xcom_push, +) from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom try: @@ -85,13 +92,27 @@ def store_compiled_sql_for_model( _push_compiled_sql_for_model(task_instance, unique_id, compiled_sql) -def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: +def store_dbt_resource_status_from_log( + line: str, + extra_kwargs: Any, + *, + tests_per_model: dict[str, list[str]] | None = None, + test_results_per_model: dict[str, list[str]] | None = None, +) -> None: """ Parses a single line from dbt JSON logs and stores node status to Airflow XCom. This method parses each log line from dbt when --log-format json is used, extracts node status information, and pushes it to XCom for consumption by downstream watcher sensors. + + :param tests_per_model: Mapping of model unique_id to list of test unique_ids + associated with that model, as built by DbtGraph.update_node_dependency(). + Empty dict when no tests exist. + :param test_results_per_model: Mutable accumulator dict. For each model that has + tests, collects the terminal statuses of those tests as they finish. + Keyed by model unique_id, values are lists of test statuses (e.g. ``["pass", "pass"]``). + Mutated in place by this function. """ try: log_line = json.loads(line) @@ -101,16 +122,26 @@ def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: else: logger.debug("Log line: %s", log_line) node_info = log_line.get("data", {}).get("node_info", {}) - node_status = node_info.get("node_status") + dbt_node_status = node_info.get("node_status") + dbt_node_resource_type = node_info.get("resource_type") unique_id = node_info.get("unique_id") - logger.debug("Model: %s is in %s state", unique_id, node_status) + logger.debug("Model: %s is in %s state", unique_id, dbt_node_status) - # TODO: Handle and store all possible node statuses, not just the current success and failed - if node_status in ["success", "failed"]: + # Handle terminal statuses for both models (success/failed) and tests (pass/fail) + # TODO: handle all possible statuses including skipped, warn, etc. + if is_dbt_node_status_terminal(dbt_node_status): context = extra_kwargs.get("context") assert context is not None # Make MyPy happy - safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) + if dbt_node_resource_type == "test" and tests_per_model and test_results_per_model is not None: + logger.debug("Test '%s' finished with status '%s'", unique_id, dbt_node_status) + push_test_result_or_aggregate( + unique_id, dbt_node_status, tests_per_model, test_results_per_model, context["ti"] + ) + else: + safe_xcom_push( + task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=dbt_node_status + ) # Extract and push compiled_sql for models (centralised for both subprocess and node-event) # compiled_sql is available for both success and failed models - it's compiled before execution @@ -378,7 +409,7 @@ def poke(self, context: Context) -> bool: self.poke_retry_number += 1 return False - elif status == "success": + elif is_dbt_node_status_success(status): return True else: raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'") diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index 002eb52940..f249a24117 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -2,6 +2,7 @@ import logging from collections.abc import Callable +from enum import Enum from threading import Lock from typing import Any @@ -14,6 +15,34 @@ ProducerStateFetcher = Callable[[], str | None] +# dbt uses different status values for different node types (models/tests):" +DBT_SUCCESS_STATUSES = frozenset({"success", "pass"}) +DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error"}) + + +class DbtTestStatus(str, Enum): + """Aggregated status of all tests for a given model.""" + + __test__ = False + + PASS = "pass" + FAIL = "fail" + + +def is_dbt_node_status_success(status: str | None) -> bool: + """Check if the dbt node status indicates success (works for both models and tests).""" + return status in DBT_SUCCESS_STATUSES + + +def is_dbt_node_status_failed(status: str | None) -> bool: + """Check if the dbt node status indicates failure (works for both models and tests).""" + return status in DBT_FAILED_STATUSES + + +def is_dbt_node_status_terminal(status: str | None) -> bool: + """Check if the dbt node status is terminal (success or failed).""" + return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) + xcom_set_lock = Lock() diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 96ffa926fc..17cd28238f 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -12,8 +12,13 @@ from packaging.version import Version from cosmos.constants import AIRFLOW_VERSION +from cosmos.listeners.dag_run_listener import EventStatus from cosmos.log import get_logger -from cosmos.operators._watcher.state import build_producer_state_fetcher +from cosmos.operators._watcher.state import ( + build_producer_state_fetcher, + is_dbt_node_status_failed, + is_dbt_node_status_success, +) logger = get_logger(__name__) @@ -98,7 +103,7 @@ async def get_xcom_val(self, key: str) -> Any | None: else: return await self.get_xcom_val_af3(key) - async def _parse_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]: + async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str | None]: """ Parse node status and compiled_sql from XCom. @@ -145,41 +150,41 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while True: producer_task_state = await self._get_producer_task_status() - node_status, compiled_sql = await self._parse_node_status_and_compiled_sql() - if node_status == "success": - logger.info("Model '%s' succeeded", self.model_unique_id) - event_data: dict[str, Any] = {"status": "success"} + dbt_node_status, compiled_sql = await self._parse_dbt_node_status_and_compiled_sql() + if is_dbt_node_status_success(dbt_node_status): + logger.info("dbt node '%s' succeeded", self.model_unique_id) + event_data: dict[str, Any] = {"status": EventStatus.SUCCESS} if compiled_sql: event_data["compiled_sql"] = compiled_sql yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return - elif node_status == "failed": - logger.warning("Model '%s' failed", self.model_unique_id) - event_data = {"status": "failed", "reason": "model_failed"} + elif is_dbt_node_status_failed(dbt_node_status): + logger.warning("dbt node '%s' failed", self.model_unique_id) + event_data = {"status": EventStatus.FAILED, "reason": "model_failed"} if compiled_sql: event_data["compiled_sql"] = compiled_sql yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return elif producer_task_state == "failed": logger.error( - "Watcher producer task '%s' failed before delivering results for model '%s'", + "Watcher producer task '%s' failed before delivering results for node '%s'", self.producer_task_id, self.model_unique_id, ) - yield TriggerEvent({"status": "failed", "reason": "producer_failed"}) # type: ignore[no-untyped-call] + yield TriggerEvent({"status": EventStatus.FAILED, "reason": "producer_failed"}) # type: ignore[no-untyped-call] return - elif producer_task_state == "success" and node_status is None: + elif producer_task_state == "success" and dbt_node_status is None: logger.info( - "The producer task '%s' succeeded. There is no information about the model '%s' execution.", + "The producer task '%s' succeeded. There is no information about the node '%s' execution.", self.producer_task_id, self.model_unique_id, ) - yield TriggerEvent({"status": "success", "reason": "model_not_run"}) # type: ignore[no-untyped-call] + yield TriggerEvent({"status": EventStatus.SUCCESS, "reason": "model_not_run"}) # type: ignore[no-untyped-call] return # Sleep briefly before re-polling await asyncio.sleep(self.poke_interval) - logger.debug("Polling again for model '%s' status...", self.model_unique_id) + logger.debug("Polling again for node '%s' status...", self.model_unique_id) def _parse_compressed_xcom(compressed_b64_event_msg: str) -> Any: diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index d440276af5..39cb271caa 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import functools import json import zlib from collections.abc import Callable, Sequence @@ -25,6 +26,7 @@ InvocationMode, ) from cosmos.log import get_logger +from cosmos.operators._watcher.aggregation import push_test_result_or_aggregate from cosmos.operators._watcher.base import ( BaseConsumerSensor, store_compiled_sql_for_model, @@ -94,6 +96,8 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): def __init__(self, *args: Any, **kwargs: Any) -> None: task_id = kwargs.pop("task_id", PRODUCER_WATCHER_TASK_ID) + self.tests_per_model: dict[str, list[str]] = kwargs.pop("tests_per_model", {}) + self.test_results_per_model: dict[str, list[str]] = {} kwargs.setdefault("priority_weight", PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) # Consumer watcher retry logic handles model-level reruns using the LOCAL execution mode; rerunning the producer @@ -134,8 +138,14 @@ def _handle_node_finished( resource_type = getattr(event_message.data.node_info, "resource_type", None) event_message_dict = self._serialize_event(event_message) store_compiled_sql_for_model(context["ti"], self.project_dir, uid, node_path, resource_type) - payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() - safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) + + if resource_type == "test" and self.tests_per_model: + status: str = getattr(event_message.data.node_info, "node_status", None) or "" + logger.debug("Test '%s' finished with status '%s'", uid, status) + push_test_result_or_aggregate(uid, status, self.tests_per_model, self.test_results_per_model, context["ti"]) + else: + payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() + safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: # Only push startup events; per-model statuses are available via individual nodefinished_ entries. @@ -153,7 +163,11 @@ def _set_process_log_line_callable_if_subprocess(self) -> None: "DbtProducerWatcherOperator: Setting log_format to json and process_log_line_callable to store_dbt_resource_status_from_log" ) self.log_format = "json" - self._process_log_line_callable = store_dbt_resource_status_from_log + self._process_log_line_callable = functools.partial( + store_dbt_resource_status_from_log, + tests_per_model=self.tests_per_model, + test_results_per_model=self.test_results_per_model, + ) def execute(self, context: Context, **kwargs: Any) -> Any: self._set_invocation_mode_if_not_set() diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 8f2aa9e43a..7d3f814192 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -1451,6 +1451,62 @@ def test_update_node_dependency_test_not_exist(): assert nodes.has_non_detached_test is False +def test_tests_per_model_populated(): + project_config = ProjectConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, manifest_path=SAMPLE_MANIFEST + ) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path) + dbt_graph = DbtGraph( + project=project_config, + execution_config=execution_config, + profile_config=profile_config, + ) + dbt_graph.load() + + # Every model that has_test should appear in tests_per_model + for node_id, node in dbt_graph.filtered_nodes.items(): + if node.resource_type == DbtResourceType.MODEL and node.has_test: + assert node_id in dbt_graph.tests_per_model + assert len(dbt_graph.tests_per_model[node_id]) > 0 + + # Spot-check: customers model should have known tests + customers_id = "model.jaffle_shop.customers" + assert customers_id in dbt_graph.tests_per_model + customers_tests = dbt_graph.tests_per_model[customers_id] + assert any("not_null_customers_customer_id" in t for t in customers_tests) + assert any("unique_customers_customer_id" in t for t in customers_tests) + + +def test_tests_per_model_empty_when_no_tests(): + project_config = ProjectConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, manifest_path=SAMPLE_MANIFEST + ) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + render_config = RenderConfig( + exclude=["config.materialized:test"], + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ) + execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path) + dbt_graph = DbtGraph( + project=project_config, + execution_config=execution_config, + profile_config=profile_config, + render_config=render_config, + ) + dbt_graph.load_from_dbt_manifest() + + assert dbt_graph.tests_per_model == {} + + def test_tag_selected_node_test_exist(): project_config = ProjectConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, manifest_path=SAMPLE_MANIFEST diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 26cfce0c8c..d239829437 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -99,9 +99,13 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: if expect_assert: with pytest.raises(AssertionError): - store_dbt_resource_status_from_log(line, {"context": context}) + store_dbt_resource_status_from_log( + line, {"context": context}, tests_per_model={}, test_results_per_model={} + ) else: - store_dbt_resource_status_from_log(line, {"context": context}) + store_dbt_resource_status_from_log( + line, {"context": context}, tests_per_model={}, test_results_per_model={} + ) if should_push: mock_push.assert_called_once_with( task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status @@ -114,5 +118,7 @@ def test_store_dbt_resource_status_from_log_invalid_json(): invalid_line = "{not a valid json}" with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: - store_dbt_resource_status_from_log(invalid_line, {"context": {"ti": MagicMock()}}) + store_dbt_resource_status_from_log( + invalid_line, {"context": {"ti": MagicMock()}}, tests_per_model={}, test_results_per_model={} + ) mock_push.assert_not_called() diff --git a/tests/operators/_watcher/test_aggregation.py b/tests/operators/_watcher/test_aggregation.py new file mode 100644 index 0000000000..40e1802ba6 --- /dev/null +++ b/tests/operators/_watcher/test_aggregation.py @@ -0,0 +1,176 @@ +"""Unit tests for cosmos.operators._watcher.aggregation module.""" + +from __future__ import annotations + +import threading +from unittest.mock import MagicMock, patch + +from cosmos.operators._watcher.aggregation import ( + accumulate_test_result, + get_aggregated_test_status, + get_tests_status_xcom_key, + push_test_result_or_aggregate, +) +from cosmos.operators._watcher.state import DbtTestStatus + +TESTS_PER_MODEL = { + "model.pkg.orders": ["test.pkg.not_null_orders_id", "test.pkg.unique_orders_id"], + "model.pkg.customers": ["test.pkg.not_null_customers_id"], +} + + +class TestGetTestsStatusXcomKey: + """Tests for get_tests_status_xcom_key.""" + + def test_replaces_dots_with_double_underscores(self): + assert get_tests_status_xcom_key("model.pkg.orders") == "model__pkg__orders_tests_status" + + def test_no_dots(self): + assert get_tests_status_xcom_key("orders") == "orders_tests_status" + + +class TestAccumulateTestResult: + """Tests for accumulate_test_result.""" + + def test_returns_model_uid_when_test_found(self): + results: dict[str, list[str]] = {} + model_uid = accumulate_test_result("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results) + assert model_uid == "model.pkg.orders" + assert results == {"model.pkg.orders": ["pass"]} + + def test_returns_none_when_test_not_found(self): + results: dict[str, list[str]] = {} + model_uid = accumulate_test_result("test.pkg.unknown_test", "pass", TESTS_PER_MODEL, results) + assert model_uid is None + assert results == {} + + def test_accumulates_multiple_results(self): + results: dict[str, list[str]] = {} + accumulate_test_result("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results) + accumulate_test_result("test.pkg.unique_orders_id", "fail", TESTS_PER_MODEL, results) + assert results == {"model.pkg.orders": ["pass", "fail"]} + + def test_accumulates_across_models(self): + results: dict[str, list[str]] = {} + accumulate_test_result("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results) + accumulate_test_result("test.pkg.not_null_customers_id", "pass", TESTS_PER_MODEL, results) + assert results == {"model.pkg.orders": ["pass"], "model.pkg.customers": ["pass"]} + + +class TestGetAggregatedTestStatus: + """Tests for get_aggregated_test_status.""" + + def test_returns_none_when_model_not_in_tests_per_model(self): + assert get_aggregated_test_status("model.pkg.unknown", TESTS_PER_MODEL, {}) is None + + def test_returns_none_when_not_all_tests_reported(self): + results = {"model.pkg.orders": ["pass"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) is None + + def test_returns_pass_when_all_tests_pass(self): + results = {"model.pkg.orders": ["pass", "pass"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) == DbtTestStatus.PASS + + def test_returns_fail_when_any_test_fails(self): + results = {"model.pkg.orders": ["pass", "fail"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) == DbtTestStatus.FAIL + + def test_returns_fail_when_all_tests_fail(self): + results = {"model.pkg.orders": ["fail", "fail"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) == DbtTestStatus.FAIL + + def test_single_test_pass(self): + results = {"model.pkg.customers": ["pass"]} + assert get_aggregated_test_status("model.pkg.customers", TESTS_PER_MODEL, results) == DbtTestStatus.PASS + + def test_single_test_fail(self): + results = {"model.pkg.customers": ["fail"]} + assert get_aggregated_test_status("model.pkg.customers", TESTS_PER_MODEL, results) == DbtTestStatus.FAIL + + def test_returns_fail_when_test_has_error_status(self): + results = {"model.pkg.orders": ["pass", "error"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) == DbtTestStatus.FAIL + + def test_treats_success_status_as_pass(self): + """'success' is used by models; tests use 'pass' — both should be treated as success.""" + results = {"model.pkg.orders": ["success", "pass"]} + assert get_aggregated_test_status("model.pkg.orders", TESTS_PER_MODEL, results) == DbtTestStatus.PASS + + +class TestPushTestResultOrAggregate: + """Tests for push_test_result_or_aggregate.""" + + @patch("cosmos.operators._watcher.aggregation.safe_xcom_push") + def test_pushes_xcom_when_all_tests_reported(self, mock_xcom_push: MagicMock): + results: dict[str, list[str]] = {} + ti = MagicMock() + push_test_result_or_aggregate("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results, ti) + push_test_result_or_aggregate("test.pkg.unique_orders_id", "pass", TESTS_PER_MODEL, results, ti) + mock_xcom_push.assert_called_once_with( + task_instance=ti, + key="model__pkg__orders_tests_status", + value=DbtTestStatus.PASS, + ) + + @patch("cosmos.operators._watcher.aggregation.safe_xcom_push") + def test_does_not_push_xcom_before_all_tests_reported(self, mock_xcom_push: MagicMock): + results: dict[str, list[str]] = {} + ti = MagicMock() + push_test_result_or_aggregate("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results, ti) + mock_xcom_push.assert_not_called() + + @patch("cosmos.operators._watcher.aggregation.safe_xcom_push") + def test_does_not_push_xcom_for_unknown_test(self, mock_xcom_push: MagicMock): + results: dict[str, list[str]] = {} + ti = MagicMock() + push_test_result_or_aggregate("test.pkg.unknown", "pass", TESTS_PER_MODEL, results, ti) + mock_xcom_push.assert_not_called() + + @patch("cosmos.operators._watcher.aggregation.safe_xcom_push") + def test_pushes_fail_when_any_test_fails(self, mock_xcom_push: MagicMock): + results: dict[str, list[str]] = {} + ti = MagicMock() + push_test_result_or_aggregate("test.pkg.not_null_orders_id", "pass", TESTS_PER_MODEL, results, ti) + push_test_result_or_aggregate("test.pkg.unique_orders_id", "fail", TESTS_PER_MODEL, results, ti) + mock_xcom_push.assert_called_once_with( + task_instance=ti, + key="model__pkg__orders_tests_status", + value=DbtTestStatus.FAIL, + ) + + +class TestPushTestResultOrAggregateConcurrency: + """Tests that push_test_result_or_aggregate is thread-safe.""" + + @patch("cosmos.operators._watcher.aggregation.safe_xcom_push") + def test_concurrent_threads_do_not_lose_results(self, mock_xcom_push: MagicMock): + """Simulate many concurrent dbt threads pushing test results for the same model. + + Without the lock, setdefault+append interleaving can lose results and + the aggregated XCom may never fire or fire prematurely. + """ + num_tests = 50 + test_ids = [f"test.pkg.test_{i}" for i in range(num_tests)] + tests_per_model: dict[str, list[str]] = {"model.pkg.big_model": test_ids} + results: dict[str, list[str]] = {} + ti = MagicMock() + barrier = threading.Barrier(num_tests) + + def worker(test_id: str) -> None: + barrier.wait() # Force all threads to start at the same instant + push_test_result_or_aggregate(test_id, "pass", tests_per_model, results, ti) + + threads = [threading.Thread(target=worker, args=(tid,)) for tid in test_ids] + for t in threads: + t.start() + for t in threads: + t.join() + + # All 50 results must be recorded — none lost to races + assert len(results["model.pkg.big_model"]) == num_tests + # XCom should have been pushed exactly once + mock_xcom_push.assert_called_once_with( + task_instance=ti, + key="model__pkg__big_model_tests_status", + value=DbtTestStatus.PASS, + ) diff --git a/tests/operators/_watcher/test_state.py b/tests/operators/_watcher/test_state.py new file mode 100644 index 0000000000..4831a08f07 --- /dev/null +++ b/tests/operators/_watcher/test_state.py @@ -0,0 +1,39 @@ +"""Unit tests for cosmos.operators._watcher.state module.""" + +from __future__ import annotations + +import pytest + +from cosmos.operators._watcher.state import ( + is_dbt_node_status_failed, + is_dbt_node_status_success, + is_dbt_node_status_terminal, +) + + +class TestNodeStatusHelpers: + """Tests for node status helper functions.""" + + @pytest.mark.parametrize("status", ["success", "pass"]) + def test_is_dbt_node_status_success_true(self, status: str): + assert is_dbt_node_status_success(status) is True + + @pytest.mark.parametrize("status", ["failed", "fail", "error", "skipped", "warn", None, ""]) + def test_is_dbt_node_status_success_false(self, status: str | None): + assert is_dbt_node_status_success(status) is False + + @pytest.mark.parametrize("status", ["failed", "fail", "error"]) + def test_is_dbt_node_status_failed_true(self, status: str): + assert is_dbt_node_status_failed(status) is True + + @pytest.mark.parametrize("status", ["success", "pass", "skipped", "warn", None, ""]) + def test_is_dbt_node_status_failed_false(self, status: str | None): + assert is_dbt_node_status_failed(status) is False + + @pytest.mark.parametrize("status", ["success", "pass", "failed", "fail", "error"]) + def test_is_dbt_node_status_terminal_true(self, status: str): + assert is_dbt_node_status_terminal(status) is True + + @pytest.mark.parametrize("status", ["skipped", "warn", "running", None, ""]) + def test_is_dbt_node_status_terminal_false(self, status: str | None): + assert is_dbt_node_status_terminal(status) is False diff --git a/tests/operators/_watcher/test_triggerer.py b/tests/operators/_watcher/test_triggerer.py index 07623e0300..4418dd41a2 100644 --- a/tests/operators/_watcher/test_triggerer.py +++ b/tests/operators/_watcher/test_triggerer.py @@ -92,7 +92,7 @@ async def runner(*args, **kwargs): (False, "success", "success", "SELECT * FROM table"), ], ) - async def test_parse_node_status_and_compiled_sql( + async def test_parse_dbt_node_status_and_compiled_sql( self, use_event, xcom_val, expected_status, expected_compiled_sql ): self.trigger.use_event = use_event @@ -112,7 +112,7 @@ async def mock_get_xcom_val(key): patch("cosmos.operators._watcher.triggerer._parse_compressed_xcom", return_value=xcom_val), patch.object(self.trigger, "get_xcom_val", AsyncMock(side_effect=mock_get_xcom_val)), ): - status, compiled_sql = await self.trigger._parse_node_status_and_compiled_sql() + status, compiled_sql = await self.trigger._parse_dbt_node_status_and_compiled_sql() assert status == expected_status assert compiled_sql == expected_compiled_sql @@ -135,7 +135,7 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): assert val == "af3" @pytest.mark.parametrize( - "node_status, producer_state, expected", + "dbt_node_status, producer_state, expected", [ ("success", "running", {"status": "success"}), ("failed", "running", {"status": "failed", "reason": "model_failed"}), @@ -143,7 +143,7 @@ async def test_get_xcom_val_branches(self, airflow_version, expected_val): (None, "success", {"status": "success", "reason": "model_not_run"}), ], ) - async def test_run_various_outcomes(self, node_status, producer_state, expected): + async def test_run_various_outcomes(self, dbt_node_status, producer_state, expected): async def fake_get_xcom_val(key): # Return None for compiled_sql key so payload matches expected (no compiled_sql) @@ -156,7 +156,7 @@ async def fake_get_xcom_val(key): patch.object(self.trigger, "_get_producer_task_status", AsyncMock(return_value=producer_state)), patch( "cosmos.operators._watcher.triggerer._parse_compressed_xcom", - return_value={"data": {"run_result": {"status": node_status}}} if node_status else {}, + return_value={"data": {"run_result": {"status": dbt_node_status}}} if dbt_node_status else {}, ), ): events = [event async for event in self.trigger.run()] @@ -232,13 +232,15 @@ def _import_side_effect(name: str, *args, **kwargs): async def test_run_producer_success_model_not_run(self, caplog): """Test that when producer succeeds but model has no status, trigger yields success with model_not_run reason.""" get_producer_status_mock = AsyncMock(return_value="success") - parse_node_status_and_compiled_sql_mock = AsyncMock(return_value=(None, None)) + parse_dbt_node_status_and_compiled_sql_mock = AsyncMock(return_value=(None, None)) caplog.set_level("INFO") with ( patch.object(self.trigger, "_get_producer_task_status", get_producer_status_mock), - patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_node_status_and_compiled_sql_mock), + patch.object( + self.trigger, "_parse_dbt_node_status_and_compiled_sql", parse_dbt_node_status_and_compiled_sql_mock + ), ): events = [] async for event in self.trigger.run(): @@ -247,13 +249,13 @@ async def test_run_producer_success_model_not_run(self, caplog): assert len(events) == 1 assert events[0].payload == {"status": "success", "reason": "model_not_run"} assert "The producer task 'task_1' succeeded" in caplog.text - assert "There is no information about the model 'model.test' execution" in caplog.text + assert "There is no information about the node 'model.test' execution" in caplog.text @pytest.mark.asyncio async def test_run_poke_interval_and_debug_log(self, caplog): get_xcom_val_mock = AsyncMock(side_effect=["compressed_data"]) get_producer_status_mock = AsyncMock(side_effect=["running", "running", "running"]) - parse_node_status_and_compiled_sql_mock = AsyncMock( + parse_dbt_node_status_and_compiled_sql_mock = AsyncMock( side_effect=[(None, None), (None, None), ("success", "SELECT 1")] ) @@ -262,7 +264,9 @@ async def test_run_poke_interval_and_debug_log(self, caplog): with ( patch.object(self.trigger, "get_xcom_val", get_xcom_val_mock), patch.object(self.trigger, "_get_producer_task_status", get_producer_status_mock), - patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_node_status_and_compiled_sql_mock), + patch.object( + self.trigger, "_parse_dbt_node_status_and_compiled_sql", parse_dbt_node_status_and_compiled_sql_mock + ), patch("asyncio.sleep", new_callable=AsyncMock) as sleep_mock, ): events = [] @@ -280,7 +284,7 @@ async def test_run_failed_model_includes_compiled_sql_in_event(self): parse_mock = AsyncMock(return_value=("failed", "SELECT * FROM broken_model")) with ( patch.object(self.trigger, "_get_producer_task_status", AsyncMock(return_value="running")), - patch.object(self.trigger, "_parse_node_status_and_compiled_sql", parse_mock), + patch.object(self.trigger, "_parse_dbt_node_status_and_compiled_sql", parse_mock), ): events = [event async for event in self.trigger.run()] assert len(events) == 1 diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 20b6c6b4a1..32bbabef85 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -536,7 +536,7 @@ def test_store_dbt_resource_status_from_log_success(self): log_line = json.dumps({"data": {"node_info": {"node_status": "success", "unique_id": "model.pkg.my_model"}}}) - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) assert ti.store.get("model__pkg__my_model_status") == "success" @@ -547,7 +547,7 @@ def test_store_dbt_resource_status_from_log_failed(self): log_line = json.dumps({"data": {"node_info": {"node_status": "failed", "unique_id": "model.pkg.failed_model"}}}) - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) assert ti.store.get("model__pkg__failed_model_status") == "failed" @@ -560,17 +560,141 @@ def test_store_dbt_resource_status_from_log_ignores_other_statuses(self): {"data": {"node_info": {"node_status": "running", "unique_id": "model.pkg.running_model"}}} ) - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) assert "model__pkg__running_model_status" not in ti.store - def test_store_dbt_resource_status_from_log_handles_invalid_json(self, caplog): + def test_store_dbt_resources_status_from_log_detects_passed_test_status(self): + """Test that a passed test status is correctly parsed and stored in XCom.""" + ti = _MockTI() + ctx = {"ti": ti} + + log_line = json.dumps( + { + "data": { + "node_info": { + "node_status": "pass", + "unique_id": "test.pkg.my_test", + } + } + } + ) + + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) + + assert ti.store.get("test__pkg__my_test_status") == "pass" + + def test_store_dbt_resource_status_from_log_detects_failed_test_status(self): + """Test that a failed test status is correctly parsed and stored in XCom.""" + ti = _MockTI() + ctx = {"ti": ti} + + log_line = json.dumps( + { + "data": { + "node_info": { + "node_status": "fail", + "unique_id": "test.pkg.my_test", + } + } + } + ) + + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) + + assert ti.store.get("test__pkg__my_test_status") == "fail" + + def test_store_dbt_resource_status_from_log_aggregates_test_results_when_tests_per_model_provided(self): + """When tests_per_model is non-empty and a test node finishes, the function should + accumulate results and push a single aggregated *_tests_status XCom once all tests + for the model have reported — instead of pushing individual *_status keys per test. + """ + ti = _MockTI() + ctx = {"ti": ti} + + tests_per_model = { + "model.pkg.orders": ["test.pkg.not_null_orders_id", "test.pkg.unique_orders_id"], + } + test_results_per_model: dict[str, list[str]] = {} + + # First test passes — not all tests reported yet, no XCom push + log_line_1 = json.dumps( + { + "data": { + "node_info": { + "node_status": "pass", + "unique_id": "test.pkg.not_null_orders_id", + "resource_type": "test", + } + } + } + ) + store_dbt_resource_status_from_log( + log_line_1, + {"context": ctx}, + tests_per_model=tests_per_model, + test_results_per_model=test_results_per_model, + ) + assert "test__pkg__not_null_orders_id_status" not in ti.store # no per-test key + assert "model__pkg__orders_tests_status" not in ti.store # not yet aggregated + + # Second test passes — all tests reported, aggregated XCom should be pushed + log_line_2 = json.dumps( + { + "data": { + "node_info": { + "node_status": "pass", + "unique_id": "test.pkg.unique_orders_id", + "resource_type": "test", + } + } + } + ) + store_dbt_resource_status_from_log( + log_line_2, + {"context": ctx}, + tests_per_model=tests_per_model, + test_results_per_model=test_results_per_model, + ) + assert "test__pkg__unique_orders_id_status" not in ti.store # no per-test key + assert ti.store.get("model__pkg__orders_tests_status") == "pass" + + def test_store_dbt_resource_status_from_log_aggregates_fail_when_any_test_fails(self): + """When at least one test fails, the aggregated status should be 'fail'.""" + ti = _MockTI() + ctx = {"ti": ti} + + tests_per_model = { + "model.pkg.orders": ["test.pkg.not_null_orders_id", "test.pkg.unique_orders_id"], + } + test_results_per_model: dict[str, list[str]] = {} + + for uid, status in [ + ("test.pkg.not_null_orders_id", "pass"), + ("test.pkg.unique_orders_id", "fail"), + ]: + log_line = json.dumps( + {"data": {"node_info": {"node_status": status, "unique_id": uid, "resource_type": "test"}}} + ) + store_dbt_resource_status_from_log( + log_line, + {"context": ctx}, + tests_per_model=tests_per_model, + test_results_per_model=test_results_per_model, + ) + + assert ti.store.get("model__pkg__orders_tests_status") == "fail" + # No per-test status keys should exist + assert "test__pkg__not_null_orders_id_status" not in ti.store + assert "test__pkg__unique_orders_id_status" not in ti.store """Test that invalid JSON doesn't raise an exception.""" ti = _MockTI() ctx = {"ti": ti} # Should not raise an exception - store_dbt_resource_status_from_log("not valid json {{{", {"context": ctx}) + store_dbt_resource_status_from_log( + "not valid json {{{", {"context": ctx}, tests_per_model={}, test_results_per_model={} + ) # No status should be stored assert len(ti.store) == 0 @@ -583,7 +707,7 @@ def test_store_dbt_resource_status_from_log_handles_missing_node_info(self): log_line = json.dumps({"data": {"other_key": "value"}}) # Should not raise an exception - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={}) # No status should be stored assert len(ti.store) == 0 @@ -606,7 +730,9 @@ def test_store_dbt_resource_status_from_log_outputs_dbt_info(self, caplog, msg, log_line = json.dumps({"info": {"msg": msg, "level": level}}) dynamic_level = getattr(logging, level.upper(), logging.INFO) with caplog.at_level(dynamic_level): - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log( + log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={} + ) assert msg in caplog.text assert any(record.levelname == logging.getLevelName(dynamic_level) for record in caplog.records) @@ -620,7 +746,9 @@ def test_store_dbt_resource_status_from_log_logs_message_only_once(self, caplog) log_line = json.dumps({"info": {"msg": test_msg, "level": "info", "ts": "2025-01-29T13:16:05.123456Z"}}) with caplog.at_level(logging.INFO): - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log( + log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={} + ) # Count how many times the message appears in log records message_count = sum(1 for record in caplog.records if test_msg in record.message) @@ -635,7 +763,9 @@ def test_store_dbt_resource_status_from_log_formats_timestamp(self, caplog): log_line = json.dumps({"info": {"msg": test_msg, "level": "info", "ts": "2025-01-29T13:16:05.123456Z"}}) with caplog.at_level(logging.INFO): - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log( + log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={} + ) # Verify the timestamp is formatted as HH:MM:SS assert any("13:16:05" in record.message and test_msg in record.message for record in caplog.records) @@ -651,15 +781,25 @@ def test_store_dbt_resource_status_from_log_invalid_timestamp_falls_back_to_raw( log_line = json.dumps({"info": {"msg": test_msg, "level": "info", "ts": invalid_ts}}) with caplog.at_level(logging.INFO): - store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log( + log_line, {"context": ctx}, tests_per_model={}, test_results_per_model={} + ) # Verify the raw timestamp is used when parsing fails assert any(invalid_ts in record.message and test_msg in record.message for record in caplog.records) def test_process_log_line_callable_integration_with_subprocess_pattern(self): - """Test the exact pattern used in subprocess.py: process_log_line(line, kwargs).""" + """Test the exact pattern used in subprocess.py: process_log_line(line, kwargs). + + The production code uses functools.partial to bind tests_per_model, + so the subprocess hook can still call process_log_line(line, kwargs) with 2 positional args. + """ + import functools + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - op._process_log_line_callable = store_dbt_resource_status_from_log + op._process_log_line_callable = functools.partial( + store_dbt_resource_status_from_log, tests_per_model={}, test_results_per_model={} + ) ti = _MockTI() ctx = {"ti": ti} @@ -1326,7 +1466,7 @@ def make_trigger(self, use_event: bool = False): ) @pytest.mark.asyncio - async def test_parse_node_status_and_compiled_sql_subprocess_mode(self): + async def test_parse_dbt_node_status_and_compiled_sql_subprocess_mode(self): """Test that compiled_sql is extracted from XCom in subprocess mode.""" trigger = self.make_trigger(use_event=False) @@ -1340,13 +1480,13 @@ async def mock_get_xcom_val(key): trigger.get_xcom_val = mock_get_xcom_val - status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + status, compiled_sql = await trigger._parse_dbt_node_status_and_compiled_sql() assert status == "success" assert compiled_sql == "SELECT * FROM orders" @pytest.mark.asyncio - async def test_parse_node_status_and_compiled_sql_subprocess_no_compiled_sql(self): + async def test_parse_dbt_node_status_and_compiled_sql_subprocess_no_compiled_sql(self): """Test that missing compiled_sql is handled gracefully in subprocess mode.""" trigger = self.make_trigger(use_event=False) @@ -1358,13 +1498,13 @@ async def mock_get_xcom_val(key): trigger.get_xcom_val = mock_get_xcom_val - status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + status, compiled_sql = await trigger._parse_dbt_node_status_and_compiled_sql() assert status == "success" assert compiled_sql is None @pytest.mark.asyncio - async def test_parse_node_status_and_compiled_sql_dbt_runner_mode(self): + async def test_parse_dbt_node_status_and_compiled_sql_dbt_runner_mode(self): """Test that in dbt_runner mode status comes from event payload and compiled_sql from canonical key.""" trigger = self.make_trigger(use_event=True) @@ -1381,7 +1521,7 @@ async def mock_get_xcom_val(key): trigger.get_xcom_val = mock_get_xcom_val - status, compiled_sql = await trigger._parse_node_status_and_compiled_sql() + status, compiled_sql = await trigger._parse_dbt_node_status_and_compiled_sql() assert status == "success" assert compiled_sql == "SELECT id FROM users"