diff --git a/cosmos/airflow/task_group.py b/cosmos/airflow/task_group.py index a7eb32b300..efb94dd20b 100644 --- a/cosmos/airflow/task_group.py +++ b/cosmos/airflow/task_group.py @@ -12,6 +12,9 @@ except ImportError: from airflow.utils.task_group import TaskGroup +from cosmos import settings +from cosmos.config import ExecutionConfig +from cosmos.constants import ExecutionMode from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs @@ -26,7 +29,52 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: + self._execution_config = kwargs.get("execution_config") kwargs["group_id"] = group_id TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs)) kwargs["task_group"] = self DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) + + @property + def is_watcher_mode(self) -> bool: + """Whether this task group uses a watcher execution mode.""" + return isinstance(self._execution_config, ExecutionConfig) and self._execution_config.execution_mode in ( + ExecutionMode.WATCHER, + ExecutionMode.WATCHER_KUBERNETES, + ) + + def _set_watcher_downstream_tasks_trigger_rule(self, downstream_tasks: Any) -> None: + """In watcher mode the producer task may be skipped on retry. + + Set ``trigger_rule="none_failed_min_one_success"`` on downstream tasks so the skip + does not propagate outside the task group. + """ + if not self.is_watcher_mode or not settings.propagate_watcher_trigger_rule: + return + items = downstream_tasks if isinstance(downstream_tasks, (list, tuple)) else [downstream_tasks] + for item in items: + if isinstance(item, TaskGroup): + # For downstream TaskGroups, only set trigger_rule on root tasks + # (tasks with no upstream within the group) to avoid altering + # the group's internal dependency semantics. + for root_task in item.get_roots(): + if hasattr(root_task, "trigger_rule"): + root_task.trigger_rule = "none_failed_min_one_success" # type: ignore[assignment] + elif hasattr(item, "trigger_rule"): + item.trigger_rule = "none_failed_min_one_success" # type: ignore[assignment] + + def __rshift__(self, other: Any) -> Any: + # dbt_group >> post_dbt — post_dbt is downstream of dbt_group + result = super().__rshift__(other) + self._set_watcher_downstream_tasks_trigger_rule(other) + return result + + def __rlshift__(self, other: Any) -> Any: + # other << dbt_group — other is downstream of dbt_group + result = super().__rlshift__(other) + self._set_watcher_downstream_tasks_trigger_rule(other) + return result + + def set_downstream(self, task_or_task_list: Any, edge_modifier: Any = None) -> None: # type: ignore[override] + super().set_downstream(task_or_task_list, edge_modifier=edge_modifier) + self._set_watcher_downstream_tasks_trigger_rule(task_or_task_list) diff --git a/cosmos/settings.py b/cosmos/settings.py index 40255fac83..3d4922d4dc 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -68,6 +68,7 @@ # this setting allows retries to run on a queue with larger resources, which is often necessary for larger dbt projects # this would also be used to run the producer task watcher_dbt_execution_queue = conf.get("cosmos", "watcher_dbt_execution_queue", fallback=None) +propagate_watcher_trigger_rule = conf.getboolean("cosmos", "propagate_watcher_trigger_rule", fallback=False) # The following environment variable is populated in Astro Cloud in_astro_cloud = os.getenv("ASTRONOMER_ENVIRONMENT") == "cloud" diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml b/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml new file mode 100644 index 0000000000..196cb8ba40 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/dbt_project.yml @@ -0,0 +1,26 @@ +name: 'watcher_downstream_not_skipped' + +config-version: 2 +version: '0.1' + +profile: 'default' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +on-run-start: + - "CREATE SEQUENCE IF NOT EXISTS {{ target.schema }}._cosmos_fail_once_seq" + +models: + watcher_downstream_not_skipped: + +materialized: table diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql new file mode 100644 index 0000000000..2031e20406 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_a.sql @@ -0,0 +1,5 @@ +select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email +union all +select 2, 'Bob', 'Jones', 'bob@example.com' +union all +select 3, 'Charlie', 'Brown', 'charlie@example.com' diff --git a/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql new file mode 100644 index 0000000000..bf5f905bb3 --- /dev/null +++ b/dev/dags/dbt/watcher_downstream_not_skipped/models/model_retry.sql @@ -0,0 +1,12 @@ +{{ + config( + pre_hook=[ + "DO $$ BEGIN IF nextval('{{ target.schema }}._cosmos_fail_once_seq') <= 1 THEN RAISE EXCEPTION 'fail_once: intentional first-run failure'; END IF; END $$" + ] + ) +}} + +select + id, + first_name +from {{ ref('model_a') }} diff --git a/dev/failed_dags/example_watcher_downstream_not_skipped.py b/dev/failed_dags/example_watcher_downstream_not_skipped.py new file mode 100644 index 0000000000..6614bd6c02 --- /dev/null +++ b/dev/failed_dags/example_watcher_downstream_not_skipped.py @@ -0,0 +1,99 @@ +""" +Airflow DAG to verify watcher downstream-task behavior when a producer task is skipped on retry. + +In watcher mode, when a dbt model fails the producer retries and raises AirflowSkipException. +By default, the skip propagates to all tasks downstream of the TaskGroup (e.g. post_dbt), +even though the consumer tasks inside the group succeeded. + +To prevent this, enable the ``propagate_watcher_trigger_rule`` Cosmos setting:: + + export AIRFLOW__COSMOS__PROPAGATE_WATCHER_TRIGGER_RULE=True + +When enabled, Cosmos automatically sets ``trigger_rule="none_failed"`` on tasks downstream +of a watcher DbtTaskGroup, so the producer skip does not propagate. + +Without this setting, post_dbt will be skipped. + +This DAG demonstrates the behaviour with the setting enabled: +- model_a succeeds in the producer and the consumer reads the result from XCom +- model_retry fails on the first attempt (via a fail_once sequence pre-hook) but succeeds + on the consumer retry fallback +- post_dbt (an EmptyOperator downstream of the group) runs successfully — it is NOT skipped + +A cleanup task drops the PostgreSQL sequence so the DAG can be re-triggered. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from airflow.models import DAG +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + +try: + from airflow.providers.standard.operators.empty import EmptyOperator +except ImportError: + from airflow.operators.empty import EmptyOperator + +from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_downstream_not_skipped" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), + "depends_on_past": True, +} + +with DAG( + dag_id="example_watcher_downstream_not_skipped", + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + default_args=default_args, +): + dbt_group = DbtTaskGroup( + group_id="watcher_downstream_not_skipped", + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + ) + + cleanup = SQLExecuteQueryOperator( + task_id="drop_fail_once_marker", + conn_id="example_conn", + sql="DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq;", + trigger_rule="all_done", + ) + + post_dbt = EmptyOperator(task_id="post_dbt") + + dbt_group >> post_dbt + dbt_group >> cleanup diff --git a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst index 5375cd461e..b9dec4855b 100644 --- a/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst +++ b/docs/guides/run_dbt/airflow-worker/watcher-execution-mode.rst @@ -224,23 +224,85 @@ If a branch of the DAG fails, users can clear the status of a failed consumer ta **Producer retry behavior** -.. versionadded:: 1.12.2 +.. versionchanged:: 1.14.1 -When the ``DbtProducerWatcherOperator`` is triggered for a retry (try_number > 1), it will not re-run the dbt build command and will succeed. In previous versions of Cosmos, the producer task would fail during retries. -This behavior is designed to support TaskGroup-level retries, as reported in `#2282 `_. +When the ``DbtProducerWatcherOperator`` is triggered for a retry (``try_number > 1``), it raises +``AirflowSkipException`` instead of re-running the ``dbt build`` command. Before skipping, it restores +XCom values from a backup so that consumer sensors can still read model statuses from the first attempt. -**Why this matters:** +**XCom backup and restore:** -- In earlier versions, attempting to retry the producer task would raise an ``AirflowException``, causing the retry to fail immediately. -- Now, the producer gracefully skips execution on retries, logging an informational message explaining that the retry was skipped to avoid running a second ``dbt build``. -- This allows users to retry entire TaskGroups and/or DAGs without the producer task blocking the retry flow. +During execution, each XCom push is incrementally backed up to an Airflow Variable. This ensures that +when the producer fails and Airflow clears XCom entries before the retry, the backed-up values can be +restored. On a successful run, the backup Variable is automatically deleted to avoid stale data +accumulating over time. + +**How consumer retries work:** + +1. The producer runs ``dbt build`` — some models succeed, some fail. +2. The producer task fails, and XCom values are backed up to a Variable. +3. On retry, the producer restores XCom from the Variable and raises ``AirflowSkipException``. +4. Consumer sensors read model statuses from the restored XCom. +5. Consumers for successful models complete immediately. +6. Consumers for failed models detect the error status and raise ``AirflowException``. +7. On their own retry, failed consumers fall back to running dbt individually for their model + (via ``_fallback_to_non_watcher_run``), behaving like ``ExecutionMode.LOCAL``. **Important considerations:** -- Retries are no longer forced to ``0`` by Cosmos, since 1.14.0. Users may configure ``retries`` freely on the producer task. On any retry attempt (``try_number > 1``), the producer gracefully skips execution and returns success — it will not re-run the ``dbt build`` command. This means retrying the producer (or clearing an entire TaskGroup) is safe and will not cause duplicate dbt builds. During the retry of the sensor tasks, they will effectively run the corresponding dbt commands. +- Users may configure ``retries`` freely on the producer task. On any retry attempt (``try_number > 1``), + the producer gracefully skips execution — it will not re-run the ``dbt build`` command. +- Retrying the producer (or clearing an entire TaskGroup) is safe and will not cause duplicate dbt builds. +- During the retry of the sensor tasks, they will effectively run the corresponding dbt commands. +- When using ``DbtTaskGroup``, the producer's skip state propagates to tasks downstream of the group + by default — Airflow's ``trigger_rule="all_success"`` causes them to be skipped even when all consumer + tasks succeeded. To prevent this, either: + + - Set the ``propagate_watcher_trigger_rule`` Cosmos setting to ``True`` (see below), which automatically + sets ``trigger_rule="none_failed_min_one_success"`` on tasks downstream of the watcher ``DbtTaskGroup``. + - Or manually set ``trigger_rule="none_failed_min_one_success"`` on your downstream tasks. + + This issue does not affect ``DbtDag``, where the producer-to-consumer dependency is handled differently. The overall retry behavior will be further improved once `#1978 `_ is implemented. +Propagate Watcher Trigger Rule +.............................. + +.. versionadded:: 1.14.1 + +When using ``DbtTaskGroup`` in watcher mode, the producer task may be skipped on retry. By default, +this causes Airflow to skip any tasks downstream of the task group (due to ``trigger_rule="all_success"``). + +The ``propagate_watcher_trigger_rule`` setting makes Cosmos automatically set ``trigger_rule="none_failed_min_one_success"`` +on tasks wired downstream of a watcher ``DbtTaskGroup`` when the dependency is created from the +``DbtTaskGroup`` side: + +- ``dbt_group >> task`` +- ``dbt_group.set_downstream(task)`` + +**Configuration:** + +.. code-block:: ini + + [cosmos] + propagate_watcher_trigger_rule = True + +Or via environment variable: + +.. code-block:: bash + + export AIRFLOW__COSMOS__PROPAGATE_WATCHER_TRIGGER_RULE=True + +**Limitations:** + +- This does not work when the dependency is created from the downstream task side + (e.g., ``task.set_upstream(dbt_group)`` or ``task << dbt_group``), since Cosmos cannot + intercept methods called on tasks it does not control. +- When enabled, it overrides any user-defined ``trigger_rule`` on downstream tasks with ``"none_failed_min_one_success"``. + If you need a different ``trigger_rule`` on a downstream task, do not enable this setting and instead + set ``trigger_rule="none_failed_min_one_success"`` manually on the specific tasks that need it. + Watcher dbt Execution Queue ........................... diff --git a/docs/reference/configs/cosmos-conf.rst b/docs/reference/configs/cosmos-conf.rst index 39e0bf2042..1d77745912 100644 --- a/docs/reference/configs/cosmos-conf.rst +++ b/docs/reference/configs/cosmos-conf.rst @@ -355,6 +355,29 @@ This page lists all available Airflow configurations that affect ``astronomer-co - Default: ``None`` - Environment Variable: ``AIRFLOW__COSMOS__WATCHER_DBT_EXECUTION_QUEUE`` +.. _propagate_watcher_trigger_rule: + +`propagate_watcher_trigger_rule`_: + (Introduced in Cosmos 1.14.1) When using ``DbtTaskGroup`` in watcher mode, the producer task may be + skipped on retry (via ``AirflowSkipException``). By default, this causes Airflow to skip any tasks + downstream of the task group due to the default ``trigger_rule="all_success"``. + + When this setting is enabled, Cosmos automatically sets ``trigger_rule="none_failed_min_one_success"`` on tasks wired + downstream of a watcher ``DbtTaskGroup`` when the dependency is created from the ``DbtTaskGroup`` side: + + - ``dbt_group >> task`` + - ``dbt_group.set_downstream(task)`` + + **Limitations:** + + - Does not work when the dependency is created from the downstream task side + (e.g., ``task.set_upstream(dbt_group)`` or ``task << dbt_group``), since Cosmos cannot + intercept methods called on tasks it does not control. + - When enabled, overrides any user-defined ``trigger_rule`` on downstream tasks with ``"none_failed_min_one_success"``. + + - Default: ``False`` + - Environment Variable: ``AIRFLOW__COSMOS__PROPAGATE_WATCHER_TRIGGER_RULE`` + [openlineage] +++++++++++++ diff --git a/tests/airflow/test_task_group.py b/tests/airflow/test_task_group.py new file mode 100644 index 0000000000..9ea1f64814 --- /dev/null +++ b/tests/airflow/test_task_group.py @@ -0,0 +1,111 @@ +"""Unit tests for cosmos.airflow.task_group module.""" + +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from airflow import DAG + +try: + from airflow.sdk import TaskGroup +except ImportError: + from airflow.utils.task_group import TaskGroup + +from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import ExecutionMode, TestBehavior +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + + +@patch("cosmos.settings.enable_cache", False) +def _make_task_group(execution_mode): + """Create a DbtTaskGroup inside a DAG context with caching disabled.""" + with DAG(dag_id="test_dag", start_date=datetime(2023, 1, 1)): + return DbtTaskGroup( + group_id="test_group", + execution_config=ExecutionConfig(execution_mode=execution_mode), + project_config=ProjectConfig(dbt_project_path=DBT_PROJECT_PATH), + profile_config=profile_config, + render_config=RenderConfig(test_behavior=TestBehavior.NONE), + ) + + +@pytest.mark.parametrize( + "execution_mode, expected", + [ + (ExecutionMode.WATCHER, True), + (ExecutionMode.LOCAL, False), + ], +) +def test_is_watcher_mode(execution_mode, expected): + tg = _make_task_group(execution_mode) + assert tg.is_watcher_mode is expected + + +@patch("cosmos.settings.propagate_watcher_trigger_rule", True) +def test_rshift_sets_trigger_rule_in_watcher_mode(): + tg = _make_task_group(ExecutionMode.WATCHER) + task = MagicMock(trigger_rule="all_success") + + tg >> task + + assert task.trigger_rule == "none_failed_min_one_success" + + +@patch("cosmos.settings.propagate_watcher_trigger_rule", False) +def test_rshift_noop_when_setting_disabled(): + tg = _make_task_group(ExecutionMode.WATCHER) + task = MagicMock(trigger_rule="all_success") + + tg >> task + + assert task.trigger_rule == "all_success" + + +@patch("cosmos.settings.propagate_watcher_trigger_rule", True) +def test_rshift_sets_trigger_rule_on_downstream_task_group_roots(): + tg = _make_task_group(ExecutionMode.WATCHER) + + root_task = MagicMock(trigger_rule="all_success") + downstream_group = MagicMock() + downstream_group.get_roots.return_value = [root_task] + # isinstance(item, TaskGroup) check needs to pass + downstream_group.__class__ = TaskGroup + + tg >> downstream_group + + assert root_task.trigger_rule == "none_failed_min_one_success" + + +@patch("cosmos.settings.propagate_watcher_trigger_rule", True) +def test_set_downstream_sets_trigger_rule(): + tg = _make_task_group(ExecutionMode.WATCHER) + task = MagicMock(trigger_rule="all_success") + + tg.set_downstream(task) + + assert task.trigger_rule == "none_failed_min_one_success" + + +@patch("cosmos.settings.propagate_watcher_trigger_rule", True) +def test_rlshift_sets_trigger_rule(): + tg = _make_task_group(ExecutionMode.WATCHER) + task = MagicMock(trigger_rule="all_success") + + tg.__rlshift__(task) + + assert task.trigger_rule == "none_failed_min_one_success" diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index f3dbd9cbfc..a99bdad3a7 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -39,6 +39,9 @@ DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml" MULTI_FOLDER_DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/multi_folder" DBT_WATCHER_FAILING_TESTS_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_failing_tests" +DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH = ( + Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_downstream_not_skipped" +) DBT_EXECUTABLE_PATH = Path(__file__).parent.parent.parent / "venv-subprocess/bin/dbt" DBT_PROJECT_WITH_EMPTY_MODEL_PATH = Path(__file__).parent.parent / "sample/dbt_project_with_empty_model" @@ -1929,7 +1932,7 @@ def test_dbt_dag_with_watcher_and_failing_model(caplog): execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, - operator_args={"trigger_rule": "none_failed", "execution_timeout": timedelta(seconds=120)}, + operator_args={"trigger_rule": "none_failed_min_one_success", "execution_timeout": timedelta(seconds=120)}, dagrun_timeout=timedelta(seconds=120), ) outcome = new_test_dag(watcher_dag, expected_dag_state=DagRunState.FAILED) @@ -1956,6 +1959,174 @@ def test_dbt_dag_with_watcher_and_failing_model(caplog): assert dbt_error_message in caplog.text +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("2.10") or AIRFLOW_VERSION == Version("3.1.0"), + reason=( + "dag.test() in Airflow 2.9 hangs when a task fails with retries configured. " + "Airflow 3.1.0 crashes during task finalization (SetRenderedFields) when retrying " + "tasks inside a DbtTaskGroup via dag.test()." + ), +) +@pytest.mark.integration +def test_dbt_task_group_watcher_downstream_skipped_by_default(caplog): + """ + Without propagate_watcher_trigger_rule, tasks downstream of a watcher DbtTaskGroup + are skipped when the producer is skipped on retry. + """ + import psycopg2 + from airflow import DAG + from airflow.hooks.base import BaseHook + + from cosmos import DbtTaskGroup + + try: + from airflow.providers.standard.operators.empty import EmptyOperator + except ImportError: + from airflow.operators.empty import EmptyOperator + + # Reset the fail_once sequence using credentials from the same Airflow connection + airflow_conn = BaseHook.get_connection("example_conn") + with psycopg2.connect( + host=airflow_conn.host, + port=airflow_conn.port or 5432, + dbname=airflow_conn.schema or "postgres", + user=airflow_conn.login, + password=airflow_conn.password, + ) as conn: + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq") + + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + + with DAG( + dag_id="watcher_taskgroup_downstream_skipped", + start_date=datetime(2023, 1, 1), + default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, + dagrun_timeout=timedelta(seconds=120), + ) as dag: + dbt_group = DbtTaskGroup( + group_id="watcher_downstream_not_skipped", + execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), + project_config=ProjectConfig(dbt_project_path=DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH), + profile_config=profile_config, + render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), + operator_args={"trigger_rule": "none_failed_min_one_success", "execution_timeout": timedelta(seconds=120)}, + ) + + # Force producer >> root consumers so dag.test() respects execution order. + # Only root nodes (no upstream within the group) are wired, and their trigger_rule + # is set to none_failed so they run even when the producer is skipped. + # This workaround is only needed for dag.test() — the real scheduler uses priority_weight. + # See https://github.com/apache/airflow/issues/56723 + producer = dag.task_dict["watcher_downstream_not_skipped.dbt_producer_watcher"] + for root_task in dbt_group.get_roots(): + if root_task != producer: + producer >> root_task + # Use "none_failed" (not "none_failed_min_one_success") here because the producer + # is the only direct upstream and it will be skipped — "min_one_success" would + # block the consumer since a skip is not a success. + root_task.trigger_rule = "none_failed" + + post_dbt = EmptyOperator(task_id="post_dbt") + dbt_group >> post_dbt + # Also wire producer >> post_dbt so the producer skip propagates directly, + # simulating what happens in a real deployment where the producer is a leaf + # task in the group and its skip affects downstream tasks. + producer >> post_dbt + + outcome = new_test_dag(dag, expected_dag_state=DagRunState.SUCCESS) + + tis = {ti.task_id: ti for ti in outcome.get_task_instances()} + # The downstream task is skipped because the producer skip propagates (default behaviour) + assert tis["post_dbt"].state == "skipped" + + +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("2.10") or AIRFLOW_VERSION == Version("3.1.0"), + reason=( + "dag.test() in Airflow 2.9 hangs when a task fails with retries configured. " + "Airflow 3.1.0 crashes during task finalization (SetRenderedFields) when retrying " + "tasks inside a DbtTaskGroup via dag.test()." + ), +) +@pytest.mark.integration +@patch("cosmos.settings.propagate_watcher_trigger_rule", True) +def test_dbt_task_group_watcher_downstream_not_skipped_with_setting(caplog): + """ + With propagate_watcher_trigger_rule=True, tasks downstream of a watcher DbtTaskGroup + run successfully even when the producer is skipped on retry. + """ + import psycopg2 + from airflow import DAG + from airflow.hooks.base import BaseHook + + from cosmos import DbtTaskGroup + + # Reset the fail_once sequence using credentials from the same Airflow connection + airflow_conn = BaseHook.get_connection("example_conn") + with psycopg2.connect( + host=airflow_conn.host, + port=airflow_conn.port or 5432, + dbname=airflow_conn.schema or "postgres", + user=airflow_conn.login, + password=airflow_conn.password, + ) as conn: + conn.autocommit = True + with conn.cursor() as cur: + cur.execute("DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq") + + try: + from airflow.providers.standard.operators.empty import EmptyOperator + except ImportError: + from airflow.operators.empty import EmptyOperator + + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + + with DAG( + dag_id="watcher_taskgroup_downstream_not_skipped", + start_date=datetime(2023, 1, 1), + default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, + dagrun_timeout=timedelta(seconds=120), + ) as dag: + dbt_group = DbtTaskGroup( + group_id="watcher_downstream_not_skipped", + execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), + project_config=ProjectConfig(dbt_project_path=DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH), + profile_config=profile_config, + render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), + operator_args={"trigger_rule": "none_failed_min_one_success", "execution_timeout": timedelta(seconds=120)}, + ) + + # Force producer >> root consumers so dag.test() respects execution order. + # See https://github.com/apache/airflow/issues/56723 + producer = dag.task_dict["watcher_downstream_not_skipped.dbt_producer_watcher"] + for root_task in dbt_group.get_roots(): + if root_task != producer: + producer >> root_task + # Use "none_failed" (not "none_failed_min_one_success") here because the producer + # is the only direct upstream and it will be skipped — "min_one_success" would + # block the consumer since a skip is not a success. + root_task.trigger_rule = "none_failed" + + post_dbt = EmptyOperator(task_id="post_dbt") + post_dbt_2 = EmptyOperator(task_id="post_dbt_2") + dbt_group >> post_dbt + dbt_group.set_downstream(post_dbt_2) + # Wire producer >> downstream tasks so the producer skip propagates directly, + # simulating what happens in a real deployment where the producer is a leaf + # task in the group and its skip affects downstream tasks. + producer >> post_dbt + producer >> post_dbt_2 + + outcome = new_test_dag(dag, expected_dag_state=DagRunState.SUCCESS) + + tis = {ti.task_id: ti for ti in outcome.get_task_instances()} + # Both downstream tasks run because propagate_watcher_trigger_rule prevents skip propagation + assert tis["post_dbt"].state == "success" + assert tis["post_dbt_2"].state == "success" + + def test_dbt_source_watcher_operator_template_fields(): """Test that DbtSourceWatcherOperator includes model_unique_id as a consumer sensor.""" from cosmos.operators._watcher.base import BaseConsumerSensor