diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 3da4f30f3f..90784bc84a 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -32,6 +32,7 @@ is_dbt_node_status_skipped, is_dbt_node_status_success, is_dbt_node_status_terminal, + is_dbt_upstream_failure_skip_event, is_producer_task_terminated, safe_xcom_push, xcom_set_lock, @@ -253,6 +254,33 @@ def _ensure_subprocess_model_outlet_uris( model_outlet_uris[_MODEL_OUTLET_URIS_ATTEMPTED_KEY] = [] # type: ignore[assignment] +def _rewrite_upstream_failure_skip_status( + log_line: dict[str, Any], + unique_id: str | None, + dbt_node_status: Any, + upstream_failure_skipped_ids: set[str] | None, +) -> Any: + """Track upstream-failure-skipped nodes and rewrite their status from "skipped" to "failed". + + dbt emits two events for a node that it skipped because of an upstream + failure: ``SkippingDetails``/``LogSkipBecauseError`` during ``on_skip()``, + and a later ``NodeFinished`` from the runnable ``finally``-block -- both + carry ``node_status="skipped"``. The first identifies the upstream-failure + cause (those events are reached only via ``do_skip(cause=...)``, exclusively + on upstream-node failure); the second would otherwise overwrite the + rewritten XCom with the original ``"skipped"``. Tracking affected + ``unique_id``\\ s in a per-execution set lets the parser rewrite both. See + #2698. + """ + if not unique_id or upstream_failure_skipped_ids is None: + return dbt_node_status + if is_dbt_upstream_failure_skip_event(log_line.get("info", {}).get("name")): + upstream_failure_skipped_ids.add(unique_id) + if dbt_node_status == "skipped" and unique_id in upstream_failure_skipped_ids: + return "failed" + return dbt_node_status + + def _surface_non_json_stdout(line: str) -> None: """Forward non-JSON stdout (e.g. Snowflake connector's "Going to open: " prompt during externalbrowser auth) to the task log instead of silently dropping it. @@ -269,6 +297,7 @@ def store_dbt_resource_status_from_log( test_results_per_model: dict[str, list[str]] | None = None, model_outlet_uris: dict[str, list[str]] | None = None, dataset_namespace: str | None = None, + upstream_failure_skipped_ids: set[str] | None = None, ) -> None: """ Parses a single line from dbt JSON logs and stores node status to Airflow XCom. @@ -289,6 +318,12 @@ def store_dbt_resource_status_from_log( :param model_outlet_uris: Mutable dict mapping unique_id to outlet URIs. Populated lazily from the manifest on first terminal status detection. :param dataset_namespace: The OL-compatible dataset namespace for URI construction. + :param upstream_failure_skipped_ids: Mutable accumulator set of node unique_ids + that dbt skipped because of an upstream-node failure. Populated when this + function sees a ``SkippingDetails`` or ``LogSkipBecauseError`` event; later + ``NodeFinished`` events with ``node_status="skipped"`` for these unique_ids + are rewritten to ``"failed"`` so the consumer sensor fails (and Airflow can + retry it) rather than going SKIPPED. See #2698. """ try: log_line = json.loads(line) @@ -309,6 +344,12 @@ def store_dbt_resource_status_from_log( dbt_node_resource_type = node_info.get("resource_type") unique_id = node_info.get("unique_id") + # Rewrite "skipped" to "failed" when dbt skipped the node because an + # upstream node failed (#2698). See _rewrite_upstream_failure_skip_status. + dbt_node_status = _rewrite_upstream_failure_skip_status( + log_line, unique_id, dbt_node_status, upstream_failure_skipped_ids + ) + logger.debug("Model: %s is in %s state", unique_id, dbt_node_status) if is_dbt_node_status_terminal(dbt_node_status): diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index b27e63fafe..58ff8095d3 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -25,6 +25,12 @@ DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"}) DBT_SKIPPED_STATUSES = frozenset({"skipped"}) +# dbt event names that signal a node was skipped because an upstream node failed. +# dbt fires SkippingDetails (non-ephemeral upstream) or LogSkipBecauseError +# (ephemeral upstream) only from on_skip(), which is reached only when +# do_skip(cause=...) was called -- i.e. exclusively on upstream node failure. +DBT_UPSTREAM_FAILURE_SKIP_EVENT_NAMES = frozenset({"SkippingDetails", "LogSkipBecauseError"}) + # dbt source freshness statuses that mark a source as stale and propagate skips downstream. DBT_SOURCE_FRESHNESS_STALE_STATUSES = frozenset({"error", "warn"}) @@ -66,6 +72,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool: return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status) +def is_dbt_upstream_failure_skip_event(event_name: str | None) -> bool: + """Check if the dbt event name indicates a node was skipped because an upstream node failed.""" + return event_name in DBT_UPSTREAM_FAILURE_SKIP_EVENT_NAMES + + def is_producer_task_terminated(state: str | None) -> bool: """Return True when the producer task is in a terminal state.""" return state in PRODUCER_TERMINAL_STATES diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index f7e0e707cb..0a15d2d558 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -234,6 +234,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Mutable dict populated lazily from the manifest; shared with the log parser. self._dataset_namespace: str | None = None self._model_outlet_uris: dict[str, list[str]] = {} + # Mutable set populated by the log parser when dbt emits SkippingDetails + # or LogSkipBecauseError for a node; subsequent "skipped" terminal events + # for those unique_ids are rewritten to "failed" so the consumer sensor + # fails on attempt 1 (instead of SKIPPED, which Airflow will not retry). + # See #2698. + self._upstream_failure_skipped_ids: set[str] = set() def _handle_datasets(self, context: Context) -> None: """No-op override: consumer tasks handle their own dataset emission in WATCHER mode.""" @@ -246,6 +252,7 @@ def _make_parse_callable(self) -> Callable[[str, Any], None]: test_results_per_model=self.test_results_per_model, model_outlet_uris=self._model_outlet_uris, dataset_namespace=self._dataset_namespace, + upstream_failure_skipped_ids=self._upstream_failure_skipped_ids, ) def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any: @@ -439,6 +446,7 @@ def execute(self, context: Context, **kwargs: Any) -> Any: # Pre-compute the dataset namespace for per-model outlet URI generation. self._dataset_namespace = get_dataset_namespace(self.profile_config) self._model_outlet_uris.clear() + self._upstream_failure_skipped_ids.clear() task_instance = context.get("ti") if task_instance is None: diff --git a/cosmos/operators/watcher_kubernetes.py b/cosmos/operators/watcher_kubernetes.py index 3cab4835e4..5b5811627d 100644 --- a/cosmos/operators/watcher_kubernetes.py +++ b/cosmos/operators/watcher_kubernetes.py @@ -44,9 +44,13 @@ logger = get_logger(__name__) -# This global variable is currently used to make the task context available to the K8s callback. -# While the callback is set during the operator initialization, the context is only created during the operator's execution. +# Module-level globals used to make per-execution state available to the static +# K8s callback (see WatcherKubernetesCallback). They are set in +# DbtProducerWatcherKubernetesOperator.execute, read in the callback, and +# mirror the per-operator state DbtProducerWatcherOperator threads through its +# functools.partial-wrapped parser. producer_task_context = None +producer_upstream_failure_skipped_ids: set[str] | None = None class WatcherKubernetesCallback(KubernetesPodOperatorCallback): # type: ignore[misc] @@ -73,10 +77,12 @@ def progress_callback( :param pod: the pod from which the log line was read. """ if "context" not in kwargs: - # This global variable is used to make the task context available to the K8s callback. - # While the callback is set during the operator initialization, the context is only created during the operator's execution. kwargs["context"] = producer_task_context - store_dbt_resource_status_from_log(line, kwargs) + store_dbt_resource_status_from_log( + line, + kwargs, + upstream_failure_skipped_ids=producer_upstream_failure_skipped_ids, + ) class DbtProducerWatcherKubernetesOperator(DbtBuildKubernetesOperator): @@ -98,6 +104,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["callbacks"] = normalized_callbacks super().__init__(task_id=task_id, *args, **kwargs) self.dbt_cmd_flags += ["--log-format", "json"] + # Mutable set populated by the log parser when dbt emits SkippingDetails + # or LogSkipBecauseError for a node; subsequent "skipped" terminal events + # for those unique_ids are rewritten to "failed" so the consumer sensor + # fails on attempt 1 (instead of SKIPPED, which Airflow will not retry). + # Mirrors DbtProducerWatcherOperator._upstream_failure_skipped_ids; see #2698. + self._upstream_failure_skipped_ids: set[str] = set() @cached_property def pod_manager(self) -> CosmosKubernetesPodManager: @@ -121,10 +133,13 @@ def execute(self, context: Context, **kwargs: Any) -> Any: _init_xcom_backup(context) - # This global variable is used to make the task context available to the K8s callback. - # While the callback is set during the operator initialization, the context is only created during the operator's execution. - global producer_task_context + self._upstream_failure_skipped_ids.clear() + + # Make per-execution state available to the static K8s callback. See the + # module-level globals at the top of this file for details. + global producer_task_context, producer_upstream_failure_skipped_ids producer_task_context = context + producer_upstream_failure_skipped_ids = self._upstream_failure_skipped_ids try: return_value = super().execute(context, **kwargs) diff --git a/dev/dags/dbt/watcher_upstream_failure_recovery/dbt_project.yml b/dev/dags/dbt/watcher_upstream_failure_recovery/dbt_project.yml new file mode 100644 index 0000000000..b6a3459288 --- /dev/null +++ b/dev/dags/dbt/watcher_upstream_failure_recovery/dbt_project.yml @@ -0,0 +1,30 @@ +name: 'watcher_upstream_failure_recovery' + +config-version: 2 +version: '0.1' + +profile: 'default' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +# Sequence used by model_flaky.sql to fail on first run and succeed on subsequent +# runs (see #2698 regression test). Using a project-specific sequence name +# avoids state leaking from / into watcher_downstream_not_skipped, which uses +# the same fail-once recipe but tests a different invariant. +on-run-start: + - "CREATE SEQUENCE IF NOT EXISTS {{ target.schema }}._cosmos_recovery_fail_once_seq" + +models: + watcher_upstream_failure_recovery: + +materialized: table diff --git a/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_a.sql b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_a.sql new file mode 100644 index 0000000000..2031e20406 --- /dev/null +++ b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_a.sql @@ -0,0 +1,5 @@ +select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email +union all +select 2, 'Bob', 'Jones', 'bob@example.com' +union all +select 3, 'Charlie', 'Brown', 'charlie@example.com' diff --git a/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream.sql b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream.sql new file mode 100644 index 0000000000..7b373bb260 --- /dev/null +++ b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream.sql @@ -0,0 +1,4 @@ +select + id, + first_name +from {{ ref('model_flaky') }} diff --git a/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream_2.sql b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream_2.sql new file mode 100644 index 0000000000..62c07fefcf --- /dev/null +++ b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_downstream_2.sql @@ -0,0 +1,4 @@ +select + id, + first_name +from {{ ref('model_downstream') }} diff --git a/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_flaky.sql b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_flaky.sql new file mode 100644 index 0000000000..e4821c1a49 --- /dev/null +++ b/dev/dags/dbt/watcher_upstream_failure_recovery/models/model_flaky.sql @@ -0,0 +1,12 @@ +{{ + config( + pre_hook=[ + "DO $$ BEGIN IF nextval('{{ target.schema }}._cosmos_recovery_fail_once_seq') <= 1 THEN RAISE EXCEPTION 'fail_once: intentional first-run failure'; END IF; END $$" + ] + ) +}} + +select + id, + first_name +from {{ ref('model_a') }} diff --git a/dev/failed_dags/example_watcher_recovers_skipped_downstream.py b/dev/failed_dags/example_watcher_recovers_skipped_downstream.py new file mode 100644 index 0000000000..d728d5b4c5 --- /dev/null +++ b/dev/failed_dags/example_watcher_recovers_skipped_downstream.py @@ -0,0 +1,113 @@ +""" +Demonstrate watcher-mode recovery of a downstream model that dbt skipped +because its upstream failed on the producer's first attempt (#2698). + +Without the fix in ``cosmos/operators/_watcher/base.py``: +- A dbt model fails on the first producer attempt. +- dbt marks every downstream node ``skipped`` with the upstream-failure cause. +- The producer log parser pushes that ``"skipped"`` status to XCom. +- The downstream consumer sensor raises ``AirflowSkipException`` -- SKIPPED. +- Airflow retries the producer task (the producer's retry is a no-op by + design: Cosmos restores XCom and raises ``AirflowSkipException`` to avoid + re-running the whole dbt build). +- The consumer sensor for the failing upstream retries on its own and falls + back to running ``dbt --select `` locally, which succeeds. +- The downstream consumer, however, was already SKIPPED. Airflow does not + retry skipped tasks, so the downstream model is never re-run even though + its upstream has now recovered. +- The DAG ends in ``success`` because Airflow treats SKIPPED as non-failure + -- a "false green" outcome with un-materialized downstream tables. + +With the fix, the producer parser rewrites the ``"skipped"`` status to +``"failed"`` for any node that dbt skipped via ``SkippingDetails`` / +``LogSkipBecauseError`` (the only paths reached when ``do_skip(cause=...)`` +fires -- i.e. exclusively on upstream-node failure). The downstream consumer +then fails on attempt 1, Airflow retries it, and the same consumer-fallback +path that recovers the failing upstream now runs the downstream locally. + +Models used (from ``dev/dags/dbt/watcher_upstream_failure_recovery``): +- ``model_a``: trivial source-style model, succeeds. +- ``model_flaky``: uses an ``on-run-start`` Postgres sequence to fail on + the first ``nextval`` call (``<= 1`` → ``RAISE EXCEPTION``) and succeed + on subsequent calls. +- ``model_downstream``: depends on ``model_flaky``; dbt skips it on + attempt 1 because its upstream failed. + +A ``post_dbt`` ``EmptyOperator`` downstream of the task group makes the +"green DAG" visible. A ``cleanup`` SQL task drops the sequence at the end +so the DAG is re-runnable from a clean state. +""" + +import os +from datetime import datetime, timedelta +from pathlib import Path + +from airflow.models import DAG +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + +try: + from airflow.providers.standard.operators.empty import EmptyOperator +except ImportError: + from airflow.operators.empty import EmptyOperator + +from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig +from cosmos.constants import ExecutionMode +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) +DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_upstream_failure_recovery" + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +execution_config = ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, +) + +operator_args = { + "install_deps": True, + "execution_timeout": timedelta(seconds=120), +} + +if os.getenv("CI"): + operator_args["trigger_rule"] = "all_success" + +default_args = { + "retries": 2, + "retry_delay": timedelta(seconds=0), +} + +with DAG( + dag_id="example_watcher_recovers_skipped_downstream", + schedule="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + default_args=default_args, +): + dbt_group = DbtTaskGroup( + group_id="watcher_upstream_failure_recovery", + execution_config=execution_config, + project_config=ProjectConfig(DBT_PROJECT_PATH), + profile_config=profile_config, + operator_args=operator_args, + ) + + post_dbt = EmptyOperator(task_id="post_dbt") + + cleanup = SQLExecuteQueryOperator( + task_id="drop_fail_once_marker", + conn_id="example_conn", + sql="DROP SEQUENCE IF EXISTS public._cosmos_recovery_fail_once_seq;", + trigger_rule="all_done", + ) + + dbt_group >> post_dbt + dbt_group >> cleanup diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 181d585f68..ab4d6ce0e7 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -50,6 +50,9 @@ DBT_WATCHER_DOWNSTREAM_NOT_SKIPPED_PATH = ( Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_downstream_not_skipped" ) +DBT_WATCHER_UPSTREAM_FAILURE_RECOVERY_PATH = ( + Path(__file__).parent.parent.parent / "dev/dags/dbt/watcher_upstream_failure_recovery" +) DBT_EXECUTABLE_PATH = Path(__file__).parent.parent.parent / "venv-subprocess/bin/dbt" DBT_PROJECT_WITH_EMPTY_MODEL_PATH = Path(__file__).parent.parent / "sample/dbt_project_with_empty_model" @@ -85,6 +88,38 @@ class _MockContext(dict): pass +@pytest.fixture +def reset_fail_once_sequence(): + """Drop a fail-once Postgres sequence on ``example_conn`` so a watcher + retry-recovery integration test starts from a clean state. + + The watcher integration tests use a per-project Postgres sequence + (created in the dbt project's ``on-run-start``) plus a model ``pre_hook`` + that ``nextval``s into it and raises if the value is <= 1. That makes the + first run of the flaky model fail and any subsequent run succeed -- the + minimal recipe for exercising the consumer-fallback retry path. Each test + drops its sequence before running so re-runs and CI parallelism don't + leak sequence state. + """ + import psycopg2 + from airflow.hooks.base import BaseHook + + def _drop(sequence_name: str) -> None: + airflow_conn = BaseHook.get_connection("example_conn") + with psycopg2.connect( + host=airflow_conn.host, + port=airflow_conn.port or 5432, + dbname=airflow_conn.schema or "postgres", + user=airflow_conn.login, + password=airflow_conn.password, + ) as conn: + conn.autocommit = True + with conn.cursor() as cur: + cur.execute(f"DROP SEQUENCE IF EXISTS public.{sequence_name}") + + return _drop + + def test_dbt_producer_watcher_operator_priority_weight_default(): """Test that DbtProducerWatcherOperator uses default priority_weight of 9999.""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) @@ -372,6 +407,20 @@ def test_emit_datasets_pushes_xcom_when_enabled(self, mock_register): class TestStoreDbtStatusFromLog: """Tests for store_dbt_resource_status_from_log and _process_log_line_callable.""" + @staticmethod + def _node_log_line(unique_id: str, node_status: str, event_name: str = "NodeFinished") -> str: + return json.dumps( + { + "info": {"name": event_name}, + "data": { + "node_info": { + "node_status": node_status, + "unique_id": unique_id, + } + }, + } + ) + def test_store_dbt_resource_status_from_log_success(self): """Test that success status is correctly parsed and stored in XCom.""" ti = _MockTI() @@ -394,6 +443,64 @@ def test_store_dbt_resource_status_from_log_failed(self): assert ti.store.get("model__pkg__failed_model_status") == {"status": "failed", "outlet_uris": []} + @pytest.mark.parametrize( + "first_event_name, second_event_name", + [ + ("SkippingDetails", "NodeFinished"), + ("NodeFinished", "SkippingDetails"), + ("LogSkipBecauseError", "NodeFinished"), + ("NodeFinished", "LogSkipBecauseError"), + ], + ) + def test_store_dbt_resource_status_from_log_rewrites_upstream_failure_skips( + self, first_event_name, second_event_name + ): + """Upstream-failure skip events are stored as failed so Airflow can retry the consumer.""" + ti = _MockTI() + ctx = {"ti": ti} + upstream_failure_skipped_ids: set[str] = set() + + store_dbt_resource_status_from_log( + self._node_log_line("model.pkg.downstream_model", "skipped", first_event_name), + {"context": ctx}, + tests_per_model={}, + test_results_per_model={}, + upstream_failure_skipped_ids=upstream_failure_skipped_ids, + ) + store_dbt_resource_status_from_log( + self._node_log_line("model.pkg.downstream_model", "skipped", second_event_name), + {"context": ctx}, + tests_per_model={}, + test_results_per_model={}, + upstream_failure_skipped_ids=upstream_failure_skipped_ids, + ) + + assert upstream_failure_skipped_ids == {"model.pkg.downstream_model"} + assert ti.store.get("model__pkg__downstream_model_status") == {"status": "failed", "outlet_uris": []} + + def test_store_dbt_resource_status_from_log_keeps_plain_skips_as_skipped(self): + """Plain skips are unchanged, including when no rewrite accumulator is passed.""" + ti = _MockTI() + ctx = {"ti": ti} + upstream_failure_skipped_ids: set[str] = set() + + store_dbt_resource_status_from_log( + self._node_log_line("model.pkg.with_accumulator", "skipped"), + {"context": ctx}, + tests_per_model={}, + test_results_per_model={}, + upstream_failure_skipped_ids=upstream_failure_skipped_ids, + ) + store_dbt_resource_status_from_log( + self._node_log_line("model.pkg.without_accumulator", "skipped", "SkippingDetails"), + {"context": ctx}, + tests_per_model={}, + test_results_per_model={}, + ) + + assert ti.store.get("model__pkg__with_accumulator_status") == {"status": "skipped", "outlet_uris": []} + assert ti.store.get("model__pkg__without_accumulator_status") == {"status": "skipped", "outlet_uris": []} + def test_store_dbt_resource_status_from_log_ignores_other_statuses(self): """Test that statuses other than success/failed are ignored.""" ti = _MockTI() @@ -2090,14 +2197,12 @@ def skip_orders(context, dag, task_group, nodes, sources_json): ), ) @pytest.mark.integration -def test_dbt_task_group_watcher_gateway_prevents_downstream_skip(caplog): +def test_dbt_task_group_watcher_gateway_prevents_downstream_skip(caplog, reset_fail_once_sequence): """ Verify that the dbt_producer_watcher_done gateway task prevents the producer's skip from propagating to tasks downstream of the DbtTaskGroup. """ - import psycopg2 from airflow import DAG - from airflow.hooks.base import BaseHook from cosmos import DbtTaskGroup @@ -2106,18 +2211,7 @@ def test_dbt_task_group_watcher_gateway_prevents_downstream_skip(caplog): except ImportError: from airflow.operators.empty import EmptyOperator - # Reset the fail_once sequence using credentials from the same Airflow connection - airflow_conn = BaseHook.get_connection("example_conn") - with psycopg2.connect( - host=airflow_conn.host, - port=airflow_conn.port or 5432, - dbname=airflow_conn.schema or "postgres", - user=airflow_conn.login, - password=airflow_conn.password, - ) as conn: - conn.autocommit = True - with conn.cursor() as cur: - cur.execute("DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq") + reset_fail_once_sequence("_cosmos_fail_once_seq") caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") @@ -2158,6 +2252,111 @@ def test_dbt_task_group_watcher_gateway_prevents_downstream_skip(caplog): assert tis["watcher_downstream_not_skipped.dbt_producer_watcher_done"].state == "success" +@pytest.mark.skipif( + AIRFLOW_VERSION < Version("2.10") or (Version("3.0.0") <= AIRFLOW_VERSION < Version("3.2.0")), + reason=( + "dag.test() in Airflow 2.9 hangs when a task fails with retries configured. " + "Airflow 3.0 runs tasks inline via _run_raw_task without the task SDK supervisor, " + "so RuntimeTaskInstance.get_task_states raises NameError for SUPERVISOR_COMMS and " + "the watcher sensor cannot detect producer termination on retry. " + "Airflow 3.1.x crashes during task finalization (SetRenderedFields) when retrying " + "tasks inside a DbtTaskGroup via dag.test()." + ), +) +@pytest.mark.integration +@pytest.mark.parametrize( + "downstream_task_id", + [ + # Level-1 downstream of the flaky model. + "watcher_upstream_failure_recovery.model_downstream_run", + # Level-2 downstream -- confirms the fix handles arbitrary chain depth, + # since dbt emits SkippingDetails for each transitively-skipped node. + "watcher_upstream_failure_recovery.model_downstream_2_run", + ], +) +def test_dbt_task_group_watcher_retry_recovers_skipped_downstream(caplog, reset_fail_once_sequence, downstream_task_id): + """Regression for upstream-failure skips that previously left downstream models skipped (#2698). + + When a dbt model fails on the producer's first attempt, dbt emits + ``SkippingDetails`` (and/or ``LogSkipBecauseError`` for ephemeral upstreams) + plus a later ``NodeFinished`` event with ``node_status="skipped"`` for each + downstream node. Before the fix that "skipped" status went straight to XCom, + the consumer sensor raised ``AirflowSkipException``, and the task ended + SKIPPED. Airflow does not retry SKIPPED tasks, so even after the failing + upstream recovered on its own consumer retry, the downstream model was never + re-run -- the DAG ended in a "false green" outcome with un-materialized + downstream tables. + + The fix (``store_dbt_resource_status_from_log`` in + ``cosmos/operators/_watcher/base.py``) tracks unique_ids that dbt skipped + due to upstream-node failure (the only path that fires + ``SkippingDetails``/``LogSkipBecauseError``, since both come from + ``on_skip()`` which is only reached via ``do_skip(cause=...)``). Subsequent + "skipped" terminal events for those unique_ids are rewritten to "failed" so + the consumer fails on attempt 1, Airflow retries it, and the existing + ``_fallback_to_non_watcher_run`` path runs the model locally once its + upstream has recovered. + """ + from airflow import DAG + + from cosmos import DbtTaskGroup + + try: + from airflow.providers.standard.operators.empty import EmptyOperator + except ImportError: + from airflow.operators.empty import EmptyOperator + + reset_fail_once_sequence("_cosmos_recovery_fail_once_seq") + + caplog.set_level(logging.DEBUG, logger="cosmos.operators._watcher.base") + + with DAG( + dag_id="watcher_upstream_failure_recovery_test", + start_date=datetime(2023, 1, 1), + default_args={"retries": 2, "retry_delay": timedelta(seconds=0)}, + dagrun_timeout=timedelta(seconds=180), + ) as dag: + dbt_group = DbtTaskGroup( + group_id="watcher_upstream_failure_recovery", + execution_config=ExecutionConfig(execution_mode=ExecutionMode.WATCHER), + project_config=ProjectConfig(dbt_project_path=DBT_WATCHER_UPSTREAM_FAILURE_RECOVERY_PATH), + profile_config=profile_config, + render_config=RenderConfig(emit_datasets=False, test_behavior=TestBehavior.NONE), + operator_args={"trigger_rule": "none_failed", "execution_timeout": timedelta(seconds=180)}, + ) + + # Force gateway >> root consumers so dag.test() respects execution order + # (see test_dbt_task_group_watcher_gateway_prevents_downstream_skip for the + # apache/airflow#56723 context). + gateway = dag.task_dict["watcher_upstream_failure_recovery.dbt_producer_watcher_done"] + for root_task in dbt_group.get_roots(): + if root_task.task_id.endswith("_done") or root_task.task_id.endswith("dbt_producer_watcher"): + continue + gateway >> root_task + + post_dbt = EmptyOperator(task_id="post_dbt") + dbt_group >> post_dbt + + outcome = new_test_dag(dag, expected_dag_state=DagRunState.SUCCESS) + + tis = {ti.task_id: ti for ti in outcome.get_task_instances()} + + # Core #2698 assertion: the parametrized downstream model must end SUCCESS, + # not SKIPPED, and must have actually been retried (try_number > 1 proves + # the consumer-fallback path fired rather than the task happening to succeed + # on the first attempt). + assert tis[downstream_task_id].state == "success" + assert tis[downstream_task_id].try_number > 1 + + # The flaky upstream itself also recovers via the consumer-retry fallback. + assert tis["watcher_upstream_failure_recovery.model_flaky_run"].state == "success" + assert tis["watcher_upstream_failure_recovery.model_a_run"].state == "success" + + # Producer/gateway/post_dbt sanity (matches the gateway-prevents-skip test). + assert tis["watcher_upstream_failure_recovery.dbt_producer_watcher_done"].state == "success" + assert tis["post_dbt"].state == "success" + + def test_dbt_source_watcher_operator_template_fields(): """Test that DbtSourceWatcherOperator includes model_unique_id as a consumer sensor.""" from cosmos.operators._watcher.base import BaseConsumerSensor