-
Notifications
You must be signed in to change notification settings - Fork 296
Prevent watcher producer skip propagating to downstream tasks via trigger_rule #2591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
acf93ee
bec86e0
4199c20
a2a7892
c258efa
44d8a9c
56650f6
58eaa09
822a6e5
260a450
526b30d
3e4feca
aba8c9d
3af4f18
839ef9a
7247b32
8f32df8
cf888ac
679007b
cb2fc1e
4f7e948
7ce3478
4e1ab94
ecda234
cd366d8
4a8b966
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |||||||||||
| except ImportError: | ||||||||||||
| from airflow.utils.task_group import TaskGroup | ||||||||||||
|
|
||||||||||||
| from cosmos import settings | ||||||||||||
| from cosmos.config import ExecutionConfig | ||||||||||||
| from cosmos.constants import ExecutionMode | ||||||||||||
| from cosmos.converter import DbtToAirflowConverter, airflow_kwargs, specific_kwargs | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -26,7 +29,52 @@ def __init__( | |||||||||||
| *args: Any, | ||||||||||||
| **kwargs: Any, | ||||||||||||
| ) -> None: | ||||||||||||
| self._execution_config = kwargs.get("execution_config") | ||||||||||||
| kwargs["group_id"] = group_id | ||||||||||||
| TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs)) | ||||||||||||
| kwargs["task_group"] = self | ||||||||||||
| DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) | ||||||||||||
|
|
||||||||||||
| @property | ||||||||||||
| def is_watcher_mode(self) -> bool: | ||||||||||||
| """Whether this task group uses a watcher execution mode.""" | ||||||||||||
| return isinstance(self._execution_config, ExecutionConfig) and self._execution_config.execution_mode in ( | ||||||||||||
| ExecutionMode.WATCHER, | ||||||||||||
| ExecutionMode.WATCHER_KUBERNETES, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| def _set_watcher_downstream_tasks_trigger_rule(self, downstream_tasks: Any) -> None: | ||||||||||||
| """In watcher mode the producer task may be skipped on retry. | ||||||||||||
|
|
||||||||||||
| Set ``trigger_rule="none_failed_min_one_success"`` on downstream tasks so the skip | ||||||||||||
| does not propagate outside the task group. | ||||||||||||
| """ | ||||||||||||
| if not self.is_watcher_mode or not settings.propagate_watcher_trigger_rule: | ||||||||||||
| return | ||||||||||||
| items = downstream_tasks if isinstance(downstream_tasks, (list, tuple)) else [downstream_tasks] | ||||||||||||
| for item in items: | ||||||||||||
| if isinstance(item, TaskGroup): | ||||||||||||
| # For downstream TaskGroups, only set trigger_rule on root tasks | ||||||||||||
| # (tasks with no upstream within the group) to avoid altering | ||||||||||||
| # the group's internal dependency semantics. | ||||||||||||
| for root_task in item.get_roots(): | ||||||||||||
| if hasattr(root_task, "trigger_rule"): | ||||||||||||
| root_task.trigger_rule = "none_failed_min_one_success" # type: ignore[assignment] | ||||||||||||
| elif hasattr(item, "trigger_rule"): | ||||||||||||
| item.trigger_rule = "none_failed_min_one_success" # type: ignore[assignment] | ||||||||||||
|
|
||||||||||||
| def __rshift__(self, other: Any) -> Any: | ||||||||||||
| # dbt_group >> post_dbt — post_dbt is downstream of dbt_group | ||||||||||||
| result = super().__rshift__(other) | ||||||||||||
| self._set_watcher_downstream_tasks_trigger_rule(other) | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
| def __rlshift__(self, other: Any) -> Any: | ||||||||||||
| # other << dbt_group — other is downstream of dbt_group | ||||||||||||
|
||||||||||||
| # other << dbt_group — other is downstream of dbt_group | |
| # Reflected ``<<`` hook used only when Python/Airflow dispatch resolves to this | |
| # TaskGroup instance. This is not a general interception point for | |
| # ``task << dbt_group`` / ``task.set_upstream(dbt_group)``; watcher trigger-rule | |
| # propagation must not rely on this method being invoked for those cases. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| name: 'watcher_downstream_not_skipped' | ||
|
|
||
| config-version: 2 | ||
| version: '0.1' | ||
|
|
||
| profile: 'default' | ||
|
|
||
| model-paths: ["models"] | ||
| seed-paths: ["seeds"] | ||
| test-paths: ["tests"] | ||
| macro-paths: ["macros"] | ||
|
|
||
| target-path: "target" | ||
| clean-targets: | ||
| - "target" | ||
| - "dbt_modules" | ||
| - "logs" | ||
|
|
||
| require-dbt-version: [">=1.0.0", "<2.0.0"] | ||
|
|
||
| on-run-start: | ||
| - "CREATE SEQUENCE IF NOT EXISTS {{ target.schema }}._cosmos_fail_once_seq" | ||
|
|
||
| models: | ||
| watcher_downstream_not_skipped: | ||
| +materialized: table |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| select 1 as id, 'Alice' as first_name, 'Smith' as last_name, 'alice@example.com' as email | ||
| union all | ||
| select 2, 'Bob', 'Jones', 'bob@example.com' | ||
| union all | ||
| select 3, 'Charlie', 'Brown', 'charlie@example.com' |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| {{ | ||
| config( | ||
| pre_hook=[ | ||
| "DO $$ BEGIN IF nextval('{{ target.schema }}._cosmos_fail_once_seq') <= 1 THEN RAISE EXCEPTION 'fail_once: intentional first-run failure'; END IF; END $$" | ||
| ] | ||
| ) | ||
| }} | ||
|
|
||
| select | ||
| id, | ||
| first_name | ||
| from {{ ref('model_a') }} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| """ | ||
| Airflow DAG to verify watcher downstream-task behavior when a producer task is skipped on retry. | ||
|
|
||
| In watcher mode, when a dbt model fails the producer retries and raises AirflowSkipException. | ||
| By default, the skip propagates to all tasks downstream of the TaskGroup (e.g. post_dbt), | ||
| even though the consumer tasks inside the group succeeded. | ||
|
|
||
| To prevent this, enable the ``propagate_watcher_trigger_rule`` Cosmos setting:: | ||
|
|
||
| export AIRFLOW__COSMOS__PROPAGATE_WATCHER_TRIGGER_RULE=True | ||
|
|
||
| When enabled, Cosmos automatically sets ``trigger_rule="none_failed"`` on tasks downstream | ||
| of a watcher DbtTaskGroup, so the producer skip does not propagate. | ||
|
|
||
| Without this setting, post_dbt will be skipped. | ||
|
|
||
| This DAG demonstrates the behaviour with the setting enabled: | ||
| - model_a succeeds in the producer and the consumer reads the result from XCom | ||
| - model_retry fails on the first attempt (via a fail_once sequence pre-hook) but succeeds | ||
| on the consumer retry fallback | ||
| - post_dbt (an EmptyOperator downstream of the group) runs successfully — it is NOT skipped | ||
|
|
||
| A cleanup task drops the PostgreSQL sequence so the DAG can be re-triggered. | ||
| """ | ||
|
tatiana marked this conversation as resolved.
|
||
|
|
||
| import os | ||
| from datetime import datetime, timedelta | ||
| from pathlib import Path | ||
|
|
||
| from airflow.models import DAG | ||
| from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator | ||
|
|
||
| try: | ||
| from airflow.providers.standard.operators.empty import EmptyOperator | ||
| except ImportError: | ||
| from airflow.operators.empty import EmptyOperator | ||
|
|
||
| from cosmos import DbtTaskGroup, ExecutionConfig, ProfileConfig, ProjectConfig | ||
| from cosmos.constants import ExecutionMode | ||
| from cosmos.profiles import PostgresUserPasswordProfileMapping | ||
|
|
||
| DEFAULT_DBT_ROOT_PATH = Path(__file__).parent.parent / "dags/dbt" | ||
| DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) | ||
| DBT_PROJECT_PATH = DBT_ROOT_PATH / "watcher_downstream_not_skipped" | ||
|
|
||
| profile_config = ProfileConfig( | ||
| profile_name="default", | ||
| target_name="dev", | ||
| profile_mapping=PostgresUserPasswordProfileMapping( | ||
| conn_id="example_conn", | ||
| profile_args={"schema": "public"}, | ||
| disable_event_tracking=True, | ||
| ), | ||
| ) | ||
|
|
||
| execution_config = ExecutionConfig( | ||
| execution_mode=ExecutionMode.WATCHER, | ||
| ) | ||
|
|
||
| operator_args = { | ||
| "install_deps": True, | ||
| "execution_timeout": timedelta(seconds=120), | ||
| } | ||
|
|
||
| if os.getenv("CI"): | ||
| operator_args["trigger_rule"] = "all_success" | ||
|
|
||
| default_args = { | ||
| "retries": 2, | ||
| "retry_delay": timedelta(seconds=0), | ||
| "depends_on_past": True, | ||
| } | ||
|
|
||
| with DAG( | ||
| dag_id="example_watcher_downstream_not_skipped", | ||
| schedule="@daily", | ||
| start_date=datetime(2023, 1, 1), | ||
| catchup=False, | ||
| default_args=default_args, | ||
| ): | ||
| dbt_group = DbtTaskGroup( | ||
| group_id="watcher_downstream_not_skipped", | ||
| execution_config=execution_config, | ||
| project_config=ProjectConfig(DBT_PROJECT_PATH), | ||
| profile_config=profile_config, | ||
| operator_args=operator_args, | ||
| ) | ||
|
|
||
| cleanup = SQLExecuteQueryOperator( | ||
| task_id="drop_fail_once_marker", | ||
| conn_id="example_conn", | ||
| sql="DROP SEQUENCE IF EXISTS public._cosmos_fail_once_seq;", | ||
| trigger_rule="all_done", | ||
| ) | ||
|
|
||
| post_dbt = EmptyOperator(task_id="post_dbt") | ||
|
|
||
| dbt_group >> post_dbt | ||
| dbt_group >> cleanup | ||
Uh oh!
There was an error while loading. Please reload this page.