Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
is_dbt_upstream_failure_skip_event,
is_producer_task_terminated,
safe_xcom_push,
xcom_set_lock,
Expand Down Expand Up @@ -253,6 +254,33 @@ def _ensure_subprocess_model_outlet_uris(
model_outlet_uris[_MODEL_OUTLET_URIS_ATTEMPTED_KEY] = [] # type: ignore[assignment]


def _rewrite_upstream_failure_skip_status(
log_line: dict[str, Any],
unique_id: str | None,
dbt_node_status: Any,
upstream_failure_skipped_ids: set[str] | None,
) -> Any:
"""Track upstream-failure-skipped nodes and rewrite their status from "skipped" to "failed".

dbt emits two events for a node that it skipped because of an upstream
failure: ``SkippingDetails``/``LogSkipBecauseError`` during ``on_skip()``,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these public dbt contracts? If internal, would it be possible to do a quick sanity check that these are consistent across the dbt versions we support? I guess for future releases we might be able to catch if these change when we test against newly released versions, but it would be nice to check for earlier versions once.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tatiana. I see the PR touching on the --log-format key. What I meant here was that the docstring talks about upstream failures: SkippingDetails / LogSkipBecauseError and whether they are the same for the dbt versions we support.

and a later ``NodeFinished`` from the runnable ``finally``-block -- both
carry ``node_status="skipped"``. The first identifies the upstream-failure
cause (those events are reached only via ``do_skip(cause=...)``, exclusively
on upstream-node failure); the second would otherwise overwrite the
rewritten XCom with the original ``"skipped"``. Tracking affected
``unique_id``\\ s in a per-execution set lets the parser rewrite both. See
#2698.
"""
if not unique_id or upstream_failure_skipped_ids is None:
return dbt_node_status
if is_dbt_upstream_failure_skip_event(log_line.get("info", {}).get("name")):
upstream_failure_skipped_ids.add(unique_id)
if dbt_node_status == "skipped" and unique_id in upstream_failure_skipped_ids:
return "failed"
Comment thread
tatiana marked this conversation as resolved.
return dbt_node_status
Comment thread
tatiana marked this conversation as resolved.


def _surface_non_json_stdout(line: str) -> None:
"""Forward non-JSON stdout (e.g. Snowflake connector's "Going to open: <URL>" prompt
during externalbrowser auth) to the task log instead of silently dropping it.
Expand All @@ -269,6 +297,7 @@ def store_dbt_resource_status_from_log(
test_results_per_model: dict[str, list[str]] | None = None,
model_outlet_uris: dict[str, list[str]] | None = None,
dataset_namespace: str | None = None,
upstream_failure_skipped_ids: set[str] | None = None,
) -> None:
Comment thread
tatiana marked this conversation as resolved.
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.
Expand All @@ -289,6 +318,12 @@ def store_dbt_resource_status_from_log(
:param model_outlet_uris: Mutable dict mapping unique_id to outlet URIs.
Populated lazily from the manifest on first terminal status detection.
:param dataset_namespace: The OL-compatible dataset namespace for URI construction.
:param upstream_failure_skipped_ids: Mutable accumulator set of node unique_ids
that dbt skipped because of an upstream-node failure. Populated when this
function sees a ``SkippingDetails`` or ``LogSkipBecauseError`` event; later
``NodeFinished`` events with ``node_status="skipped"`` for these unique_ids
are rewritten to ``"failed"`` so the consumer sensor fails (and Airflow can
retry it) rather than going SKIPPED. See #2698.
"""
try:
log_line = json.loads(line)
Expand All @@ -309,6 +344,12 @@ def store_dbt_resource_status_from_log(
dbt_node_resource_type = node_info.get("resource_type")
unique_id = node_info.get("unique_id")

# Rewrite "skipped" to "failed" when dbt skipped the node because an
# upstream node failed (#2698). See _rewrite_upstream_failure_skip_status.
dbt_node_status = _rewrite_upstream_failure_skip_status(
log_line, unique_id, dbt_node_status, upstream_failure_skipped_ids
)

logger.debug("Model: %s is in %s state", unique_id, dbt_node_status)

if is_dbt_node_status_terminal(dbt_node_status):
Expand Down
11 changes: 11 additions & 0 deletions cosmos/operators/_watcher/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
DBT_FAILED_STATUSES = frozenset({"failed", "fail", "error", "runtime error"})
DBT_SKIPPED_STATUSES = frozenset({"skipped"})

# dbt event names that signal a node was skipped because an upstream node failed.
# dbt fires SkippingDetails (non-ephemeral upstream) or LogSkipBecauseError
# (ephemeral upstream) only from on_skip(), which is reached only when
# do_skip(cause=...) was called -- i.e. exclusively on upstream node failure.
DBT_UPSTREAM_FAILURE_SKIP_EVENT_NAMES = frozenset({"SkippingDetails", "LogSkipBecauseError"})

# dbt source freshness statuses that mark a source as stale and propagate skips downstream.
DBT_SOURCE_FRESHNESS_STALE_STATUSES = frozenset({"error", "warn"})

Expand Down Expand Up @@ -66,6 +72,11 @@ def is_dbt_node_status_terminal(status: str | None) -> bool:
return is_dbt_node_status_success(status) or is_dbt_node_status_failed(status) or is_dbt_node_status_skipped(status)


def is_dbt_upstream_failure_skip_event(event_name: str | None) -> bool:
"""Check if the dbt event name indicates a node was skipped because an upstream node failed."""
return event_name in DBT_UPSTREAM_FAILURE_SKIP_EVENT_NAMES


def is_producer_task_terminated(state: str | None) -> bool:
"""Return True when the producer task is in a terminal state."""
return state in PRODUCER_TERMINAL_STATES
Expand Down
8 changes: 8 additions & 0 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Mutable dict populated lazily from the manifest; shared with the log parser.
self._dataset_namespace: str | None = None
self._model_outlet_uris: dict[str, list[str]] = {}
# Mutable set populated by the log parser when dbt emits SkippingDetails
# or LogSkipBecauseError for a node; subsequent "skipped" terminal events
# for those unique_ids are rewritten to "failed" so the consumer sensor
# fails on attempt 1 (instead of SKIPPED, which Airflow will not retry).
# See #2698.
self._upstream_failure_skipped_ids: set[str] = set()

def _handle_datasets(self, context: Context) -> None:
"""No-op override: consumer tasks handle their own dataset emission in WATCHER mode."""
Expand All @@ -246,6 +252,7 @@ def _make_parse_callable(self) -> Callable[[str, Any], None]:
test_results_per_model=self.test_results_per_model,
model_outlet_uris=self._model_outlet_uris,
dataset_namespace=self._dataset_namespace,
upstream_failure_skipped_ids=self._upstream_failure_skipped_ids,
)

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -439,6 +446,7 @@ def execute(self, context: Context, **kwargs: Any) -> Any:
# Pre-compute the dataset namespace for per-model outlet URI generation.
self._dataset_namespace = get_dataset_namespace(self.profile_config)
self._model_outlet_uris.clear()
self._upstream_failure_skipped_ids.clear()

task_instance = context.get("ti")
if task_instance is None:
Expand Down
31 changes: 23 additions & 8 deletions cosmos/operators/watcher_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@
logger = get_logger(__name__)


# This global variable is currently used to make the task context available to the K8s callback.
# While the callback is set during the operator initialization, the context is only created during the operator's execution.
# Module-level globals used to make per-execution state available to the static
# K8s callback (see WatcherKubernetesCallback). They are set in
# DbtProducerWatcherKubernetesOperator.execute, read in the callback, and
# mirror the per-operator state DbtProducerWatcherOperator threads through its
# functools.partial-wrapped parser.
producer_task_context = None
producer_upstream_failure_skipped_ids: set[str] | None = None


class WatcherKubernetesCallback(KubernetesPodOperatorCallback): # type: ignore[misc]
Expand All @@ -73,10 +77,12 @@ def progress_callback(
:param pod: the pod from which the log line was read.
"""
if "context" not in kwargs:
# This global variable is used to make the task context available to the K8s callback.
# While the callback is set during the operator initialization, the context is only created during the operator's execution.
kwargs["context"] = producer_task_context
store_dbt_resource_status_from_log(line, kwargs)
store_dbt_resource_status_from_log(
line,
kwargs,
upstream_failure_skipped_ids=producer_upstream_failure_skipped_ids,
)


class DbtProducerWatcherKubernetesOperator(DbtBuildKubernetesOperator):
Expand All @@ -98,6 +104,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["callbacks"] = normalized_callbacks
super().__init__(task_id=task_id, *args, **kwargs)
self.dbt_cmd_flags += ["--log-format", "json"]
# Mutable set populated by the log parser when dbt emits SkippingDetails
# or LogSkipBecauseError for a node; subsequent "skipped" terminal events
# for those unique_ids are rewritten to "failed" so the consumer sensor
# fails on attempt 1 (instead of SKIPPED, which Airflow will not retry).
# Mirrors DbtProducerWatcherOperator._upstream_failure_skipped_ids; see #2698.
self._upstream_failure_skipped_ids: set[str] = set()

@cached_property
def pod_manager(self) -> CosmosKubernetesPodManager:
Expand All @@ -121,10 +133,13 @@ def execute(self, context: Context, **kwargs: Any) -> Any:

_init_xcom_backup(context)

# This global variable is used to make the task context available to the K8s callback.
# While the callback is set during the operator initialization, the context is only created during the operator's execution.
global producer_task_context
self._upstream_failure_skipped_ids.clear()

# Make per-execution state available to the static K8s callback. See the
# module-level globals at the top of this file for details.
global producer_task_context, producer_upstream_failure_skipped_ids
producer_task_context = context
producer_upstream_failure_skipped_ids = self._upstream_failure_skipped_ids

try:
return_value = super().execute(context, **kwargs)
Expand Down
30 changes: 30 additions & 0 deletions dev/dags/dbt/watcher_upstream_failure_recovery/dbt_project.yml
Comment thread
tatiana marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: 'watcher_upstream_failure_recovery'

config-version: 2
version: '0.1'

profile: 'default'

model-paths: ["models"]
seed-paths: ["seeds"]
test-paths: ["tests"]
macro-paths: ["macros"]

target-path: "target"
clean-targets:
- "target"
- "dbt_modules"
- "logs"

require-dbt-version: [">=1.0.0", "<2.0.0"]

# Sequence used by model_flaky.sql to fail on first run and succeed on subsequent
# runs (see #2698 regression test). Using a project-specific sequence name
# avoids state leaking from / into watcher_downstream_not_skipped, which uses
# the same fail-once recipe but tests a different invariant.
on-run-start:
- "CREATE SEQUENCE IF NOT EXISTS {{ target.schema }}._cosmos_recovery_fail_once_seq"

models:
watcher_upstream_failure_recovery:
+materialized: table
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email
union all
select 2, 'Bob', 'Jones', 'bob@example.com'
union all
select 3, 'Charlie', 'Brown', 'charlie@example.com'
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
select
id,
first_name
from {{ ref('model_flaky') }}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
select
id,
first_name
from {{ ref('model_downstream') }}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{{
config(
pre_hook=[
"DO $$ BEGIN IF nextval('{{ target.schema }}._cosmos_recovery_fail_once_seq') <= 1 THEN RAISE EXCEPTION 'fail_once: intentional first-run failure'; END IF; END $$"
]
)
}}

select
id,
first_name
from {{ ref('model_a') }}
113 changes: 113 additions & 0 deletions dev/failed_dags/example_watcher_recovers_skipped_downstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Demonstrate watcher-mode recovery of a downstream model that dbt skipped
because its upstream failed on the producer's first attempt (#2698).

Without the fix in ``cosmos/operators/_watcher/base.py``:
- A dbt model fails on the first producer attempt.
- dbt marks every downstream node ``skipped`` with the upstream-failure cause.
Comment thread
tatiana marked this conversation as resolved.
- The producer log parser pushes that ``"skipped"`` status to XCom.
- The downstream consumer sensor raises ``AirflowSkipException`` -- SKIPPED.
- Airflow retries the producer task (the producer's retry is a no-op by
design: Cosmos restores XCom and raises ``AirflowSkipException`` to avoid
re-running the whole dbt build).
- The consumer sensor for the failing upstream retries on its own and falls
back to running ``dbt --select <model>`` locally, which succeeds.
- The downstream consumer, however, was already SKIPPED. Airflow does not
retry skipped tasks, so the downstream model is never re-run even though
its upstream has now recovered.
- The DAG ends in ``success`` because Airflow treats SKIPPED as non-failure
-- a "false green" outcome with un-materialized downstream tables.

With the fix, the producer parser rewrites the ``"skipped"`` status to
``"failed"`` for any node that dbt skipped via ``SkippingDetails`` /
``LogSkipBecauseError`` (the only paths reached when ``do_skip(cause=...)``
fires -- i.e. exclusively on upstream-node failure). The downstream consumer
then fails on attempt 1, Airflow retries it, and the same consumer-fallback
path that recovers the failing upstream now runs the downstream locally.

Models used (from ``dev/dags/dbt/watcher_upstream_failure_recovery``):
- ``model_a``: trivial source-style model, succeeds.
- ``model_flaky``: uses an ``on-run-start`` Postgres sequence to fail on
the first ``nextval`` call (``<= 1`` → ``RAISE EXCEPTION``) and succeed
on subsequent calls.
- ``model_downstream``: depends on ``model_flaky``; dbt skips it on
attempt 1 because its upstream failed.

A ``post_dbt`` ``EmptyOperator`` downstream of the task group makes the
"green DAG" visible. A ``cleanup`` SQL task drops the sequence at the end
so the DAG is re-runnable from a clean state.
"""

import os
from datetime import datetime, timedelta
from pathlib import Path

from airflow.models import DAG
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

try:
from airflow.providers.standard.operators.empty import EmptyOperator
except ImportError:
from airflow.operators.empty import EmptyOperator

from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig
from cosmos.constants import ExecutionMode
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_upstream_failure_recovery"

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

execution_config = ExecutionConfig(
execution_mode=ExecutionMode.WATCHER,
)

operator_args = {
"install_deps": True,
"execution_timeout": timedelta(seconds=120),
}

if os.getenv("CI"):
operator_args["trigger_rule"] = "all_success"

default_args = {
"retries": 2,
"retry_delay": timedelta(seconds=0),
}

with DAG(
dag_id="example_watcher_recovers_skipped_downstream",
schedule="@daily",
start_date=datetime(2023, 1, 1),
catchup=False,
default_args=default_args,
):
dbt_group = DbtTaskGroup(
group_id="watcher_upstream_failure_recovery",
execution_config=execution_config,
project_config=ProjectConfig(DBT_PROJECT_PATH),
profile_config=profile_config,
operator_args=operator_args,
)

post_dbt = EmptyOperator(task_id="post_dbt")

cleanup = SQLExecuteQueryOperator(
task_id="drop_fail_once_marker",
conn_id="example_conn",
sql="DROP SEQUENCE IF EXISTS public._cosmos_recovery_fail_once_seq;",
trigger_rule="all_done",
)

dbt_group >> post_dbt
dbt_group >> cleanup
Loading
Loading