From 25bc89c8f6d5794bbad546c02dad13eb4393aeba Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 5 Jan 2026 16:21:24 +0000 Subject: [PATCH 1/8] Refactor watcher operators so the base definitions can be reused by the K8s watcher operator --- cosmos/constants.py | 3 + cosmos/operators/_watcher/base.py | 279 ++++++++++++++++++++++++++++++ cosmos/operators/watcher.py | 245 +++----------------------- tests/hooks/test_subprocess.py | 4 +- tests/operators/test_watcher.py | 37 ++-- 5 files changed, 326 insertions(+), 242 deletions(-) create mode 100644 cosmos/operators/_watcher/base.py diff --git a/cosmos/constants.py b/cosmos/constants.py index 313036a8ff..024297d071 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -178,6 +178,9 @@ def _missing_value_(cls, value): # type: ignore DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async" DBT_TEARDOWN_ASYNC_TASK_ID = "dbt_teardown_async" +WATCHER_TASK_WEIGHT_RULE = "absolute" +CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT = 2 +PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT = 20 PRODUCER_WATCHER_TASK_ID = "dbt_producer_watcher" # Historical telemetry endpoints retained for reference: diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py new file mode 100644 index 0000000000..9f23db2c45 --- /dev/null +++ b/cosmos/operators/_watcher/base.py @@ -0,0 +1,279 @@ +import json +from datetime import timedelta +from typing import TYPE_CHECKING, Any + +from airflow.exceptions import AirflowException + +from cosmos.config import ProfileConfig +from cosmos.constants import ( + AIRFLOW_VERSION, + CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT, + PRODUCER_WATCHER_TASK_ID, + WATCHER_TASK_WEIGHT_RULE, +) +from cosmos.log import get_logger +from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push +from cosmos.operators._watcher.triggerer import WatcherTrigger, _parse_compressed_xcom + +try: + from airflow.sdk.bases.sensor import BaseSensorOperator +except ImportError: # pragma: no cover + from airflow.sensors.base import BaseSensorOperator + + +if TYPE_CHECKING: # pragma: no cover + try: + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context # type: ignore[attr-defined] + + +logger = get_logger(__name__) + + +def _store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: + """ + Parses a single line from dbt JSON logs and stores node status to Airflow XCom. + + This method parses each log line from dbt when --log-format json is used, + extracts node status information, and pushes it to XCom for consumption + by downstream watcher sensors. + """ + try: + log_line = json.loads(line) + except json.JSONDecodeError: + logger.debug("Failed to parse log: %s", line) + log_line = {} + node_info = log_line.get("data", {}).get("node_info", {}) + node_status = node_info.get("node_status") + unique_id = node_info.get("unique_id") + + logger.debug("Model: %s is in %s state", unique_id, node_status) + + # TODO: Handle and store all possible node statuses, not just the current success and failed + if node_status in ["success", "failed"]: + context = extra_kwargs.get("context") + assert context is not None # Make MyPy happy + safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) + + +class BaseConsumerSensor(BaseSensorOperator): # type: ignore[misc] + template_fields: tuple[str, ...] = ("model_unique_id", "compiled_sql") # type: ignore[operator] + poke_retry_number: int = 0 + + def __init__( + self, + *, + profile_config: ProfileConfig | None = None, + project_dir: str | None = None, + profiles_dir: str | None = None, + producer_task_id: str = PRODUCER_WATCHER_TASK_ID, + poke_interval: int = 10, + timeout: int = 60 * 60, # 1 h safety valve + execution_timeout: timedelta = timedelta(hours=1), + deferrable: bool = True, + **kwargs: Any, + ) -> None: + self.compiled_sql = "" + extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} + kwargs.setdefault("priority_weight", CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) + super().__init__( + poke_interval=poke_interval, + timeout=timeout, + execution_timeout=execution_timeout, + profile_config=profile_config, + project_dir=project_dir, + profiles_dir=profiles_dir, + **kwargs, + ) + self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") + self.project_dir = project_dir + self.producer_task_id = producer_task_id + self.deferrable = deferrable + + @staticmethod + def _filter_flags(flags: list[str]) -> list[str]: + """Filters out dbt flags that are incompatible with retry (e.g., --select, --exclude).""" + filtered = [] + skip_next = False + for token in flags: + if skip_next: + if token.startswith("--"): + skip_next = False + else: + continue # skip value of previous flag + if token in ("--select", "--exclude"): + skip_next = True + continue + filtered.append(token) + return filtered + + def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> bool: + """ + Handles logic for retrying a failed dbt model execution. + Reconstructs the dbt command by cloning the project and re-running the model + with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. + """ + logger.info( + "Retry attempt #%s – Running model '%s' from project '%s' using ExecutionMode.LOCAL", + try_number - 1, + self.model_unique_id, + self.project_dir, + ) + + upstream_task = context["ti"].task.dag.get_task(self.producer_task_id) + + extra_flags: list[str] = [] + if upstream_task and hasattr(upstream_task, "add_cmd_flags"): + raw_flags = upstream_task.add_cmd_flags() + extra_flags = self._filter_flags(raw_flags) + + model_selector = self.model_unique_id.split(".")[-1] + cmd_flags = extra_flags + ["--select", model_selector] + + self.build_and_run_cmd(context, cmd_flags=cmd_flags) + + logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) + return True + + def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: + compressed_b64_run_results = ti.xcom_pull(task_ids=self.producer_task_id, key="run_results") + + if not compressed_b64_run_results: + return None + + run_results_json = _parse_compressed_xcom(compressed_b64_run_results) + + logger.debug("Run results: %s", run_results_json) + + results = run_results_json.get("results", []) + node_result = next((r for r in results if r.get("unique_id") == self.model_unique_id), None) + + if not node_result: # pragma: no cover + logger.warning("No matching result found for unique_id '%s'", self.model_unique_id) + return None + + logger.info("Node Info: %s", run_results_json) + self.compiled_sql = node_result.get("compiled_code") + if self.compiled_sql: + self._override_rtif(context) + + return node_result.get("status") + + def _get_producer_task_status(self, context: Context) -> str | None: + """ + Get the task status of the producer task for both Airflow 2 and Airflow 3. + + Returns the state of the producer task instance, or None if not found. + """ + ti = context["ti"] + run_id = context["run_id"] + dag_id = ti.dag_id + + fetch_state = build_producer_state_fetcher( + airflow_version=AIRFLOW_VERSION, + dag_id=dag_id, + run_id=run_id, + producer_task_id=self.producer_task_id, + logger=logger, + ) + if fetch_state is None: + return None + + return fetch_state() + + def execute(self, context: Context, **kwargs: Any) -> None: + if not self.deferrable: + super().execute(context) + elif not self.poke(context): + self.defer( + trigger=WatcherTrigger( + model_unique_id=self.model_unique_id, + producer_task_id=self.producer_task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + map_index=context["task_instance"].map_index, + use_event=self._use_event(), + poke_interval=self.poke_interval, + ), + timeout=self.execution_timeout, + method_name=self.execute_complete.__name__, + ) + + def execute_complete(self, context: Context, event: dict[str, str]) -> None: + status = event.get("status") + if status != "failed": + return + + reason = event.get("reason") + if reason == "model_failed": + raise AirflowException( + f"dbt model '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details." + ) + + if reason == "producer_failed": + raise AirflowException( + f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." + ) + + def _use_event(self) -> bool: + raise NotImplementedError("Subclasses must implement this method") + + def _get_status_from_events(self, ti: Any, context: Context) -> Any: + raise NotImplementedError("Subclasses should implement this method if `_use_event` may return True") + + def _override_rtif(self, context: Context) -> None: + raise NotImplementedError( + "Subclasses should implement this method, or inherit from a class that implements it (e.g. DbtRunLocalOperator)." + ) + + def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> None: + raise NotImplementedError( + "Subclasses should implement this method, or inherit from a class that implements it (e.g. DbtRunLocalOperator)." + ) + + def poke(self, context: Context) -> bool: + """ + Checks the status of a dbt model run by pulling relevant XComs from the master task. + Handles retries and checks for successful completion of the model execution. + """ + ti = context["ti"] + try_number = ti.try_number + + logger.info( + "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'", + try_number, + self.poke_retry_number, + self.producer_task_id, + self.model_unique_id, + ) + + if try_number > 1: + return self._fallback_to_non_watcher_run(try_number, context) + + # We have assumption here that both the build producer and the sensor task will have same invocation mode + producer_task_state = self._get_producer_task_status(context) + if self._use_event(): + status = self._get_status_from_events(ti, context) + else: + status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + + if status is None: + + if producer_task_state == "failed": + if self.poke_retry_number > 0: + raise AirflowException( + f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + ) + else: + # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` + return self._fallback_to_non_watcher_run(try_number, context) + + self.poke_retry_number += 1 + + return False + elif status == "success": + return True + else: + raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'") diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 8654074451..01caf2373c 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -2,36 +2,32 @@ import base64 import json -import logging import zlib from collections.abc import Sequence from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any -from cosmos.operators._watcher import WatcherTrigger, _parse_compressed_xcom, get_xcom_val, safe_xcom_push - -if TYPE_CHECKING: # pragma: no cover - try: - from airflow.sdk.definitions.context import Context - except ImportError: - from airflow.utils.context import Context # type: ignore[attr-defined] - -try: - from airflow.sdk.bases.sensor import BaseSensorOperator -except ImportError: # pragma: no cover - from airflow.sensors.base import BaseSensorOperator from airflow.exceptions import AirflowException +from cosmos.config import ProfileConfig +from cosmos.constants import CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT +from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push + try: from airflow.providers.standard.operators.empty import EmptyOperator except ImportError: # pragma: no cover from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] -from cosmos.config import ProfileConfig -from cosmos.constants import AIRFLOW_VERSION, PRODUCER_WATCHER_TASK_ID, InvocationMode -from cosmos.operators._watcher.state import build_producer_state_fetcher +from cosmos.constants import ( + PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, + PRODUCER_WATCHER_TASK_ID, + WATCHER_TASK_WEIGHT_RULE, + InvocationMode, +) +from cosmos.log import get_logger +from cosmos.operators._watcher.base import BaseConsumerSensor, _store_dbt_resource_status_from_log from cosmos.operators.base import ( DbtBuildMixin, DbtRunMixin, @@ -49,37 +45,15 @@ except ImportError: # pragma: no cover EventMsg = None -logger = logging.getLogger(__name__) - -CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10 -PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999 -WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test() - -def _store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: - """ - Parses a single line from dbt JSON logs and stores node status to Airflow XCom. - - This method parses each log line from dbt when --log-format json is used, - extracts node status information, and pushes it to XCom for consumption - by downstream watcher sensors. - """ +if TYPE_CHECKING: # pragma: no cover try: - log_line = json.loads(line) - except json.JSONDecodeError: - logger.debug("Failed to parse log: %s", line) - log_line = {} - node_info = log_line.get("data", {}).get("node_info", {}) - node_status = node_info.get("node_status") - unique_id = node_info.get("unique_id") + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context # type: ignore[attr-defined] - logger.debug("Model: %s is in %s state", unique_id, node_status) - # TODO: Handle and store all possible node statuses, not just the current success and failed - if node_status in ["success", "failed"]: - context = extra_kwargs.get("context") - assert context is not None # Make MyPy happy - safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) +logger = get_logger(__name__) class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): @@ -113,9 +87,9 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): _process_log_line_callable = staticmethod(_store_dbt_resource_status_from_log) def __init__(self, *args: Any, **kwargs: Any) -> None: - task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator") - kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) - kwargs.setdefault("weight_rule", WEIGHT_RULE) + task_id = kwargs.pop("task_id", PRODUCER_WATCHER_TASK_ID) + kwargs.setdefault("priority_weight", PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) # Consumer watcher retry logic handles model-level reruns using the LOCAL execution mode; rerunning the producer # would repeat the full dbt build and duplicate watcher callbacks which may not be processed by the consumers if # they have already processed output XCOMs from the first run of the producer, so we disable retries. @@ -237,9 +211,8 @@ def _callback(event_message: EventMsg) -> None: raise -class DbtConsumerWatcherSensor(BaseSensorOperator, DbtRunLocalOperator): # type: ignore[misc] - template_fields: tuple[str, ...] = DbtRunLocalOperator.template_fields + ("model_unique_id",) # type: ignore[operator] - poke_retry_number: int = 0 +class DbtConsumerWatcherSensor(BaseConsumerSensor, DbtRunLocalOperator): # type: ignore[misc] + template_fields: tuple[str, ...] = BaseConsumerSensor.template_fields + DbtRunLocalOperator.template_fields # type: ignore[operator] def __init__( self, @@ -256,8 +229,8 @@ def __init__( ) -> None: self.compiled_sql = "" extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} - kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) - kwargs.setdefault("weight_rule", WEIGHT_RULE) + kwargs.setdefault("priority_weight", CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) super().__init__( poke_interval=poke_interval, timeout=timeout, @@ -271,51 +244,6 @@ def __init__( self.producer_task_id = producer_task_id self.deferrable = deferrable - @staticmethod - def _filter_flags(flags: list[str]) -> list[str]: - """Filters out dbt flags that are incompatible with retry (e.g., --select, --exclude).""" - filtered = [] - skip_next = False - for token in flags: - if skip_next: - if token.startswith("--"): - skip_next = False - else: - continue # skip value of previous flag - if token in ("--select", "--exclude"): - skip_next = True - continue - filtered.append(token) - return filtered - - def _fallback_to_local_run(self, try_number: int, context: Context) -> bool: - """ - Handles logic for retrying a failed dbt model execution. - Reconstructs the dbt command by cloning the project and re-running the model - with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. - """ - logger.info( - "Retry attempt #%s – Running model '%s' from project '%s' using ExecutionMode.LOCAL", - try_number - 1, - self.model_unique_id, - self.project_dir, - ) - - upstream_task = context["ti"].task.dag.get_task(self.producer_task_id) - - extra_flags: list[str] = [] - if upstream_task and hasattr(upstream_task, "add_cmd_flags"): - raw_flags = upstream_task.add_cmd_flags() - extra_flags = self._filter_flags(raw_flags) - - model_selector = self.model_unique_id.split(".")[-1] - cmd_flags = extra_flags + ["--select", model_selector] - - self.build_and_run_cmd(context, cmd_flags=cmd_flags) - - logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) - return True - def _get_status_from_events(self, ti: Any, context: Context) -> Any: dbt_startup_events = ti.xcom_pull(task_ids=self.producer_task_id, key="dbt_startup_events") @@ -339,136 +267,11 @@ def _get_status_from_events(self, ti: Any, context: Context) -> Any: return event_json.get("data", {}).get("run_result", {}).get("status") - def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: - compressed_b64_run_results = ti.xcom_pull(task_ids=self.producer_task_id, key="run_results") - - if not compressed_b64_run_results: - return None - - run_results_json = _parse_compressed_xcom(compressed_b64_run_results) - - logger.debug("Run results: %s", run_results_json) - - results = run_results_json.get("results", []) - node_result = next((r for r in results if r.get("unique_id") == self.model_unique_id), None) - - if not node_result: # pragma: no cover - logger.warning("No matching result found for unique_id '%s'", self.model_unique_id) - return None - - logger.info("Node Info: %s", run_results_json) - self.compiled_sql = node_result.get("compiled_code") - if self.compiled_sql: - self._override_rtif(context) - - return node_result.get("status") - - def _get_producer_task_status(self, context: Context) -> str | None: - """ - Get the task status of the producer task for both Airflow 2 and Airflow 3. - - Returns the state of the producer task instance, or None if not found. - """ - ti = context["ti"] - run_id = context["run_id"] - dag_id = ti.dag_id - - fetch_state = build_producer_state_fetcher( - airflow_version=AIRFLOW_VERSION, - dag_id=dag_id, - run_id=run_id, - producer_task_id=self.producer_task_id, - logger=logger, - ) - if fetch_state is None: - return None - - return fetch_state() - - def execute(self, context: Context, **kwargs: Any) -> None: - if not self.deferrable: - super().execute(context) - elif not self.poke(context): - self.defer( - trigger=WatcherTrigger( - model_unique_id=self.model_unique_id, - producer_task_id=self.producer_task_id, - dag_id=self.dag_id, - run_id=context["run_id"], - map_index=context["task_instance"].map_index, - use_event=self._use_event(), - poke_interval=self.poke_interval, - ), - timeout=self.execution_timeout, - method_name=self.execute_complete.__name__, - ) - - def execute_complete(self, context: Context, event: dict[str, str]) -> None: - status = event.get("status") - if status != "failed": - return - - reason = event.get("reason") - if reason == "model_failed": - raise AirflowException( - f"dbt model '{self.model_unique_id}' failed. Review the producer task '{self.producer_task_id}' logs for details." - ) - - if reason == "producer_failed": - raise AirflowException( - f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." - ) - def _use_event(self) -> bool: if not self.invocation_mode: self._discover_invocation_mode() return self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None - def poke(self, context: Context) -> bool: - """ - Checks the status of a dbt model run by pulling relevant XComs from the master task. - Handles retries and checks for successful completion of the model execution. - """ - ti = context["ti"] - try_number = ti.try_number - - logger.info( - "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' for model '%s'", - try_number, - self.poke_retry_number, - self.producer_task_id, - self.model_unique_id, - ) - - if try_number > 1: - return self._fallback_to_local_run(try_number, context) - - # We have assumption here that both the build producer and the sensor task will have same invocation mode - producer_task_state = self._get_producer_task_status(context) - if self._use_event(): - status = self._get_status_from_events(ti, context) - else: - status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") - - if status is None: - - if producer_task_state == "failed": - if self.poke_retry_number > 0: - raise AirflowException( - f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." - ) - else: - # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` - return self._fallback_to_local_run(try_number, context) - - self.poke_retry_number += 1 - - return False - elif status == "success": - return True - else: - raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'") - # This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. # That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 6fb56092ea..5955a15a80 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -96,7 +96,7 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}} line = json.dumps(log_line) - with patch("cosmos.operators.watcher.safe_xcom_push") as mock_push: + with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: if expect_assert: with pytest.raises(AssertionError): _store_dbt_resource_status_from_log(line, {"context": context}) @@ -113,6 +113,6 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, def test_store_dbt_resource_status_from_log_invalid_json(): invalid_line = "{not a valid json}" - with patch("cosmos.operators.watcher.safe_xcom_push") as mock_push: + with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: _store_dbt_resource_status_from_log(invalid_line, {"context": {"ti": MagicMock()}}) mock_push.assert_not_called() diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index c56d5e87c2..15ee443c4c 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -17,10 +17,9 @@ from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, TestBehavior from cosmos.config import InvocationMode -from cosmos.constants import ExecutionMode +from cosmos.constants import PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT, ExecutionMode from cosmos.operators._watcher import WatcherTrigger from cosmos.operators.watcher import ( - PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT, DbtBuildWatcherOperator, DbtConsumerWatcherSensor, DbtProducerWatcherOperator, @@ -104,7 +103,7 @@ def test_serialize_event(mock_mtd): def test_dbt_producer_watcher_operator_priority_weight_default(): """Test that DbtProducerWatcherOperator uses default priority_weight of 9999.""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - assert op.priority_weight == PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT + assert op.priority_weight == PRODUCER_WATCHER_DEFAULT_PRIORITY_WEIGHT def test_dbt_producer_watcher_operator_priority_weight_override(): @@ -610,7 +609,7 @@ def make_context(self, ti_mock, *, run_id: str = "test-run", map_index: int = 0) } @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") - @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + @patch("cosmos.operators._watcher.base.AIRFLOW_VERSION", new=Version("2.7.0")) def test_get_producer_task_status_airflow2(self): sensor = self.make_sensor() sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( @@ -622,7 +621,7 @@ def test_get_producer_task_status_airflow2(self): fetcher = MagicMock(return_value="success") - with patch("cosmos.operators.watcher.build_producer_state_fetcher", return_value=fetcher) as mock_builder: + with patch("cosmos.operators._watcher.base.build_producer_state_fetcher", return_value=fetcher) as mock_builder: status = sensor._get_producer_task_status(context) mock_builder.assert_called_once_with( @@ -636,7 +635,7 @@ def test_get_producer_task_status_airflow2(self): assert status == "success" @pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0") - @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0")) + @patch("cosmos.operators._watcher.base.AIRFLOW_VERSION", new=Version("2.7.0")) def test_get_producer_task_status_airflow2_missing_instance(self): sensor = self.make_sensor() sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( @@ -648,14 +647,14 @@ def test_get_producer_task_status_airflow2_missing_instance(self): fetcher = MagicMock(return_value=None) - with patch("cosmos.operators.watcher.build_producer_state_fetcher", return_value=fetcher): + with patch("cosmos.operators._watcher.base.build_producer_state_fetcher", return_value=fetcher): status = sensor._get_producer_task_status(context) fetcher.assert_called_once_with() assert status is None @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") - @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("cosmos.operators._watcher.base.AIRFLOW_VERSION", new=Version("3.0.0")) @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") def test_get_producer_task_status_airflow3(self, mock_get_task_states): sensor = self.make_sensor() @@ -676,7 +675,7 @@ def test_get_producer_task_status_airflow3(self, mock_get_task_states): ) @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") - @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("cosmos.operators._watcher.base.AIRFLOW_VERSION", new=Version("3.0.0")) @patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states") def test_get_producer_task_status_airflow3_missing_state(self, mock_get_task_states): sensor = self.make_sensor() @@ -697,7 +696,7 @@ def test_get_producer_task_status_airflow3_missing_state(self, mock_get_task_sta ) @pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0") - @patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0")) + @patch("cosmos.operators._watcher.base.AIRFLOW_VERSION", new=Version("3.0.0")) def test_get_producer_task_status_airflow3_import_error(self): sensor = self.make_sensor() sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__( @@ -707,7 +706,7 @@ def test_get_producer_task_status_airflow3_import_error(self): ti.dag_id = "example_dag" context = self.make_context(ti, run_id="run_4") - with patch("cosmos.operators.watcher.build_producer_state_fetcher", return_value=None) as mock_builder: + with patch("cosmos.operators._watcher.base.build_producer_state_fetcher", return_value=None) as mock_builder: status = sensor._get_producer_task_status(context) mock_builder.assert_called_once_with( @@ -748,7 +747,7 @@ def test_poke_success_from_run_results(self): assert result is True @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None) - def _fallback_to_local_run(self, mock_get_producer_task_status): + def _fallback_to_non_watcher_run(self, mock_get_producer_task_status): sensor = self.make_sensor() sensor.invocation_mode = None @@ -794,14 +793,14 @@ def test_task_retry(self, mock_build_and_run_cmd): sensor.poke(context) mock_build_and_run_cmd.assert_called_once() - def test_fallback_to_local_run(self): + def test_fallback_to_non_watcher_run(self): sensor = self.make_sensor() ti = MagicMock() ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--select", "some_model", "--threads", "2"] context = self.make_context(ti) sensor.build_and_run_cmd = MagicMock() - result = sensor._fallback_to_local_run(2, context) + result = sensor._fallback_to_non_watcher_run(2, context) assert result is True sensor.build_and_run_cmd.assert_called_once() @@ -863,7 +862,7 @@ def test_get_status_from_events_sets_compiled_sql(self): assert result == "success" assert sensor.compiled_sql == "select 42" - @patch("cosmos.operators.watcher.get_xcom_val") + @patch("cosmos.operators._watcher.base.get_xcom_val") def test_producer_state_failed(self, mock_get_xcom_val): sensor = self.make_sensor() sensor._get_producer_task_status.return_value = "failed" @@ -881,10 +880,10 @@ def test_producer_state_failed(self, mock_get_xcom_val): ): sensor.poke(context) - @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._fallback_to_local_run") - @patch("cosmos.operators.watcher.get_xcom_val") + @patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._fallback_to_non_watcher_run") + @patch("cosmos.operators._watcher.base.get_xcom_val") def test_producer_state_does_not_fail_if_previously_upstream_failed( - self, mock_get_xcom_val, mock_fallback_to_local_run + self, mock_get_xcom_val, mock_fallback_to_non_watcher_run ): """ Attempt to run the task using ExecutionMode.LOCAL if State.UPSTREAM_FAILED happens. @@ -901,7 +900,7 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed( context = self.make_context(ti) sensor.poke(context) - mock_fallback_to_local_run.assert_called_once() + mock_fallback_to_non_watcher_run.assert_called_once() @patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif") def test_get_status_from_run_results_with_compiled_sql(self, mock_override_rtif, monkeypatch): From 7c96d1bb0f0ca61f4b148ec4f040332a1d4fc975 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 5 Jan 2026 16:32:07 +0000 Subject: [PATCH 2/8] Fix unittests --- cosmos/operators/_watcher/base.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 9f23db2c45..b6ade82098 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import TYPE_CHECKING, Any +from typing import Any from airflow.exceptions import AirflowException @@ -17,15 +17,10 @@ try: from airflow.sdk.bases.sensor import BaseSensorOperator + from airflow.sdk.definitions.context import Context except ImportError: # pragma: no cover from airflow.sensors.base import BaseSensorOperator - - -if TYPE_CHECKING: # pragma: no cover - try: - from airflow.sdk.definitions.context import Context - except ImportError: - from airflow.utils.context import Context # type: ignore[attr-defined] + from airflow.utils.context import Context # type: ignore[attr-defined] logger = get_logger(__name__) @@ -132,7 +127,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo model_selector = self.model_unique_id.split(".")[-1] cmd_flags = extra_flags + ["--select", model_selector] - self.build_and_run_cmd(context, cmd_flags=cmd_flags) + self.build_and_run_cmd(context, cmd_flags=cmd_flags) # type: ignore[attr-defined] logger.info("dbt run completed successfully on retry for model '%s'", self.model_unique_id) return True @@ -157,7 +152,7 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any: logger.info("Node Info: %s", run_results_json) self.compiled_sql = node_result.get("compiled_code") if self.compiled_sql: - self._override_rtif(context) + self._override_rtif(context) # type: ignore[attr-defined] return node_result.get("status") @@ -223,16 +218,6 @@ def _use_event(self) -> bool: def _get_status_from_events(self, ti: Any, context: Context) -> Any: raise NotImplementedError("Subclasses should implement this method if `_use_event` may return True") - def _override_rtif(self, context: Context) -> None: - raise NotImplementedError( - "Subclasses should implement this method, or inherit from a class that implements it (e.g. DbtRunLocalOperator)." - ) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> None: - raise NotImplementedError( - "Subclasses should implement this method, or inherit from a class that implements it (e.g. DbtRunLocalOperator)." - ) - def poke(self, context: Context) -> bool: """ Checks the status of a dbt model run by pulling relevant XComs from the master task. From a993b0da98900889f714f176a881bedc6fbeceeb Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 5 Jan 2026 17:52:47 +0000 Subject: [PATCH 3/8] Improve test coverage --- cosmos/operators/_watcher/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index b6ade82098..73c4b8b92d 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -82,10 +82,10 @@ def __init__( profiles_dir=profiles_dir, **kwargs, ) - self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") self.project_dir = project_dir self.producer_task_id = producer_task_id self.deferrable = deferrable + self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") @staticmethod def _filter_flags(flags: list[str]) -> list[str]: From 0b5b237959a46b72f12ce388d355a0ea9f4fc605 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 6 Jan 2026 11:59:26 +0000 Subject: [PATCH 4/8] Simplify DbtConsumerWatcherSensor based on code review: https://github.com/astronomer/astronomer-cosmos/pull/2245\#discussion_r2664617989 --- cosmos/operators/watcher.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 01caf2373c..bb593aaeb6 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -11,7 +11,6 @@ from airflow.exceptions import AirflowException from cosmos.config import ProfileConfig -from cosmos.constants import CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT from cosmos.operators._watcher import _parse_compressed_xcom, safe_xcom_push try: @@ -227,10 +226,6 @@ def __init__( deferrable: bool = True, **kwargs: Any, ) -> None: - self.compiled_sql = "" - extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} - kwargs.setdefault("priority_weight", CONSUMER_WATCHER_DEFAULT_PRIORITY_WEIGHT) - kwargs.setdefault("weight_rule", WATCHER_TASK_WEIGHT_RULE) super().__init__( poke_interval=poke_interval, timeout=timeout, @@ -240,9 +235,6 @@ def __init__( profiles_dir=profiles_dir, **kwargs, ) - self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") - self.producer_task_id = producer_task_id - self.deferrable = deferrable def _get_status_from_events(self, ti: Any, context: Context) -> Any: From 6e077eb8cdc8c89cff82f90c77aa3742f344839c Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 6 Jan 2026 12:03:09 +0000 Subject: [PATCH 5/8] Address CR feedback: https://github.com/astronomer/astronomer-cosmos/pull/2245/changes\#r2664673872 --- cosmos/operators/_watcher/base.py | 2 +- cosmos/operators/watcher.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 73c4b8b92d..6774f4bc95 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -26,7 +26,7 @@ logger = get_logger(__name__) -def _store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: +def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: """ Parses a single line from dbt JSON logs and stores node status to Airflow XCom. diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index bb593aaeb6..cc97e84718 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -26,7 +26,7 @@ InvocationMode, ) from cosmos.log import get_logger -from cosmos.operators._watcher.base import BaseConsumerSensor, _store_dbt_resource_status_from_log +from cosmos.operators._watcher.base import BaseConsumerSensor, store_dbt_resource_status_from_log from cosmos.operators.base import ( DbtBuildMixin, DbtRunMixin, @@ -83,7 +83,7 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): template_fields = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] # Use staticmethod to prevent Python's descriptor protocol from binding the function to `self` # when accessed via instance, which would incorrectly pass `self` as the first argument - _process_log_line_callable = staticmethod(_store_dbt_resource_status_from_log) + _process_log_line_callable = staticmethod(store_dbt_resource_status_from_log) def __init__(self, *args: Any, **kwargs: Any) -> None: task_id = kwargs.pop("task_id", PRODUCER_WATCHER_TASK_ID) From 7107ec03a743a961b51ec592a069977b664187eb Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 6 Jan 2026 12:04:03 +0000 Subject: [PATCH 6/8] Apply suggestion from @pankajkoti Co-authored-by: Pankaj Koti --- cosmos/operators/_watcher/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 6774f4bc95..8070de9b19 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -111,7 +111,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. """ logger.info( - "Retry attempt #%s – Running model '%s' from project '%s' using ExecutionMode.LOCAL", + f"Retry attempt #%s – Running model '%s' from project '%s' using {execution_mode}", try_number - 1, self.model_unique_id, self.project_dir, From 75243ff7e8d8be79d96d38a67597d57227176dcf Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 6 Jan 2026 12:06:33 +0000 Subject: [PATCH 7/8] Address CR feedback https://github.com/astronomer/astronomer-cosmos/pull/2245/changes\#r2664693660 --- cosmos/operators/_watcher/base.py | 8 ++++---- cosmos/operators/watcher.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 8070de9b19..60f8fedd74 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -189,7 +189,7 @@ def execute(self, context: Context, **kwargs: Any) -> None: dag_id=self.dag_id, run_id=context["run_id"], map_index=context["task_instance"].map_index, - use_event=self._use_event(), + use_event=self.use_event(), poke_interval=self.poke_interval, ), timeout=self.execution_timeout, @@ -212,11 +212,11 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None: f"Watcher producer task '{self.producer_task_id}' failed before reporting model results. Check its logs for the underlying error." ) - def _use_event(self) -> bool: + def use_event(self) -> bool: raise NotImplementedError("Subclasses must implement this method") def _get_status_from_events(self, ti: Any, context: Context) -> Any: - raise NotImplementedError("Subclasses should implement this method if `_use_event` may return True") + raise NotImplementedError("Subclasses should implement this method if `use_event` may return True") def poke(self, context: Context) -> bool: """ @@ -239,7 +239,7 @@ def poke(self, context: Context) -> bool: # We have assumption here that both the build producer and the sensor task will have same invocation mode producer_task_state = self._get_producer_task_status(context) - if self._use_event(): + if self.use_event(): status = self._get_status_from_events(ti, context) else: status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index cc97e84718..0b86c27cdd 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -259,7 +259,7 @@ def _get_status_from_events(self, ti: Any, context: Context) -> Any: return event_json.get("data", {}).get("run_result", {}).get("status") - def _use_event(self) -> bool: + def use_event(self) -> bool: if not self.invocation_mode: self._discover_invocation_mode() return self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None From 5d93d86a90e9075342079e456b3d51b493dda78f Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 6 Jan 2026 12:13:59 +0000 Subject: [PATCH 8/8] Fix broken tests after applying CR feedback --- cosmos/operators/_watcher/base.py | 2 +- tests/hooks/test_subprocess.py | 12 ++++++------ tests/operators/test_watcher.py | 30 +++++++++++++++--------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 60f8fedd74..2a018a9146 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -111,7 +111,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo with appropriate flags, while ensuring flags like `--select` or `--exclude` are excluded. """ logger.info( - f"Retry attempt #%s – Running model '%s' from project '%s' using {execution_mode}", + f"Retry attempt #%s – Running model '%s' from project '%s' using {self.__class__.__name__}", try_number - 1, self.model_unique_id, self.project_dir, diff --git a/tests/hooks/test_subprocess.py b/tests/hooks/test_subprocess.py index 5955a15a80..232a8ce8f8 100644 --- a/tests/hooks/test_subprocess.py +++ b/tests/hooks/test_subprocess.py @@ -7,7 +7,7 @@ import pytest from cosmos.hooks.subprocess import FullOutputSubprocessHook -from cosmos.operators.watcher import _store_dbt_resource_status_from_log +from cosmos.operators.watcher import store_dbt_resource_status_from_log OS_ENV_KEY = "SUBPROCESS_ENV_TEST" OS_ENV_VAL = "this-is-from-os-environ" @@ -91,7 +91,7 @@ def test_send_sigterm(mock_killpg, mock_getpgid): ("failed", None, False, True), ], ) -def test_store_dbt_resource_status_from_log_param(status, context, should_push, expect_assert): +def teststore_dbt_resource_status_from_log_param(status, context, should_push, expect_assert): # Prepare log line log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}} line = json.dumps(log_line) @@ -99,9 +99,9 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: if expect_assert: with pytest.raises(AssertionError): - _store_dbt_resource_status_from_log(line, {"context": context}) + store_dbt_resource_status_from_log(line, {"context": context}) else: - _store_dbt_resource_status_from_log(line, {"context": context}) + store_dbt_resource_status_from_log(line, {"context": context}) if should_push: mock_push.assert_called_once_with( task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status @@ -110,9 +110,9 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push, mock_push.assert_not_called() -def test_store_dbt_resource_status_from_log_invalid_json(): +def teststore_dbt_resource_status_from_log_invalid_json(): invalid_line = "{not a valid json}" with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push: - _store_dbt_resource_status_from_log(invalid_line, {"context": {"ti": MagicMock()}}) + store_dbt_resource_status_from_log(invalid_line, {"context": {"ti": MagicMock()}}) mock_push.assert_not_called() diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 15ee443c4c..7f313fc324 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -26,7 +26,7 @@ DbtRunWatcherOperator, DbtSeedWatcherOperator, DbtTestWatcherOperator, - _store_dbt_resource_status_from_log, + store_dbt_resource_status_from_log, ) from cosmos.profiles import PostgresUserPasswordProfileMapping, get_automatic_profile_mapping from tests.utils import AIRFLOW_VERSION, new_test_dag @@ -220,11 +220,11 @@ def test_dbt_producer_watcher_operator_blocks_retry_attempt(caplog): ({"status": "success"}, None), ( {"status": "failed", "reason": "model_failed"}, - "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher_operator' logs for details.", + "dbt model 'model.pkg.m' failed. Review the producer task 'dbt_producer_watcher' logs for details.", ), ( {"status": "failed", "reason": "producer_failed"}, - "Watcher producer task 'dbt_producer_watcher_operator' failed before reporting model results. Check its logs for the underlying error.", + "Watcher producer task 'dbt_producer_watcher' failed before reporting model results. Check its logs for the underlying error.", ), ], ) @@ -426,31 +426,31 @@ def fake_build_run(self, context, **kw): class TestStoreDbStatusFromLog: - """Tests for _store_dbt_resource_status_from_log and _process_log_line_callable.""" + """Tests for store_dbt_resource_status_from_log and _process_log_line_callable.""" - def test_store_dbt_resource_status_from_log_success(self): + def teststore_dbt_resource_status_from_log_success(self): """Test that success status is correctly parsed and stored in XCom.""" ti = _MockTI() ctx = {"ti": ti} log_line = json.dumps({"data": {"node_info": {"node_status": "success", "unique_id": "model.pkg.my_model"}}}) - _store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}) assert ti.store.get("model__pkg__my_model_status") == "success" - def test_store_dbt_resource_status_from_log_failed(self): + def teststore_dbt_resource_status_from_log_failed(self): """Test that failed status is correctly parsed and stored in XCom.""" ti = _MockTI() ctx = {"ti": ti} log_line = json.dumps({"data": {"node_info": {"node_status": "failed", "unique_id": "model.pkg.failed_model"}}}) - _store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}) assert ti.store.get("model__pkg__failed_model_status") == "failed" - def test_store_dbt_resource_status_from_log_ignores_other_statuses(self): + def teststore_dbt_resource_status_from_log_ignores_other_statuses(self): """Test that statuses other than success/failed are ignored.""" ti = _MockTI() ctx = {"ti": ti} @@ -459,22 +459,22 @@ def test_store_dbt_resource_status_from_log_ignores_other_statuses(self): {"data": {"node_info": {"node_status": "running", "unique_id": "model.pkg.running_model"}}} ) - _store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}) assert "model__pkg__running_model_status" not in ti.store - def test_store_dbt_resource_status_from_log_handles_invalid_json(self, caplog): + def teststore_dbt_resource_status_from_log_handles_invalid_json(self, caplog): """Test that invalid JSON doesn't raise an exception.""" ti = _MockTI() ctx = {"ti": ti} # Should not raise an exception - _store_dbt_resource_status_from_log("not valid json {{{", {"context": ctx}) + store_dbt_resource_status_from_log("not valid json {{{", {"context": ctx}) # No status should be stored assert len(ti.store) == 0 - def test_store_dbt_resource_status_from_log_handles_missing_node_info(self): + def teststore_dbt_resource_status_from_log_handles_missing_node_info(self): """Test that missing node_info doesn't raise an exception.""" ti = _MockTI() ctx = {"ti": ti} @@ -482,7 +482,7 @@ def test_store_dbt_resource_status_from_log_handles_missing_node_info(self): log_line = json.dumps({"data": {"other_key": "value"}}) # Should not raise an exception - _store_dbt_resource_status_from_log(log_line, {"context": ctx}) + store_dbt_resource_status_from_log(log_line, {"context": ctx}) # No status should be stored assert len(ti.store) == 0 @@ -506,7 +506,7 @@ def test_process_log_line_callable_is_not_bound_method(self): ), "_process_log_line_callable should not be a bound method when accessed through instance" # Verify it's the original function - assert callable_from_instance is _store_dbt_resource_status_from_log + assert callable_from_instance is store_dbt_resource_status_from_log def test_process_log_line_callable_accepts_two_arguments(self): """Test that the callable can be called with exactly 2 arguments (line, kwargs).