diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 54721b04d9..c2e996ea1a 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -6,12 +6,15 @@ from typing import Any try: # Airflow 3 + from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.bases.operator import BaseOperator except ImportError: # Airflow 2 from airflow.models import BaseOperator + from airflow.operators.empty import EmptyOperator # type: ignore[no-redef] from airflow.models.base import ID_LEN as AIRFLOW_MAX_ID_LENGTH from airflow.models.dag import DAG +from airflow.utils.trigger_rule import TriggerRule try: # Airflow 3.1 onwards @@ -679,7 +682,7 @@ def _add_watcher_producer_task( render_config: RenderConfig | None = None, execution_mode: ExecutionMode = ExecutionMode.WATCHER, tests_per_model: dict[str, list[str]] | None = None, -) -> BaseOperator: +) -> tuple[BaseOperator, EmptyOperator]: """ Create the producer task for the watcher execution mode and add it to the tasks_map. The producer task is the task that will be used to produce the events for the watcher execution mode. @@ -711,12 +714,22 @@ def _add_watcher_producer_task( ) producer_airflow_task = create_airflow_task(producer_task_metadata, dag, task_group=task_group) tasks_map[PRODUCER_WATCHER_TASK_ID] = producer_airflow_task - return producer_airflow_task + + producer_task_gate = EmptyOperator( # type: ignore[no-untyped-call] + task_id=f"{PRODUCER_WATCHER_TASK_ID}_gate", + dag=dag, + task_group=task_group, + trigger_rule=TriggerRule.NONE_FAILED, + depends_on_past=producer_airflow_task.depends_on_past, + ) + producer_airflow_task >> producer_task_gate + return producer_airflow_task, producer_task_gate def _add_watcher_dependencies( dag: DAG, producer_airflow_task: BaseOperator, + producer_gate: BaseOperator, task_args: dict[str, Any], tasks_map: dict[str, Any], nodes: dict[str, DbtNode] | None = None, @@ -728,7 +741,7 @@ def _add_watcher_dependencies( """ for node_id, task_or_taskgroup in tasks_map.items(): # We do not want to set a dependency between the producer task and itself - if node_id == PRODUCER_WATCHER_TASK_ID: + if node_id == PRODUCER_WATCHER_TASK_ID or node_id == f"{PRODUCER_WATCHER_TASK_ID}_gate": continue node_tasks = ( @@ -758,6 +771,10 @@ def _add_watcher_dependencies( for task in always_run_tasks: task.trigger_rule = task_args.get("trigger_rule", "always") # type: ignore[attr-defined] + # If depends_on_past isn't true then gating all the tasks isn't really needed. + if producer_airflow_task.wait_for_downstream and not task_or_taskgroup.downstream_task_ids: + task_or_taskgroup >> producer_gate + def should_create_detached_nodes(render_config: RenderConfig) -> bool: """ @@ -908,7 +925,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_groups: dict[str, TaskGroup] = {} task_or_group: TaskGroup | BaseOperator | None parent_task_group = task_group - producer_task: BaseOperator | None = None + producer_tasks: tuple[BaseOperator, EmptyOperator] | None = None # Identify test nodes that should be run detached from the associated dbt resource nodes because they # have multiple parents @@ -926,7 +943,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro # We are intentionally creating the producer task ahead of the consumer tasks # Airflow priority weight is not being respected in multiple versions of the library, including 3.1 # To instantiate the producer before helps having it before on the DAG topological order and scheduling this task before the consumer tasks - producer_task = _add_watcher_producer_task( + producer_tasks = _add_watcher_producer_task( dag=dag, task_args={**task_args, **setup_operator_args}, tasks_map=tasks_map, @@ -988,6 +1005,9 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes) for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task + if producer_tasks and producer_tasks[0].depends_on_past: + test_task >> producer_tasks[1] + test_task.wait_for_downstream = True elif render_config.test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH): # Handle detached test nodes for node_id, node in detached_nodes.items(): @@ -1012,10 +1032,11 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro create_airflow_task_dependencies(nodes, tasks_map) - if producer_task: + if producer_tasks: _add_watcher_dependencies( dag=dag, - producer_airflow_task=producer_task, + producer_airflow_task=producer_tasks[0], + producer_gate=producer_tasks[1], task_args=task_args, tasks_map=tasks_map, nodes=nodes, diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 5142645bc6..a4e6b68625 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -311,6 +311,9 @@ def __init__( self.deferrable = deferrable self.model_unique_id = extra_context.get("dbt_node_config", {}).get("unique_id") + if self.depends_on_past: + self.wait_for_downstream = True + @property def is_test_sensor(self) -> bool: """Whether this sensor watches aggregated test results instead of individual node results.""" @@ -558,11 +561,10 @@ def poke(self, context: Context) -> bool: _log_dbt_event(dbt_events) if status is None: - - if producer_task_state == "failed": + if producer_task_state == "failed" or producer_task_state == "skipped": if self.poke_retry_number > 0: raise AirflowException( - f"The dbt build command failed in producer task. Please check the log of task {self.producer_task_id} for details." + f"The dbt build command {producer_task_state} in the producer task. Please check the log of task {self.producer_task_id} for details." ) else: # This handles the scenario of tasks that failed with `State.UPSTREAM_FAILED` diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index b9a91f4f79..1fc075afc7 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -35,7 +35,6 @@ class WatcherEventReason(str, Enum): class WatcherTrigger(BaseTrigger): - def __init__( self, model_unique_id: str, @@ -237,10 +236,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: event_data["compiled_sql"] = compiled_sql yield TriggerEvent(event_data) # type: ignore[no-untyped-call] return - elif producer_task_state == "failed": + elif producer_task_state == "failed" or producer_task_state == "skipped": logger.error( - "Watcher producer task '%s' failed before delivering results for node '%s'", + "Watcher producer task '%s' %s before delivering results for node '%s'", self.producer_task_id, + producer_task_state, self.model_unique_id, ) yield TriggerEvent({"status": EventStatus.FAILED, "reason": WatcherEventReason.PRODUCER_FAILED}) # type: ignore[no-untyped-call] diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 85eb90dccb..1a0a461c43 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from cosmos.config import ProfileConfig from cosmos.constants import ( @@ -106,6 +106,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if self.invocation_mode == InvocationMode.SUBPROCESS: self.log_format = "json" + if self.depends_on_past: + self.wait_for_downstream = True + @staticmethod def _serialize_event(event_message: EventMsg) -> dict[str, Any]: """Convert structured dbt EventMsg to plain dict.""" @@ -173,12 +176,10 @@ def execute(self, context: Context, **kwargs: Any) -> Any: try_number = getattr(task_instance, "try_number", 1) if try_number > 1: - self.log.info( - "Dbt WATCHER producer task does not support Airflow retries. " - "Detected attempt #%s; skipping execution to avoid running a second dbt build.", - try_number, + raise AirflowSkipException( + "DbtProducerWatcherOperator does not support Airflow retries. " + f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build." ) - return None try: use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None @@ -219,9 +220,10 @@ def _callback(event_message: EventMsg) -> None: safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") return return_value - except Exception: + except Exception as e: safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") - raise + self.log.exception("DbtProducerWatcherOperator execution failed") + raise AirflowSkipException("Skipping execution due to task failure") from e class DbtConsumerWatcherSensor(BaseConsumerSensor, DbtRunLocalOperator): # type: ignore[misc] diff --git a/cosmos/operators/watcher_kubernetes.py b/cosmos/operators/watcher_kubernetes.py index 76b4da4e18..0614052fb7 100644 --- a/cosmos/operators/watcher_kubernetes.py +++ b/cosmos/operators/watcher_kubernetes.py @@ -13,7 +13,7 @@ from airflow.utils.context import Context # type: ignore[attr-defined] import kubernetes.client as k8s -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback, client_type try: @@ -44,7 +44,6 @@ class WatcherKubernetesCallback(KubernetesPodOperatorCallback): # type: ignore[misc] - @staticmethod def progress_callback( *, @@ -74,7 +73,6 @@ def progress_callback( class DbtProducerWatcherKubernetesOperator(DbtBuildKubernetesOperator): - template_fields: tuple[str, ...] = tuple(DbtBuildKubernetesOperator.template_fields) + ("deferrable",) _process_log_line_callable: Callable[[str, dict[str, Any]], None] | None = store_dbt_resource_status_from_log @@ -93,6 +91,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(task_id=task_id, *args, **kwargs) self.dbt_cmd_flags += ["--log-format", "json"] + if self.depends_on_past: + self.wait_for_downstream = True + @cached_property def pod_manager(self) -> CosmosKubernetesPodManager: return CosmosKubernetesPodManager(kube_client=self.client, callbacks=self.callbacks) @@ -107,18 +108,31 @@ def execute(self, context: Context, **kwargs: Any) -> Any: try_number = getattr(task_instance, "try_number", 1) if try_number > 1: - self.log.info( + raise AirflowSkipException( "DbtProducerWatcherKubernetesOperator does not support Airflow retries. " - "Detected attempt #%s; skipping execution to avoid running a second dbt build.", - try_number, + f"Detected attempt #{try_number}; skipping execution to avoid running a second dbt build." ) - return None # This global variable is used to make the task context available to the K8s callback. # While the callback is set during the operator initialization, the context is only created during the operator's execution. global producer_task_context producer_task_context = context - return super().execute(context, **kwargs) + try: + return super().execute(context, **kwargs) + except (AirflowSkipException, TaskDeferred): + raise + except Exception as e: + self.log.exception("Dbt execution failed") + raise AirflowSkipException("Skipping execution due to task failure") from e + + def trigger_reentry(self, *args: Any, **kwargs: Any) -> Any: + try: + return super().trigger_reentry(*args, **kwargs) + except (AirflowSkipException, TaskDeferred): + raise + except Exception as e: + self.log.exception("Dbt execution failed") + raise AirflowSkipException("Skipping execution due to task failure") from e class DbtConsumerWatcherKubernetesSensor(BaseConsumerSensor, DbtRunKubernetesOperator): diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 9e6863164f..e68a818bb7 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -107,7 +107,6 @@ tags=["nightly"], config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}}, ) - sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, child2_node] sample_nodes = {node.unique_id: node for node in sample_nodes_list} @@ -1178,13 +1177,87 @@ def test_test_behavior_for_watcher_mode(test_behavior): tasks = dag.tasks if test_behavior == TestBehavior.NONE: for task in tasks: - assert not isinstance(task, DbtTestWatcherOperator or DbtTestLocalOperator) - assert len(tasks) == 5 - if test_behavior == TestBehavior.AFTER_EACH: + assert not isinstance(task, (DbtTestWatcherOperator, DbtTestLocalOperator)) assert len(tasks) == 6 + if test_behavior == TestBehavior.AFTER_EACH: + assert len(tasks) == 7 if test_behavior == TestBehavior.AFTER_ALL: assert any(isinstance(task, DbtTestLocalOperator) for task in tasks) - assert len(tasks) == 6 + assert len(tasks) == 7 + + +@pytest.mark.parametrize("depends_on_past", [False, True]) +@pytest.mark.parametrize("test_behavior", [TestBehavior.NONE, TestBehavior.AFTER_EACH, TestBehavior.AFTER_ALL]) +def test_watcher_dependency_wiring(test_behavior, depends_on_past): + with DAG("test-id", start_date=datetime(2022, 1, 1), default_args={"depends_on_past": depends_on_past}) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + + child_2b = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.child2.v2_b", + resource_type=DbtResourceType.MODEL, + depends_on=[parent_node.unique_id], + path_base=SAMPLE_PROJ_PATH, + original_file_path=Path("gen3/models/child2_v2.sql"), + tags=["nightly"], + config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}}, + has_test=True, + has_non_detached_test=True, + ) + child_2b_test = DbtNode( + unique_id=f"{DbtResourceType.TEST.value}.{SAMPLE_PROJ_PATH.stem}.child2.test_v2_b", + resource_type=DbtResourceType.TEST, + depends_on=[child_2b.unique_id], + path_base=Path("."), + original_file_path=Path("."), + ) + + build_airflow_graph( + nodes={child_2b.unique_id: child_2b, child_2b_test.unique_id: child_2b_test, **sample_nodes}, + dag=dag, + execution_mode=ExecutionMode.WATCHER, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + render_config=RenderConfig( + test_behavior=test_behavior, + ), + dbt_project_name="astro_shop", + ) + if not depends_on_past: + assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == {"dbt_producer_watcher"} + assert all(task.wait_for_downstream is False for task in dag.tasks) + return + + assert all(task.wait_for_downstream is True for task in dag.tasks if task.task_id != "dbt_producer_watcher_gate") + if test_behavior == TestBehavior.NONE: + assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == { + "child_run", + "dbt_producer_watcher", + "child2_v2_run", + "child2_v2_b_run", + } + if test_behavior == TestBehavior.AFTER_EACH: + assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == { + "child_run", + "dbt_producer_watcher", + "child2_v2_run", + "child2_v2_b.test", + } + if test_behavior == TestBehavior.AFTER_ALL: + assert dag.task_dict["dbt_producer_watcher_gate"].upstream_task_ids == { + "dbt_producer_watcher", + "astro_shop_test", + } def test_custom_meta(): diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 5a2b0173b2..106ece03dd 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -11,7 +11,7 @@ from unittest.mock import ANY, MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.utils.state import DagRunState try: @@ -217,8 +217,9 @@ class TestException(Exception): with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: mock_execute.side_effect = TestException("test error") - with pytest.raises(TestException): + with pytest.raises(AirflowSkipException) as exc: op.execute(context=context) + assert isinstance(exc.value.__cause__, TestException) # Verify completed status was pushed even in failure case assert mock_ti.store.get("task_status") == "completed" @@ -271,20 +272,36 @@ def test_dbt_consumer_watcher_sensor_execute_complete_model_not_run_logs_message assert any("ephemeral model or if the model sql file is empty" in message for message in caplog.messages) -def test_dbt_producer_watcher_operator_skips_retry_attempt(caplog): +def test_dbt_producer_watcher_operator_logs_retry_message(caplog): op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) ti = _MockTI() - ti.try_number = 2 + ti.try_number = 1 context = {"ti": ti} - with patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute: + with patch("cosmos.operators.local.DbtLocalBaseOperator.execute", return_value="ok") as mock_execute: with caplog.at_level(logging.INFO): - result = op.execute(context=context) + op.execute(context=context) + + mock_execute.assert_called_once() + assert any("forces Airflow retries to 0" in message for message in caplog.messages) + + +def test_dbt_producer_watcher_operator_skips_retry_attempt(): + op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + ti = _MockTI() + ti.try_number = 2 + context = {"ti": ti} + + with ( + patch("cosmos.operators.local.DbtLocalBaseOperator.execute") as mock_execute, + pytest.raises( + AirflowSkipException, + match="DbtProducerWatcherOperator does not support Airflow retries. Detected attempt #2; skipping execution to avoid running a second dbt build.", + ), + ): + op.execute(context=context) mock_execute.assert_not_called() - assert result is None - assert any("does not support Airflow retries" in message for message in caplog.messages) - assert any("skipping execution" in message for message in caplog.messages) @pytest.mark.parametrize( @@ -1293,7 +1310,7 @@ def test_producer_state_failed(self, mock_get_xcom_val): with pytest.raises( AirflowException, - match="The dbt build command failed in producer task. Please check the log of task dbt_producer_watcher for details.", + match="The dbt build command failed in the producer task. Please check the log of task dbt_producer_watcher for details.", ): sensor.poke(context) @@ -1525,14 +1542,19 @@ def test_dbt_dag_with_watcher(capsys): invocation_mode=InvocationMode.DBT_RUNNER, ), render_config=RenderConfig(emit_datasets=False), - operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, + operator_args={ + "trigger_rule": "all_success", + "execution_timeout": timedelta(seconds=120), + "depends_on_past": True, + }, ) + outcome = new_test_dag(watcher_dag) assert outcome.state == DagRunState.SUCCESS assert len(watcher_dag.dbt_graph.filtered_nodes) == 23 - assert len(watcher_dag.task_dict) == 14 + assert len(watcher_dag.task_dict) == 15 tasks_names = [task.task_id for task in watcher_dag.topological_sort()] expected_task_names = [ "dbt_producer_watcher", @@ -1549,6 +1571,7 @@ def test_dbt_dag_with_watcher(capsys): "customers.test", "orders.run", "orders.test", + "dbt_producer_watcher_gate", ] assert tasks_names == expected_task_names @@ -1571,8 +1594,22 @@ def test_dbt_dag_with_watcher(capsys): "raw_payments_seed", "raw_orders_seed", "raw_customers_seed", + "dbt_producer_watcher_gate", } + assert ( + watcher_dag.task_dict["dbt_producer_watcher"].depends_on_past + and watcher_dag.task_dict["dbt_producer_watcher"].wait_for_downstream + ) + assert ( + watcher_dag.task_dict["customers.run"].depends_on_past + and watcher_dag.task_dict["customers.run"].wait_for_downstream + ) + assert ( + watcher_dag.task_dict["dbt_producer_watcher_gate"].depends_on_past + and not watcher_dag.task_dict["dbt_producer_watcher_gate"].wait_for_downstream + ) + # dbt runner logs are not captured by caplog, so we need to capture them using capsys capsys_output = capsys.readouterr() stdout = capsys_output.out @@ -1620,7 +1657,7 @@ def test_dbt_dag_with_watcher_and_group_nodes_by_folder(capsys): assert len(watcher_dag.dbt_graph.filtered_nodes) == 6 # 3 seeds + 3 models task_ids = set(watcher_dag.task_dict) # 1 producer + 3 seeds + 3 model runs + 1 after_all test = 8 - assert len(task_ids) == 8 + assert len(task_ids) == 9 assert "dbt_producer_watcher" in task_ids assert "seeds.seeds_a.products_seed" in task_ids assert "seeds.seeds_b.regions_seed" in task_ids @@ -1644,6 +1681,7 @@ def test_dbt_dag_with_watcher_and_group_nodes_by_folder(capsys): "seeds.seeds_b.regions_seed", "seeds.seeds_a.products_seed", "seeds.seeds_b.region_managers_seed", + "dbt_producer_watcher_gate", } assert watcher_dag.task_dict["seeds.seeds_a.products_seed"].downstream_task_ids == { "models.models_a.stg_products_run" @@ -1682,9 +1720,9 @@ def test_dbt_dag_with_watcher_and_subprocess(caplog): assert len(watcher_dag.dbt_graph.filtered_nodes) == 1 - assert len(watcher_dag.task_dict) == 3 + assert len(watcher_dag.task_dict) == 4 tasks_names = [task.task_id for task in watcher_dag.topological_sort()] - expected_task_names = ["dbt_producer_watcher", "raw_orders_seed", "jaffle_shop_test"] + expected_task_names = ["dbt_producer_watcher", "dbt_producer_watcher_gate", "raw_orders_seed", "jaffle_shop_test"] assert tasks_names == expected_task_names # Confirm that the dbt command was successfully run using the given dbt executable path: assert "venv-subprocess/bin/dbt'), 'build'" in caplog.text @@ -1763,10 +1801,11 @@ def test_dbt_dag_with_watcher_and_empty_model(caplog): assert len(watcher_dag.dbt_graph.filtered_nodes) == 2 - assert len(watcher_dag.task_dict) == 3 + assert len(watcher_dag.task_dict) == 4 tasks_names = [task.task_id for task in watcher_dag.topological_sort()] expected_task_names = [ "dbt_producer_watcher", + "dbt_producer_watcher_gate", "add_row_run", "empty_model_run", ] @@ -1779,6 +1818,7 @@ def test_dbt_dag_with_watcher_and_empty_model(caplog): assert watcher_dag.task_dict["dbt_producer_watcher"].downstream_task_ids == { "add_row_run", "empty_model_run", + "dbt_producer_watcher_gate", } assert "Total filtered nodes: 2" in caplog.text @@ -1841,12 +1881,13 @@ def test_dbt_task_group_with_watcher(): # assert outcome.state == DagRunState.SUCCESS # Fortunately, when we trigger the DAG run manually, the weight is being respected and the producer task is being picked up in advance. - assert len(dag_dbt_task_group_watcher.task_dict) == 10 + assert len(dag_dbt_task_group_watcher.task_dict) == 11 tasks_names = [task.task_id for task in dag_dbt_task_group_watcher.topological_sort()] expected_task_names = [ "pre_dbt", "dbt_task_group.dbt_producer_watcher", + "dbt_task_group.dbt_producer_watcher_gate", "dbt_task_group.raw_customers_seed", "dbt_task_group.raw_orders_seed", "dbt_task_group.raw_payments_seed", @@ -1861,6 +1902,7 @@ def test_dbt_task_group_with_watcher(): assert isinstance( dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher"], DbtProducerWatcherOperator ) + assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher_gate"], EmptyOperator) assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.raw_customers_seed"], DbtSeedWatcherOperator) assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.raw_orders_seed"], DbtSeedWatcherOperator) assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.raw_payments_seed"], DbtSeedWatcherOperator) @@ -1870,7 +1912,9 @@ def test_dbt_task_group_with_watcher(): assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.customers_run"], DbtRunWatcherOperator) assert isinstance(dag_dbt_task_group_watcher.task_dict["dbt_task_group.orders_run"], DbtRunWatcherOperator) - assert dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher"].downstream_task_ids == set() + assert dag_dbt_task_group_watcher.task_dict["dbt_task_group.dbt_producer_watcher"].downstream_task_ids == { + "dbt_task_group.dbt_producer_watcher_gate" + } @pytest.mark.integration diff --git a/tests/operators/test_watcher_kubernetes_integration.py b/tests/operators/test_watcher_kubernetes_integration.py index dfb4e99897..4c2e69b3dd 100644 --- a/tests/operators/test_watcher_kubernetes_integration.py +++ b/tests/operators/test_watcher_kubernetes_integration.py @@ -103,7 +103,7 @@ def test_dbt_dag_with_watcher_kubernetes(): run_dag(dag_dbt_watcher_kubernetes) - assert len(dag_dbt_watcher_kubernetes.task_dict) == 9 + assert len(dag_dbt_watcher_kubernetes.task_dict) == 10 tasks_names = [task.task_id for task in dag_dbt_watcher_kubernetes.topological_sort()] expected_task_names = [ @@ -116,6 +116,7 @@ def test_dbt_dag_with_watcher_kubernetes(): "stg_payments_run", "customers_run", "orders_run", + "dbt_producer_watcher_gate", ] assert tasks_names == expected_task_names @@ -136,6 +137,7 @@ def test_dbt_dag_with_watcher_kubernetes(): "raw_payments_seed", "raw_orders_seed", "raw_customers_seed", + "dbt_producer_watcher_gate", } assert ( dag_dbt_watcher_kubernetes.task_dict["dbt_producer_watcher"].downstream_task_ids == expected_downstream_task_ids diff --git a/tests/operators/test_watcher_kubernetes_unit.py b/tests/operators/test_watcher_kubernetes_unit.py index 5ba35845e5..5b42679c62 100644 --- a/tests/operators/test_watcher_kubernetes_unit.py +++ b/tests/operators/test_watcher_kubernetes_unit.py @@ -1,12 +1,12 @@ -import logging import os from pathlib import Path from unittest.mock import MagicMock, patch import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.providers.cncf.kubernetes import __version__ as airflow_k8s_provider_version from airflow.providers.cncf.kubernetes.secret import Secret +from kubernetes.client import ApiException from packaging.version import Version from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig @@ -78,7 +78,7 @@ @patch("cosmos.operators.kubernetes.DbtBuildKubernetesOperator.execute") -def test_skips_retry_attempt(mock_execute, caplog): +def test_skips_retry_attempt(mock_execute): """ Test that the operator skips execution when a retry is attempted (try_number > 1). """ @@ -92,13 +92,13 @@ def test_skips_retry_attempt(mock_execute, caplog): ti.try_number = 2 context = {"ti": ti} - with caplog.at_level(logging.INFO): - result = op.execute(context=context) + with pytest.raises( + AirflowSkipException, + match="DbtProducerWatcherKubernetesOperator does not support Airflow retries. Detected attempt #2; skipping execution to avoid running a second dbt build.", + ): + op.execute(context=context) mock_execute.assert_not_called() - assert result is None - assert any("does not support Airflow retries" in message for message in caplog.messages) - assert any("skipping execution" in message for message in caplog.messages) def test_raises_exception_when_task_instance_missing(): @@ -339,3 +339,69 @@ def test_callbacks_included_in_producer_operator(): ) callback_classes = [callback.__name__ for callback in op.callbacks] assert "WatcherKubernetesCallback" in callback_classes + + +def test_exceptions_converted_to_airflow_skip_exception(): + """ + Test that if exceptions are raised during execute, they are converted to an AirflowSkipException. + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + ) + ti = MagicMock() + ti.try_number = 1 + context = make_context(ti) + + with patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") as mock_execute: + mock_execute.side_effect = ApiException("Kubernetes API error") + + with pytest.raises(AirflowSkipException, match="Skipping execution due to task failure") as execinfo: + op.execute(context=context) + + assert execinfo.value.__cause__ == mock_execute.side_effect + + with patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") as mock_execute: + mock_execute.side_effect = AirflowException("Airflow exception during execution") + with pytest.raises(AirflowSkipException, match="Skipping execution due to task failure") as execinfo: + op.execute(context=context) + assert execinfo.value.__cause__ == mock_execute.side_effect + + # Ensure deferred exceptions are not caught and converted to AirflowSkipException + with patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") as mock_execute: + mock_execute.side_effect = TaskDeferred(trigger="some_trigger", method_name="trigger_reentry") + + with pytest.raises(TaskDeferred): + op.execute(context=context) + + with patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry" + ) as mock_trigger_reentry: + mock_trigger_reentry.side_effect = ApiException("Kubernetes API error") + with pytest.raises(AirflowSkipException, match="Skipping execution due to task failure") as execinfo: + op.trigger_reentry(context=context) + + assert execinfo.value.__cause__ == mock_trigger_reentry.side_effect + + with patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry" + ) as mock_trigger_reentry: + mock_trigger_reentry.side_effect = TaskDeferred(trigger="some_trigger", method_name="trigger_reentry") + + with pytest.raises(TaskDeferred): + op.trigger_reentry(context=context) + + +@pytest.mark.parametrize("depends_on_past", [False, True]) +def test_depends_on_past_sets_wait_for_downstream(depends_on_past): + """ + Test that if depends_on_past is True, wait_for_downstream is also set to True. + """ + op = DbtProducerWatcherKubernetesOperator( + project_dir=".", + profile_config=None, + image="dbt-image:latest", + depends_on_past=depends_on_past, + ) + assert op.wait_for_downstream == depends_on_past diff --git a/tests/test_dbtf.py b/tests/test_dbtf.py index d9ccdd936b..287692192a 100644 --- a/tests/test_dbtf.py +++ b/tests/test_dbtf.py @@ -102,5 +102,6 @@ def test_dbt_fusion(dag_id, execution_config, profile_config): ] if execution_config.execution_mode == ExecutionMode.WATCHER: expected_task_names.insert(0, "dbt_producer_watcher") + expected_task_names.insert(1, "dbt_producer_watcher_gate") assert tasks_names == expected_task_names