From daf08f748524f5de10d909dbe74118d9d7acc3a7 Mon Sep 17 00:00:00 2001 From: pankajastro Date: Tue, 23 Sep 2025 17:48:26 +0530 Subject: [PATCH 01/19] Add DbtModelStatusSensor for the proposed ExecutionMode.WATCHER Commit squased from previous version of the implementation, before merging to main --- cosmos/operators/base.py | 2 +- cosmos/operators/watcher.py | 1 + tests/operators/test_watcher.py | 211 ++++++++++---------------------- 3 files changed, 69 insertions(+), 145 deletions(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 02b3192535..c9f8754f26 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -309,7 +309,7 @@ def build_and_run_cmd( ) -> Any: """Override this method for the operator to execute the dbt command""" - def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore + def execute(self, context: Context) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 23cca8c3c4..79154b35eb 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import base64 import json import logging diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index ce8d915015..9aaabc61dd 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -187,162 +187,85 @@ def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): assert op.invocation_mode == InvocationMode.SUBPROCESS -MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" -ENCODED_RUN_RESULTS = base64.b64encode( - zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}') -).decode("utf-8") +class TestDbtModelStatusSensor: -ENCODED_RUN_RESULTS_FAILED = base64.b64encode( - zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"fail"}]}') -).decode("utf-8") - -ENCODED_EVENT = base64.b64encode(zlib.compress(b'{"data": {"run_result": {"status": "success"}}}')).decode("utf-8") - - -class TestDbtConsumerWatcherSensor: - - def make_sensor(self, **kwargs): - extra_context = {"dbt_node_config": {"unique_id": "model.jaffle_shop.stg_orders"}} - kwargs["extra_context"] = extra_context - sensor = DbtConsumerWatcherSensor( + @pytest.fixture + def sensor(self): + return DbtModelStatusSensor( task_id="model.my_model", - project_dir="/tmp/project", - profile_config=None, - **kwargs, + model_unique_id="model.my_model", + project_dir="/fake/project", + profiles_dir="/fake/profiles", ) - sensor.invocation_mode = "DBT_RUNNER" - return sensor - - def make_context(self, ti_mock): - return {"ti": ti_mock} - - @patch("cosmos.operators.watcher.EventMsg") - def test_poke_status_none_from_events(self, MockEventMsg): - mock_event_instance = MagicMock() - mock_event_instance.status = "done" - MockEventMsg.return_value = mock_event_instance - - sensor = self.make_sensor() - sensor.invocation_mode = InvocationMode.DBT_RUNNER - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.side_effect = [None, None] # no event msg found - context = self.make_context(ti) - - result = sensor.poke(context) - assert result is False - - def test_poke_success_from_run_results(self): - sensor = self.make_sensor() - sensor.invocation_mode = "SUBPROCESS" + def test_filter_flags_removes_select_and_exclude(self): + flags = ["--select", "model_a", "--exclude", "model_b", "--threads", "4"] + expected = ["--threads", "4"] + result = DbtModelStatusSensor._filter_flags(flags) + assert result == expected - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS - context = self.make_context(ti) + def test_filter_flags_handles_no_flags(self): + assert DbtModelStatusSensor._filter_flags([]) == [] - result = sensor.poke(context) - assert result is True + @patch("cosmos.operators.watcher.DbtModelStatusSensor.build_and_run_cmd") + def test_handle_task_retry_runs_model(self, mock_build_and_run_cmd, sensor): + mock_context = { + "ti": Mock(), + } - def test_invocation_mode_none(self): - sensor = self.make_sensor() - sensor.invocation_mode = None + mock_task = Mock() + mock_task.add_cmd_flags.return_value = ["--select", "model_a", "--threads", "4"] + mock_context["ti"].task.dag.get_task.return_value = mock_task - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS - context = self.make_context(ti) + result = sensor._handle_task_retry(2, mock_context) - result = sensor.poke(context) assert result is True - - def test_poke_failure_from_run_results(self): - sensor = self.make_sensor() - sensor.invocation_mode = "OTHER_MODE" - - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS_FAILED - context = self.make_context(ti) - - with pytest.raises(AirflowException): - sensor.poke(context) - - def test_poke_status_none_from_run_results(self): - sensor = self.make_sensor() - sensor.invocation_mode = "OTHER_MODE" - - ti = MagicMock() - ti.try_number = 1 - ti.xcom_pull.return_value = None - context = self.make_context(ti) - + mock_build_and_run_cmd.assert_called_once() + called_args = mock_build_and_run_cmd.call_args[1]["cmd_flags"] + assert "--select" in called_args + assert "my_model" in called_args + assert "--threads" in called_args + + def test_poke_returns_false_when_no_data(self, sensor): + ti_mock = Mock() + ti_mock.try_number = 1 + ti_mock.xcom_pull.return_value = None + context = {"ti": ti_mock} result = sensor.poke(context) assert result is False - @patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd") - def test_task_retry(self, mock_build_and_run_cmd): - sensor = self.make_sensor() - ti = MagicMock() - ti.try_number = 2 - ti.xcom_pull.return_value = None - context = self.make_context(ti) + def make_event_payload(self, status: str) -> str: + data = {"data": {"run_result": {"status": status}}} + compressed = zlib.compress(str(data).encode("utf-8")) + return base64.b64encode(compressed).decode("utf-8") + + def test_poke_success_returns_true(self, sensor): + ti_mock = Mock() + ti_mock.try_number = 1 + ti_mock.xcom_pull.side_effect = [ + [], # dbt_startup_events + [], # pipeline_outlets + self.make_event_payload("success"), # node_finished_key + ] + context = {"ti": ti_mock} + assert sensor.poke(context) is True + + def test_poke_failure_raises_exception(self, sensor): + ti_mock = Mock() + ti_mock.try_number = 1 + ti_mock.xcom_pull.side_effect = [ + None, # dbt_startup_events + None, # pipeline_outlets + self.make_event_payload("error"), # node_finished_key + ] + context = {"ti": ti_mock} + with pytest.raises(AirflowException, match="finished with status 'error'"): + sensor.poke(context) + @patch.object(DbtModelStatusSensor, "_handle_task_retry") + def test_poke_calls_retry_on_retry_attempt(self, mock_retry, sensor): + ti_mock = Mock() + ti_mock.try_number = 2 + context = {"ti": ti_mock} sensor.poke(context) - mock_build_and_run_cmd.assert_called_once() - - def test_handle_task_retry(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--select", "some_model", "--threads", "2"] - context = self.make_context(ti) - sensor.build_and_run_cmd = MagicMock() - - result = sensor._handle_task_retry(2, context) - - assert result is True - sensor.build_and_run_cmd.assert_called_once() - args, kwargs = sensor.build_and_run_cmd.call_args - assert "--select" in kwargs["cmd_flags"] - assert MODEL_UNIQUE_ID.split(".")[-1] in kwargs["cmd_flags"] - - def test_filter_flags(self): - flags = ["--select", "model", "--exclude", "other", "--threads", "2"] - expected = ["--threads", "2"] - - result = DbtConsumerWatcherSensor._filter_flags(flags) - - assert result == expected - - def test_get_status_from_run_results_success(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.return_value = ENCODED_RUN_RESULTS - - result = sensor._get_status_from_run_results(ti) - assert result == "success" - - def test_get_status_from_run_results_none(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.return_value = None - - result = sensor._get_status_from_run_results(ti) - assert result is None - - def test_get_status_from_events_success(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.side_effect = [None, ENCODED_EVENT] - - result = sensor._get_status_from_events(ti) - assert result == "success" - - def test_get_status_from_events_none(self): - sensor = self.make_sensor() - ti = MagicMock() - ti.xcom_pull.side_effect = [None, None] - - result = sensor._get_status_from_events(ti) - assert result is None + mock_retry.assert_called_once_with(2, context) From 29240239f17633d3064faf7fa5437ce680c95c12 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 25 Sep 2025 11:19:49 +0100 Subject: [PATCH 02/19] Introduce ExecutionMode.WATCHER - Managed to run ExecutionMode.WATCHER for modified basic_cosmos_dag DAG - Add Test nodes as EmptyOperators --- cosmos/airflow/graph.py | 55 ++++++++++++++++++++++- cosmos/constants.py | 3 ++ cosmos/operators/watcher.py | 85 ++++++++++++++++++++++++++++++++++-- dev/dags/basic_cosmos_dag.py | 11 ++++- 4 files changed, 149 insertions(+), 5 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index eb3186551c..9ea1fc27ce 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -19,6 +19,7 @@ DBT_SETUP_ASYNC_TASK_ID, DBT_TEARDOWN_ASYNC_TASK_ID, DEFAULT_DBT_RESOURCES, + PRODUCER_WATCHER_TASK_ID, SUPPORTED_BUILD_RESOURCES, TESTABLE_DBT_RESOURCES, DbtResourceType, @@ -193,6 +194,7 @@ def _get_task_id_and_args( node: DbtNode, args: dict[str, Any], use_task_group: bool, + execution_mode: ExecutionMode, normalize_task_id: Callable[..., Any] | None, normalize_task_display_name: Callable[..., Any] | None, resource_suffix: str, @@ -222,6 +224,9 @@ def _get_task_id_and_args( else: task_id = task_name + if execution_mode == ExecutionMode.WATCHER: + args_update["model_unique_id"] = node.unique_id + return task_id, args_update @@ -318,6 +323,7 @@ def create_task_metadata( normalize_task_display_name=normalize_task_display_name, resource_suffix=resource_suffix, include_resource_type=True, + execution_mode=execution_mode, ) elif node.resource_type == DbtResourceType.SOURCE: args["select"] = f"source:{node.resource_name}" @@ -331,7 +337,13 @@ def create_task_metadata( return None task_id, args = _get_task_id_and_args( - node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source" + node, + args, + use_task_group, + normalize_task_id, + normalize_task_display_name, + "source", + execution_mode=execution_mode, ) if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL: # render sources without freshness as empty operators @@ -350,6 +362,7 @@ def create_task_metadata( normalize_task_id=normalize_task_id, normalize_task_display_name=normalize_task_display_name, resource_suffix=resource_suffix, + execution_mode=execution_mode, ) _override_profile_if_needed(args, node.profile_config_to_override) @@ -505,6 +518,37 @@ def _add_dbt_setup_async_task( tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task +def _add_producer( + dag: DAG, + task_args: dict[str, Any], + tasks_map: dict[str, Any], + task_group: TaskGroup | None, + render_config: RenderConfig | None = None, +) -> None: + + producer_task_args = task_args.copy() + + if render_config is not None: + producer_task_args["select"] = render_config.select + producer_task_args["selector"] = render_config.selector + producer_task_args["exclude"] = render_config.exclude + + producer_task_metadata = TaskMetadata( + id=PRODUCER_WATCHER_TASK_ID, + operator_class="cosmos.operators.watcher.DbtProducerWatcherOperator", + arguments=producer_task_args, + ) + producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) + + # If we trigger_rule="always" (https://github.com/astronomer/astronomer-cosmos/issues/1959), + # the producer task becomes the parent to the root nodes of the remaining DAG + # for task_id, task in tasks_map.items(): + # if not task.upstream_list: + # producer_airflow_task >> task + + tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task + + def should_create_detached_nodes(render_config: RenderConfig) -> bool: """ Decide if we should calculate / insert detached nodes into the graph. @@ -709,6 +753,15 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro tasks_map[node_id] = test_task create_airflow_task_dependencies(nodes, tasks_map) + + if execution_mode == ExecutionMode.WATCHER: + _add_producer( + dag, + task_args, + tasks_map, + task_group, + ) + if settings.enable_setup_async_task: _add_dbt_setup_async_task( dag, diff --git a/cosmos/constants.py b/cosmos/constants.py index 64b43f5931..8023cfbf94 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -90,6 +90,7 @@ class ExecutionMode(Enum): Where the Cosmos tasks should be executed. """ + WATCHER = "watcher" LOCAL = "local" AIRFLOW_ASYNC = "airflow_async" DOCKER = "docker" @@ -167,6 +168,8 @@ def _missing_value_(cls, value): # type: ignore DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async" DBT_TEARDOWN_ASYNC_TASK_ID = "dbt_teardown_async" +PRODUCER_WATCHER_TASK_ID = "dbt_producer" + TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}/{execution_modes}" TELEMETRY_VERSION = "v2" TELEMETRY_TIMEOUT = 1.0 diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 79154b35eb..8fe5b6e9fe 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -5,7 +5,7 @@ import json import logging import zlib -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: # pragma: no cover try: @@ -19,9 +19,23 @@ from airflow.sensors.base import BaseSensorOperator from airflow.exceptions import AirflowException +try: + from airflow.providers.standard.operators.empty import EmptyOperator +except ImportError: + from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] + from cosmos.config import ProfileConfig -from cosmos.constants import InvocationMode -from cosmos.operators.local import DbtLocalBaseOperator, DbtRunLocalOperator +from cosmos.constants import PRODUCER_WATCHER_TASK_ID, InvocationMode +from cosmos.operators.base import ( + DbtRunMixin, + DbtSeedMixin, + DbtSnapshotMixin, +) +from cosmos.operators.local import ( + DbtLocalBaseOperator, + DbtRunLocalOperator, + DbtSourceLocalOperator, +) try: from dbt_common.events.base_types import EventMsg @@ -272,3 +286,68 @@ def poke(self, context: Context) -> bool: return True else: raise AirflowException(f"Model '{self.model_unique_id}' finished with status '{status}'") + + +# This Operator does not seem to make sense for this particular execution mode, since build is executed by the producer task. +# That said, it is important to raise an exception if users attempt to use TestBehavior.BUILD, until we have a better experience. +class DbtBuildWatcherOperator: + def __init__(self, *args: Any, **kwargs: Any): + raise NotImplementedError( + "`ExecutionMode.WATCHER` does not expose a DbtBuild operator, since the build command is executed by the producer task." + ) + + +class DbtSeedWatcherOperator(DbtSeedMixin, DbtModelStatusSensor): # type: ignore[misc] + """ + Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str] = DbtModelStatusSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtModelStatusSensor): + """ + Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str] = DbtModelStatusSensor.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtSourceWatcherOperator(DbtSourceLocalOperator): + """ + Executes a dbt source freshness command, synchronously, as ExecutionMode.LOCAL. + """ + + template_fields: Sequence[str] = DbtSourceLocalOperator.template_fields + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtRunWatcherOperator(DbtModelStatusSensor): + """ + Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). + """ + + template_fields: tuple[str] = DbtModelStatusSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + +class DbtTestWatcherOperator(EmptyOperator): + """ + As a starting point, this operator does nothing. + We'll be implementing this operator as part of: https://github.com/astronomer/astronomer-cosmos/issues/1974 + """ + + def __init__(self, *args: Any, **kwargs: Any): + desired_keys = ("dag", "task_group", "task_id") + new_kwargs = {key: value for key, value in kwargs.items() if key in desired_keys} + super().__init__(**new_kwargs) # type: ignore[no-untyped-call] diff --git a/dev/dags/basic_cosmos_dag.py b/dev/dags/basic_cosmos_dag.py index dbdfa04804..75813cd8a5 100644 --- a/dev/dags/basic_cosmos_dag.py +++ b/dev/dags/basic_cosmos_dag.py @@ -1,3 +1,5 @@ +from cosmos.constants import ExecutionMode + """ An example DAG that uses Cosmos to render a dbt project into an Airflow DAG. """ @@ -7,7 +9,7 @@ from pathlib import Path # [START cosmos_init_imports] -from cosmos import DbtDag, ProfileConfig, ProjectConfig +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig # [END cosmos_init_imports] from cosmos.profiles import PostgresUserPasswordProfileMapping @@ -28,11 +30,18 @@ ), ) + # [START local_example] basic_cosmos_dag = DbtDag( # dbt/cosmos-specific parameters project_config=ProjectConfig(DBT_PROJECT_PATH), profile_config=profile_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, + ), + # render_config=RenderConfig( + # test_behavior=TestBehavior.NONE, + # ), operator_args={ "install_deps": True, # install any necessary dependencies before running any dbt command "full_refresh": True, # used only in dbt commands that support this flag From b7bcd12f21bc03e325a07c8c7421d7f8a7e43786 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 29 Sep 2025 22:44:31 +0100 Subject: [PATCH 03/19] Adjustments after rebase --- cosmos/airflow/graph.py | 23 ++++++++++++----------- cosmos/constants.py | 2 +- cosmos/operators/watcher.py | 22 +++++++++++++--------- cosmos/settings.py | 2 ++ dev/dags/basic_cosmos_dag.py | 10 +--------- 5 files changed, 29 insertions(+), 30 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 9ea1fc27ce..926f0c3b41 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -337,12 +337,13 @@ def create_task_metadata( return None task_id, args = _get_task_id_and_args( - node, - args, - use_task_group, - normalize_task_id, - normalize_task_display_name, - "source", + node=node, + args=args, + use_task_group=use_task_group, + normalize_task_id=normalize_task_id, + normalize_task_display_name=normalize_task_display_name, + resource_suffix=r"source", + include_resource_type=True, execution_mode=execution_mode, ) if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL: @@ -540,11 +541,11 @@ def _add_producer( ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) - # If we trigger_rule="always" (https://github.com/astronomer/astronomer-cosmos/issues/1959), - # the producer task becomes the parent to the root nodes of the remaining DAG - # for task_id, task in tasks_map.items(): - # if not task.upstream_list: - # producer_airflow_task >> task + for task_id, task in tasks_map.items(): + # we want to make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom + if not task.upstream_list: + producer_airflow_task >> task + task.trigger_rule = "always" tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task diff --git a/cosmos/constants.py b/cosmos/constants.py index 8023cfbf94..b6af7f6409 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -168,7 +168,7 @@ def _missing_value_(cls, value): # type: ignore DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async" DBT_TEARDOWN_ASYNC_TASK_ID = "dbt_teardown_async" -PRODUCER_WATCHER_TASK_ID = "dbt_producer" +PRODUCER_WATCHER_TASK_ID = "dbt_producer_watcher" TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}/{execution_modes}" TELEMETRY_VERSION = "v2" diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 8fe5b6e9fe..3cbc70e9dd 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -1,6 +1,5 @@ from __future__ import annotations -import ast import base64 import json import logging @@ -45,7 +44,9 @@ logger = logging.getLogger(__name__) +CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10 PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999 +WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test() class DbtProducerWatcherOperator(DbtLocalBaseOperator): @@ -77,7 +78,8 @@ class DbtProducerWatcherOperator(DbtLocalBaseOperator): def __init__(self, *args: Any, **kwargs: Any) -> None: task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator") - kwargs["priority_weight"] = kwargs.get("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("weight_rule", WEIGHT_RULE) super().__init__(task_id=task_id, *args, **kwargs) @staticmethod @@ -149,12 +151,14 @@ def __init__( profile_config: ProfileConfig | None = None, project_dir: str | None = None, profiles_dir: str | None = None, - producer_task_id: str = "dbt_producer_watcher", + producer_task_id: str = PRODUCER_WATCHER_TASK_ID, poke_interval: int = 20, timeout: int = 60 * 60, # 1 h safety valve **kwargs: Any, ) -> None: extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} + kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("weight_rule", WEIGHT_RULE) super().__init__( poke_interval=poke_interval, timeout=timeout, @@ -297,23 +301,23 @@ def __init__(self, *args: Any, **kwargs: Any): ) -class DbtSeedWatcherOperator(DbtSeedMixin, DbtModelStatusSensor): # type: ignore[misc] +class DbtSeedWatcherOperator(DbtSeedMixin, DbtConsumerWatcherSensor): # type: ignore[misc] """ Watches for the progress of dbt seed execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtModelStatusSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] + template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) -class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtModelStatusSensor): +class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): """ Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtModelStatusSensor.template_fields # type: ignore[operator] + template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -330,12 +334,12 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) -class DbtRunWatcherOperator(DbtModelStatusSensor): +class DbtRunWatcherOperator(DbtConsumerWatcherSensor): """ Watches for the progress of dbt model execution, run by the producer task (DbtProducerWatcherOperator). """ - template_fields: tuple[str] = DbtModelStatusSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] + template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields + DbtRunMixin.template_fields # type: ignore[operator] def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) diff --git a/cosmos/settings.py b/cosmos/settings.py index a4f35aa74d..e9eb3eb72f 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -14,6 +14,8 @@ DEFAULT_OPENLINEAGE_NAMESPACE, ) +test_mode = conf.getboolean("cosmos", "test_mode", fallback=False) + # In MacOS users may want to set the envvar `TMPDIR` if they do not want the value of the temp directory to change DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME) cache_dir = Path(conf.get("cosmos", "cache_dir", fallback=DEFAULT_CACHE_DIR) or DEFAULT_CACHE_DIR) diff --git a/dev/dags/basic_cosmos_dag.py b/dev/dags/basic_cosmos_dag.py index 75813cd8a5..7e7944a346 100644 --- a/dev/dags/basic_cosmos_dag.py +++ b/dev/dags/basic_cosmos_dag.py @@ -1,5 +1,3 @@ -from cosmos.constants import ExecutionMode - """ An example DAG that uses Cosmos to render a dbt project into an Airflow DAG. """ @@ -9,7 +7,7 @@ from pathlib import Path # [START cosmos_init_imports] -from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos import DbtDag, ProfileConfig, ProjectConfig # [END cosmos_init_imports] from cosmos.profiles import PostgresUserPasswordProfileMapping @@ -36,12 +34,6 @@ # dbt/cosmos-specific parameters project_config=ProjectConfig(DBT_PROJECT_PATH), profile_config=profile_config, - execution_config=ExecutionConfig( - execution_mode=ExecutionMode.WATCHER, - ), - # render_config=RenderConfig( - # test_behavior=TestBehavior.NONE, - # ), operator_args={ "install_deps": True, # install any necessary dependencies before running any dbt command "full_refresh": True, # used only in dbt commands that support this flag From 1f8950e43da3751437820af05912362ae053d467 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 30 Sep 2025 08:03:19 +0100 Subject: [PATCH 04/19] Fix watcher unit tests, revert to version on main --- tests/operators/test_watcher.py | 211 ++++++++++++++++++++++---------- 1 file changed, 144 insertions(+), 67 deletions(-) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 9aaabc61dd..ce8d915015 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -187,85 +187,162 @@ def test_execute_discovers_invocation_mode(_mock_execute, _mock_is_available): assert op.invocation_mode == InvocationMode.SUBPROCESS -class TestDbtModelStatusSensor: +MODEL_UNIQUE_ID = "model.jaffle_shop.stg_orders" +ENCODED_RUN_RESULTS = base64.b64encode( + zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"success"}]}') +).decode("utf-8") - @pytest.fixture - def sensor(self): - return DbtModelStatusSensor( +ENCODED_RUN_RESULTS_FAILED = base64.b64encode( + zlib.compress(b'{"results":[{"unique_id":"model.jaffle_shop.stg_orders","status":"fail"}]}') +).decode("utf-8") + +ENCODED_EVENT = base64.b64encode(zlib.compress(b'{"data": {"run_result": {"status": "success"}}}')).decode("utf-8") + + +class TestDbtConsumerWatcherSensor: + + def make_sensor(self, **kwargs): + extra_context = {"dbt_node_config": {"unique_id": "model.jaffle_shop.stg_orders"}} + kwargs["extra_context"] = extra_context + sensor = DbtConsumerWatcherSensor( task_id="model.my_model", - model_unique_id="model.my_model", - project_dir="/fake/project", - profiles_dir="/fake/profiles", + project_dir="/tmp/project", + profile_config=None, + **kwargs, ) - def test_filter_flags_removes_select_and_exclude(self): - flags = ["--select", "model_a", "--exclude", "model_b", "--threads", "4"] - expected = ["--threads", "4"] - result = DbtModelStatusSensor._filter_flags(flags) - assert result == expected + sensor.invocation_mode = "DBT_RUNNER" + return sensor - def test_filter_flags_handles_no_flags(self): - assert DbtModelStatusSensor._filter_flags([]) == [] + def make_context(self, ti_mock): + return {"ti": ti_mock} - @patch("cosmos.operators.watcher.DbtModelStatusSensor.build_and_run_cmd") - def test_handle_task_retry_runs_model(self, mock_build_and_run_cmd, sensor): - mock_context = { - "ti": Mock(), - } + @patch("cosmos.operators.watcher.EventMsg") + def test_poke_status_none_from_events(self, MockEventMsg): + mock_event_instance = MagicMock() + mock_event_instance.status = "done" + MockEventMsg.return_value = mock_event_instance - mock_task = Mock() - mock_task.add_cmd_flags.return_value = ["--select", "model_a", "--threads", "4"] - mock_context["ti"].task.dag.get_task.return_value = mock_task + sensor = self.make_sensor() + sensor.invocation_mode = InvocationMode.DBT_RUNNER + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.side_effect = [None, None] # no event msg found + context = self.make_context(ti) - result = sensor._handle_task_retry(2, mock_context) + result = sensor.poke(context) + assert result is False + + def test_poke_success_from_run_results(self): + sensor = self.make_sensor() + sensor.invocation_mode = "SUBPROCESS" + + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = ENCODED_RUN_RESULTS + context = self.make_context(ti) + result = sensor.poke(context) assert result is True - mock_build_and_run_cmd.assert_called_once() - called_args = mock_build_and_run_cmd.call_args[1]["cmd_flags"] - assert "--select" in called_args - assert "my_model" in called_args - assert "--threads" in called_args - - def test_poke_returns_false_when_no_data(self, sensor): - ti_mock = Mock() - ti_mock.try_number = 1 - ti_mock.xcom_pull.return_value = None - context = {"ti": ti_mock} + + def test_invocation_mode_none(self): + sensor = self.make_sensor() + sensor.invocation_mode = None + + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = ENCODED_RUN_RESULTS + context = self.make_context(ti) + result = sensor.poke(context) - assert result is False + assert result is True - def make_event_payload(self, status: str) -> str: - data = {"data": {"run_result": {"status": status}}} - compressed = zlib.compress(str(data).encode("utf-8")) - return base64.b64encode(compressed).decode("utf-8") - - def test_poke_success_returns_true(self, sensor): - ti_mock = Mock() - ti_mock.try_number = 1 - ti_mock.xcom_pull.side_effect = [ - [], # dbt_startup_events - [], # pipeline_outlets - self.make_event_payload("success"), # node_finished_key - ] - context = {"ti": ti_mock} - assert sensor.poke(context) is True - - def test_poke_failure_raises_exception(self, sensor): - ti_mock = Mock() - ti_mock.try_number = 1 - ti_mock.xcom_pull.side_effect = [ - None, # dbt_startup_events - None, # pipeline_outlets - self.make_event_payload("error"), # node_finished_key - ] - context = {"ti": ti_mock} - with pytest.raises(AirflowException, match="finished with status 'error'"): + def test_poke_failure_from_run_results(self): + sensor = self.make_sensor() + sensor.invocation_mode = "OTHER_MODE" + + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = ENCODED_RUN_RESULTS_FAILED + context = self.make_context(ti) + + with pytest.raises(AirflowException): sensor.poke(context) - @patch.object(DbtModelStatusSensor, "_handle_task_retry") - def test_poke_calls_retry_on_retry_attempt(self, mock_retry, sensor): - ti_mock = Mock() - ti_mock.try_number = 2 - context = {"ti": ti_mock} + def test_poke_status_none_from_run_results(self): + sensor = self.make_sensor() + sensor.invocation_mode = "OTHER_MODE" + + ti = MagicMock() + ti.try_number = 1 + ti.xcom_pull.return_value = None + context = self.make_context(ti) + + result = sensor.poke(context) + assert result is False + + @patch("cosmos.operators.local.AbstractDbtLocalBase.build_and_run_cmd") + def test_task_retry(self, mock_build_and_run_cmd): + sensor = self.make_sensor() + ti = MagicMock() + ti.try_number = 2 + ti.xcom_pull.return_value = None + context = self.make_context(ti) + sensor.poke(context) - mock_retry.assert_called_once_with(2, context) + mock_build_and_run_cmd.assert_called_once() + + def test_handle_task_retry(self): + sensor = self.make_sensor() + ti = MagicMock() + ti.task.dag.get_task.return_value.add_cmd_flags.return_value = ["--select", "some_model", "--threads", "2"] + context = self.make_context(ti) + sensor.build_and_run_cmd = MagicMock() + + result = sensor._handle_task_retry(2, context) + + assert result is True + sensor.build_and_run_cmd.assert_called_once() + args, kwargs = sensor.build_and_run_cmd.call_args + assert "--select" in kwargs["cmd_flags"] + assert MODEL_UNIQUE_ID.split(".")[-1] in kwargs["cmd_flags"] + + def test_filter_flags(self): + flags = ["--select", "model", "--exclude", "other", "--threads", "2"] + expected = ["--threads", "2"] + + result = DbtConsumerWatcherSensor._filter_flags(flags) + + assert result == expected + + def test_get_status_from_run_results_success(self): + sensor = self.make_sensor() + ti = MagicMock() + ti.xcom_pull.return_value = ENCODED_RUN_RESULTS + + result = sensor._get_status_from_run_results(ti) + assert result == "success" + + def test_get_status_from_run_results_none(self): + sensor = self.make_sensor() + ti = MagicMock() + ti.xcom_pull.return_value = None + + result = sensor._get_status_from_run_results(ti) + assert result is None + + def test_get_status_from_events_success(self): + sensor = self.make_sensor() + ti = MagicMock() + ti.xcom_pull.side_effect = [None, ENCODED_EVENT] + + result = sensor._get_status_from_events(ti) + assert result == "success" + + def test_get_status_from_events_none(self): + sensor = self.make_sensor() + ti = MagicMock() + ti.xcom_pull.side_effect = [None, None] + + result = sensor._get_status_from_events(ti) + assert result is None From 0d6088a68540b5691290ff4418bc45e3aa0b7def Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 30 Sep 2025 17:37:00 +0100 Subject: [PATCH 05/19] Fix minor issues --- cosmos/airflow/graph.py | 27 +++++++++++++-------------- cosmos/settings.py | 2 -- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 926f0c3b41..d7277683c0 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -224,9 +224,6 @@ def _get_task_id_and_args( else: task_id = task_name - if execution_mode == ExecutionMode.WATCHER: - args_update["model_unique_id"] = node.unique_id - return task_id, args_update @@ -343,7 +340,6 @@ def create_task_metadata( normalize_task_id=normalize_task_id, normalize_task_display_name=normalize_task_display_name, resource_suffix=r"source", - include_resource_type=True, execution_mode=execution_mode, ) if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL: @@ -519,13 +515,13 @@ def _add_dbt_setup_async_task( tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task -def _add_producer( +def _add_producer_watcher( dag: DAG, task_args: dict[str, Any], tasks_map: dict[str, Any], task_group: TaskGroup | None, render_config: RenderConfig | None = None, -) -> None: +) -> str: producer_task_args = task_args.copy() @@ -548,6 +544,7 @@ def _add_producer( task.trigger_rule = "always" tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task + return producer_airflow_task.task_id def should_create_detached_nodes(render_config: RenderConfig) -> bool: @@ -692,6 +689,16 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro if execution_mode == ExecutionMode.AIRFLOW_ASYNC: virtualenv_dir = task_args.pop("virtualenv_dir", None) + if execution_mode == ExecutionMode.WATCHER: + producer_watcher_task_id = _add_producer_watcher( + dag, + task_args, + tasks_map, + task_group, + render_config=render_config, + ) + task_args["producer_watcher_task_id"] = producer_watcher_task_id + for node_id, node in nodes.items(): conversion_function = node_converters.get(node.resource_type, generate_task_or_group) if conversion_function != generate_task_or_group: @@ -755,14 +762,6 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro create_airflow_task_dependencies(nodes, tasks_map) - if execution_mode == ExecutionMode.WATCHER: - _add_producer( - dag, - task_args, - tasks_map, - task_group, - ) - if settings.enable_setup_async_task: _add_dbt_setup_async_task( dag, diff --git a/cosmos/settings.py b/cosmos/settings.py index e9eb3eb72f..a4f35aa74d 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -14,8 +14,6 @@ DEFAULT_OPENLINEAGE_NAMESPACE, ) -test_mode = conf.getboolean("cosmos", "test_mode", fallback=False) - # In MacOS users may want to set the envvar `TMPDIR` if they do not want the value of the temp directory to change DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME) cache_dir = Path(conf.get("cosmos", "cache_dir", fallback=DEFAULT_CACHE_DIR) or DEFAULT_CACHE_DIR) From 2719484087561468ac5978b77c04112a3075c371 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 3 Oct 2025 11:22:12 +0100 Subject: [PATCH 06/19] Add watcher example DAG --- dev/dags/example_watcher.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 dev/dags/example_watcher.py diff --git a/dev/dags/example_watcher.py b/dev/dags/example_watcher.py new file mode 100644 index 0000000000..bbf01fdcd1 --- /dev/null +++ b/dev/dags/example_watcher.py @@ -0,0 +1,51 @@ +""" +An example DAG that uses Cosmos to render a dbt project into an Airflow DAG. +""" + +import os +from datetime import datetime +from pathlib import Path + +# [START cosmos_init_imports] +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode + +# [END cosmos_init_imports] +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_NAME = os.getenv("DBT_PROJECT_NAME", "jaffle_shop") +DBT_PROJECT_PATH = DBT_ROOT_PATH / DBT_PROJECT_NAME + + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START local_example] +example_watcher = DbtDag( + # dbt/cosmos-specific parameters + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, + ), + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args={ + "install_deps": True, # install any necessary dependencies before running any dbt command + "full_refresh": True, # used only in dbt commands that support this flag + }, + # normal dag parameters + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="example_watcher", + default_args={"retries": 0}, +) +# [END local_example] From 66ff659bc0330ab29e972cf3b52a284b9631f0b5 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 3 Oct 2025 11:23:37 +0100 Subject: [PATCH 07/19] Fix dbt project hash locally in test graph --- tests/dbt/test_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index bbe2186f63..7a242d57b8 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -1908,7 +1908,7 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir assert hash_args == "d41d8cd98f00b204e9800998ecf8427e" if sys.platform == "darwin": # We faced inconsistent hashing versions depending on the version of MacOS/Linux - the following line aims to address these. - assert hash_dir in ("481324dabe926f5cf6352b05e5ebe5d7", "60c08a4730a39d03d89f0f87a8ff3931") + assert hash_dir in ("7f64aab068fb7fcf912765605210bf02", "60c08a4730a39d03d89f0f87a8ff3931") else: assert hash_dir == "60c08a4730a39d03d89f0f87a8ff3931" From a411de941a3e29d8a0e001ed47152c42aafa7810 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 3 Oct 2025 11:28:53 +0100 Subject: [PATCH 08/19] Revert uninteded rebase change that was breaking unittests --- cosmos/operators/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index c9f8754f26..02b3192535 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -309,7 +309,7 @@ def build_and_run_cmd( ) -> Any: """Override this method for the operator to execute the dbt command""" - def execute(self, context: Context) -> Any | None: # type: ignore + def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) From 8c1b63f2594a10d2b33a9a9f1e64898efc1dd291 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 3 Oct 2025 11:35:10 +0100 Subject: [PATCH 09/19] Remove unnecessary blank line --- dev/dags/basic_cosmos_dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dev/dags/basic_cosmos_dag.py b/dev/dags/basic_cosmos_dag.py index 7e7944a346..dbdfa04804 100644 --- a/dev/dags/basic_cosmos_dag.py +++ b/dev/dags/basic_cosmos_dag.py @@ -28,7 +28,6 @@ ), ) - # [START local_example] basic_cosmos_dag = DbtDag( # dbt/cosmos-specific parameters From db54d177b48a17d92173fb88c261207397508750 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 3 Oct 2025 11:55:11 +0100 Subject: [PATCH 10/19] Fix mypy check --- cosmos/operators/watcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index e92604d713..88173fd0ba 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -320,7 +320,7 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) -class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): +class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # type: ignore[misc] """ Watches for the progress of dbt snapshot execution, run by the producer task (DbtProducerWatcherOperator). """ From d9923861d9034cc77de083750d9501b2fa6d4c5c Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 13:34:56 +0100 Subject: [PATCH 11/19] Try to fix integration tests that failed --- cosmos/airflow/graph.py | 23 +++++++++++------------ cosmos/operators/watcher.py | 2 +- dev/dags/example_watcher.py | 20 +++++++++++++------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 7db2143406..9b00456d0b 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -562,12 +562,11 @@ def _add_producer_watcher( arguments=producer_task_args, ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) - for task_id, task in tasks_map.items(): # we want to make the producer task to be the parent of the root dbt nodes, without blocking them from sensing XCom if not task.upstream_list: producer_airflow_task >> task - task.trigger_rule = "always" + task.trigger_rule = task_args.get("trigger_rule", "always") tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task return producer_airflow_task.task_id @@ -716,16 +715,6 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro if execution_mode == ExecutionMode.AIRFLOW_ASYNC: virtualenv_dir = task_args.pop("virtualenv_dir", None) - if execution_mode == ExecutionMode.WATCHER: - producer_watcher_task_id = _add_producer_watcher( - dag, - task_args, - tasks_map, - task_group, - render_config=render_config, - ) - task_args["producer_watcher_task_id"] = producer_watcher_task_id - for node_id, node in nodes.items(): conversion_function = node_converters.get(node.resource_type, generate_task_or_group) if conversion_function != generate_task_or_group: @@ -756,6 +745,16 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro logger.debug(f"Conversion of <{node.unique_id}> was successful!") tasks_map[node_id] = task_or_group + if execution_mode == ExecutionMode.WATCHER: + producer_watcher_task_id = _add_producer_watcher( + dag, + task_args, + tasks_map, + task_group, + render_config=render_config, + ) + task_args["producer_watcher_task_id"] = producer_watcher_task_id + # If test_behaviour=="after_all", there will be one test task, run by the end of the DAG # The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks) if test_behavior == TestBehavior.AFTER_ALL: diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 88173fd0ba..f0250130c1 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -160,7 +160,7 @@ def __init__( project_dir: str | None = None, profiles_dir: str | None = None, producer_task_id: str = PRODUCER_WATCHER_TASK_ID, - poke_interval: int = 20, + poke_interval: int = 10, timeout: int = 60 * 60, # 1 h safety valve **kwargs: Any, ) -> None: diff --git a/dev/dags/example_watcher.py b/dev/dags/example_watcher.py index bbf01fdcd1..6f87aa94e1 100644 --- a/dev/dags/example_watcher.py +++ b/dev/dags/example_watcher.py @@ -3,7 +3,7 @@ """ import os -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path # [START cosmos_init_imports] @@ -29,7 +29,16 @@ ), ) -# [START local_example] + +operator_args = { + "install_deps": True, # install any necessary dependencies before running any dbt command + "execution_timeout": timedelta(seconds=120), + # Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: + "trigger_rule": "all_success" if os.getenv("CI", False) else "always", +} + + +# [START example_watcher] example_watcher = DbtDag( # dbt/cosmos-specific parameters execution_config=ExecutionConfig( @@ -37,10 +46,7 @@ ), project_config=ProjectConfig(DBT_PROJECT_PATH), profile_config=profile_config, - operator_args={ - "install_deps": True, # install any necessary dependencies before running any dbt command - "full_refresh": True, # used only in dbt commands that support this flag - }, + operator_args=operator_args, # normal dag parameters schedule="@daily", start_date=datetime(2023, 1, 1), @@ -48,4 +54,4 @@ dag_id="example_watcher", default_args={"retries": 0}, ) -# [END local_example] +# [END example_watcher] From 99d165f01dd48c35b4ab0bf5db7faee69df135e9 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 15:08:49 +0100 Subject: [PATCH 12/19] Try to improve test coverage --- cosmos/operators/watcher.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index f0250130c1..9f5f6e8949 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -327,9 +327,6 @@ class DbtSnapshotWatcherOperator(DbtSnapshotMixin, DbtConsumerWatcherSensor): # template_fields: tuple[str] = DbtConsumerWatcherSensor.template_fields # type: ignore[operator] - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - class DbtSourceWatcherOperator(DbtSourceLocalOperator): """ @@ -338,9 +335,6 @@ class DbtSourceWatcherOperator(DbtSourceLocalOperator): template_fields: Sequence[str] = DbtSourceLocalOperator.template_fields - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - class DbtRunWatcherOperator(DbtConsumerWatcherSensor): """ From b980b69bcfef453f5c078b833a3c9307e135633d Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 15:12:20 +0100 Subject: [PATCH 13/19] Improve test coverage --- tests/operators/test_watcher.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 0b9da1cb69..e8de7cdcf8 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -10,6 +10,7 @@ from cosmos.config import InvocationMode from cosmos.operators.watcher import ( PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT, + DbtBuildWatcherOperator, DbtConsumerWatcherSensor, DbtProducerWatcherOperator, ) @@ -380,3 +381,15 @@ def test_get_status_from_events_none(self): result = sensor._get_status_from_events(ti) assert result is None + + +class TestDbtBuildWatcherOperator: + + def test_dbt_build_watcher_operator_raises_not_implemented_error(self): + expected_message = ( + "`ExecutionMode.WATCHER` does not expose a DbtBuild operator, " + "since the build command is executed by the producer task." + ) + + with pytest.raises(NotImplementedError, match=expected_message): + DbtBuildWatcherOperator() From 63e5d7f39eff64930df37cb1f8e1eb1f0f61dd3a Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 15:56:02 +0100 Subject: [PATCH 14/19] Add test to confirm topology and how watcher DAG is built --- tests/operators/test_watcher.py | 86 +++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index e8de7cdcf8..4e2d18a110 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -1,18 +1,45 @@ import base64 import json import zlib +from datetime import datetime, timedelta +from pathlib import Path from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from airflow.exceptions import AirflowException +from airflow.utils.state import DagRunState +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig from cosmos.config import InvocationMode +from cosmos.constants import ExecutionMode from cosmos.operators.watcher import ( PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT, DbtBuildWatcherOperator, DbtConsumerWatcherSensor, DbtProducerWatcherOperator, + DbtRunWatcherOperator, + DbtSeedWatcherOperator, + DbtTestWatcherOperator, +) +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" +DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml" + + +project_config = ProjectConfig( + dbt_project_path=DBT_PROJECT_PATH, +) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), ) @@ -393,3 +420,62 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self): with pytest.raises(NotImplementedError, match=expected_message): DbtBuildWatcherOperator() + + +@pytest.mark.integration +def test_dbt_dag_with_watcher(): + """ + Run a DbtDag using dbt Fusion. + Confirm it succeeds and has the expected amount of both: + - dbt resources + - Airflow tasks + And that the tasks are in the expected topological order. + """ + watcher_dag = DbtDag( + project_config=project_config, + profile_config=profile_config, + start_date=datetime(2023, 1, 1), + dag_id="watcher_dag", + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, + ), + operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, + ) + outcome = watcher_dag.test() + assert outcome.state == DagRunState.SUCCESS + + assert len(watcher_dag.dbt_graph.filtered_nodes) == 26 + + assert len(watcher_dag.task_dict) == 14 + tasks_names = [task.task_id for task in watcher_dag.topological_sort()] + expected_task_names = [ + "dbt_producer_watcher", + "raw_customers_seed", + "raw_orders_seed", + "raw_payments_seed", + "stg_customers.run", + "stg_customers.test", + "stg_orders.run", + "stg_orders.test", + "stg_payments.run", + "stg_payments.test", + "customers.run", + "customers.test", + "orders.run", + "orders.test", + ] + assert tasks_names == expected_task_names + isinstance(watcher_dag.task_dict["dbt_producer_watcher"], DbtProducerWatcherOperator) + isinstance(watcher_dag.task_dict["raw_customers_seed"], DbtSeedWatcherOperator) + isinstance(watcher_dag.task_dict["raw_orders_seed"], DbtSeedWatcherOperator) + isinstance(watcher_dag.task_dict["raw_payments_seed"], DbtSeedWatcherOperator) + isinstance(watcher_dag.task_dict["stg_customers.run"], DbtRunWatcherOperator) + isinstance(watcher_dag.task_dict["stg_orders.run"], DbtRunWatcherOperator) + isinstance(watcher_dag.task_dict["stg_payments.run"], DbtRunWatcherOperator) + isinstance(watcher_dag.task_dict["customers.run"], DbtRunWatcherOperator) + isinstance(watcher_dag.task_dict["orders.run"], DbtRunWatcherOperator) + isinstance(watcher_dag.task_dict["stg_customers.test"], DbtTestWatcherOperator) + isinstance(watcher_dag.task_dict["stg_orders.test"], DbtTestWatcherOperator) + isinstance(watcher_dag.task_dict["stg_payments.test"], DbtTestWatcherOperator) + isinstance(watcher_dag.task_dict["customers.test"], DbtTestWatcherOperator) + isinstance(watcher_dag.task_dict["orders.test"], DbtTestWatcherOperator) From 8c0cb15c06086220283e30f6667f2bfce78df72c Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 16:24:20 +0100 Subject: [PATCH 15/19] Try to fix test that did not work in CI --- tests/operators/test_watcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 4e2d18a110..03e00f5819 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -10,7 +10,7 @@ from airflow.exceptions import AirflowException from airflow.utils.state import DagRunState -from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig from cosmos.config import InvocationMode from cosmos.constants import ExecutionMode from cosmos.operators.watcher import ( @@ -439,6 +439,7 @@ def test_dbt_dag_with_watcher(): execution_config=ExecutionConfig( execution_mode=ExecutionMode.WATCHER, ), + render_config=RenderConfig(emit_datasets=False), operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, ) outcome = watcher_dag.test() From 5c4990bf328978caef6ea83cb467c67c93807e3b Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 16:29:40 +0100 Subject: [PATCH 16/19] Fix integration test --- tests/operators/test_watcher.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 03e00f5819..649dc0198d 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -9,6 +9,7 @@ import pytest from airflow.exceptions import AirflowException from airflow.utils.state import DagRunState +from packaging.version import Version from cosmos import DbtDag, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig from cosmos.config import InvocationMode @@ -23,6 +24,7 @@ DbtTestWatcherOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from tests.utils import AIRFLOW_VERSION DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml" @@ -422,6 +424,7 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self): DbtBuildWatcherOperator() +@pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release") @pytest.mark.integration def test_dbt_dag_with_watcher(): """ From 1e1c2c117ba5fcf7df413f3c5d772a21f9874b42 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 17:49:49 +0100 Subject: [PATCH 17/19] Fix weight of consumer tasks --- cosmos/operators/watcher.py | 2 +- dev/dags/example_watcher.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 9f5f6e8949..b642cb2a88 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -165,7 +165,7 @@ def __init__( **kwargs: Any, ) -> None: extra_context = kwargs.pop("extra_context") if "extra_context" in kwargs else {} - kwargs.setdefault("priority_weight", PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) + kwargs.setdefault("priority_weight", CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT) kwargs.setdefault("weight_rule", WEIGHT_RULE) super().__init__( poke_interval=poke_interval, diff --git a/dev/dags/example_watcher.py b/dev/dags/example_watcher.py index 6f87aa94e1..9feaf9a2e1 100644 --- a/dev/dags/example_watcher.py +++ b/dev/dags/example_watcher.py @@ -33,9 +33,10 @@ operator_args = { "install_deps": True, # install any necessary dependencies before running any dbt command "execution_timeout": timedelta(seconds=120), - # Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: - "trigger_rule": "all_success" if os.getenv("CI", False) else "always", } +# Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: +# if os.getenv("CI"): +# operator_args["trigger_rule"] = "all_success" # [START example_watcher] From 1f2dcb3c6a76764088375f3d7068bebdb5ad534f Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 18:07:38 +0100 Subject: [PATCH 18/19] Fix CI --- dev/dags/example_watcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dev/dags/example_watcher.py b/dev/dags/example_watcher.py index 9feaf9a2e1..0ba685e0d5 100644 --- a/dev/dags/example_watcher.py +++ b/dev/dags/example_watcher.py @@ -34,9 +34,10 @@ "install_deps": True, # install any necessary dependencies before running any dbt command "execution_timeout": timedelta(seconds=120), } + # Currently airflow dags test ignores priority_weight and weight_rule, for this reason, we're setting the following in the CI only: -# if os.getenv("CI"): -# operator_args["trigger_rule"] = "all_success" +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" # [START example_watcher] From 3d1f5b00fb8e8778d2ec997d2eee2434cbf72ba2 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 6 Oct 2025 18:41:06 +0100 Subject: [PATCH 19/19] Update tests/operators/test_watcher.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/operators/test_watcher.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 649dc0198d..677aeaceb9 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -469,17 +469,17 @@ def test_dbt_dag_with_watcher(): "orders.test", ] assert tasks_names == expected_task_names - isinstance(watcher_dag.task_dict["dbt_producer_watcher"], DbtProducerWatcherOperator) - isinstance(watcher_dag.task_dict["raw_customers_seed"], DbtSeedWatcherOperator) - isinstance(watcher_dag.task_dict["raw_orders_seed"], DbtSeedWatcherOperator) - isinstance(watcher_dag.task_dict["raw_payments_seed"], DbtSeedWatcherOperator) - isinstance(watcher_dag.task_dict["stg_customers.run"], DbtRunWatcherOperator) - isinstance(watcher_dag.task_dict["stg_orders.run"], DbtRunWatcherOperator) - isinstance(watcher_dag.task_dict["stg_payments.run"], DbtRunWatcherOperator) - isinstance(watcher_dag.task_dict["customers.run"], DbtRunWatcherOperator) - isinstance(watcher_dag.task_dict["orders.run"], DbtRunWatcherOperator) - isinstance(watcher_dag.task_dict["stg_customers.test"], DbtTestWatcherOperator) - isinstance(watcher_dag.task_dict["stg_orders.test"], DbtTestWatcherOperator) - isinstance(watcher_dag.task_dict["stg_payments.test"], DbtTestWatcherOperator) - isinstance(watcher_dag.task_dict["customers.test"], DbtTestWatcherOperator) - isinstance(watcher_dag.task_dict["orders.test"], DbtTestWatcherOperator) + assert isinstance(watcher_dag.task_dict["dbt_producer_watcher"], DbtProducerWatcherOperator) + assert isinstance(watcher_dag.task_dict["raw_customers_seed"], DbtSeedWatcherOperator) + assert isinstance(watcher_dag.task_dict["raw_orders_seed"], DbtSeedWatcherOperator) + assert isinstance(watcher_dag.task_dict["raw_payments_seed"], DbtSeedWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_customers.run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_orders.run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_payments.run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["customers.run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["orders.run"], DbtRunWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_customers.test"], DbtTestWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_orders.test"], DbtTestWatcherOperator) + assert isinstance(watcher_dag.task_dict["stg_payments.test"], DbtTestWatcherOperator) + assert isinstance(watcher_dag.task_dict["customers.test"], DbtTestWatcherOperator) + assert isinstance(watcher_dag.task_dict["orders.test"], DbtTestWatcherOperator)