diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 54721b04d9..e648999399 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -694,12 +694,13 @@ def _add_watcher_producer_task( producer_task_args["exclude"] = _convert_list_to_str(render_config.exclude) if render_config.test_behavior in [TestBehavior.NONE, TestBehavior.AFTER_ALL]: - additional_excludes = "resource_type:test resource_type:unit_test" - current_exclude = producer_task_args.get("exclude") - if current_exclude: - producer_task_args["exclude"] = f"{current_exclude} {additional_excludes}" - else: - producer_task_args["exclude"] = additional_excludes + # Use --resource-type to exclude tests from the producer dbt build command. + # This works both with and without selectors (--exclude is ignored by dbt when a selector is used). + existing_flags = producer_task_args.get("dbt_cmd_flags") or [] + dbt_cmd_flags = list(existing_flags) + for resource_type in SUPPORTED_BUILD_RESOURCES: + dbt_cmd_flags.extend(["--resource-type", resource_type.value]) # type: ignore[attr-defined] + producer_task_args["dbt_cmd_flags"] = dbt_cmd_flags class_name = calculate_operator_class(execution_mode, "DbtProducer") diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 9e6863164f..db22d8c127 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -6,7 +6,7 @@ import pytest from airflow.models import DAG -from cosmos.operators.watcher import DbtTestWatcherOperator +from cosmos.operators.watcher import DbtProducerWatcherOperator, DbtTestWatcherOperator try: # Airflow 3.1 onwards @@ -33,6 +33,7 @@ ) from cosmos.config import ProfileConfig, RenderConfig from cosmos.constants import ( + SUPPORTED_BUILD_RESOURCES, DbtResourceType, ExecutionMode, SourceRenderingBehavior, @@ -1187,6 +1188,84 @@ def test_test_behavior_for_watcher_mode(test_behavior): assert len(tasks) == 6 +@pytest.mark.parametrize("test_behavior", [TestBehavior.NONE, TestBehavior.AFTER_ALL]) +@pytest.mark.parametrize("selector", [None, "my_selector"]) +def test_watcher_producer_uses_resource_type_flag_to_exclude_tests(test_behavior, selector): + """The producer should use --resource-type to exclude tests, regardless of whether a selector is used.""" + with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + + render_config = RenderConfig(test_behavior=test_behavior) + if selector: + render_config = RenderConfig(test_behavior=test_behavior, selector=selector) + + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.WATCHER, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + render_config=render_config, + dbt_project_name="astro_shop", + ) + producer_task = next(task for task in dag.tasks if isinstance(task, DbtProducerWatcherOperator)) + + # Should use --resource-type flags, not --exclude + expected_resource_types = [rt.value for rt in SUPPORTED_BUILD_RESOURCES] + flags = producer_task.dbt_cmd_flags + actual_resource_types = [flags[i + 1] for i in range(len(flags)) if flags[i] == "--resource-type"] + assert actual_resource_types == expected_resource_types + assert "resource_type:test" not in (producer_task.exclude or "") + + +@pytest.mark.parametrize("test_behavior", [TestBehavior.NONE, TestBehavior.AFTER_ALL]) +def test_watcher_producer_preserves_existing_dbt_cmd_flags(test_behavior): + """The producer should not mutate the original dbt_cmd_flags list and should preserve existing flags.""" + with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: + original_flags = ["--full-refresh"] + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "dbt_cmd_flags": original_flags, + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.WATCHER, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + render_config=RenderConfig(test_behavior=test_behavior), + dbt_project_name="astro_shop", + ) + producer_task = next(task for task in dag.tasks if isinstance(task, DbtProducerWatcherOperator)) + + # Original flags should not be mutated + assert original_flags == ["--full-refresh"] + # Producer should have the original flag plus the resource-type flags + assert "--full-refresh" in producer_task.dbt_cmd_flags + assert "--resource-type" in producer_task.dbt_cmd_flags + + def test_custom_meta(): with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: task_args = {