diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index eb3186551c..f440a1f125 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -64,6 +64,24 @@ def calculate_operator_class( ) +def _is_source_used_by_filtered_nodes(source_node: DbtNode, filtered_nodes: dict[str, DbtNode]) -> bool: + """ + Check if a source node is referenced by any of the filtered nodes. + + :param source_node: The source node to check + :param filtered_nodes: Dictionary of filtered nodes + :returns: True if the source is used by any filtered node, False otherwise + """ + source_id = source_node.unique_id + + # Check if any filtered node depends on this source + for node in filtered_nodes.values(): + if source_id in node.depends_on: + return True + + return False + + def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[str]: """ Return a list of unique_ids for nodes that are not parents (don't have dependencies on other tasks). @@ -257,6 +275,7 @@ def create_task_metadata( dbt_dag_task_group_identifier: str, use_task_group: bool = False, source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE, + source_pruning: bool = False, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, test_behavior: TestBehavior = TestBehavior.AFTER_ALL, @@ -264,6 +283,7 @@ def create_task_metadata( on_warning_callback: Callable[..., Any] | None = None, detached_from_parent: dict[str, DbtNode] | None = None, enable_owner_inheritance: bool | None = None, + filtered_nodes: dict[str, DbtNode] | None = None, ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -330,6 +350,8 @@ def create_task_metadata( ): return None + if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(node, filtered_nodes): + return None task_id, args = _get_task_id_and_args( node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source" ) @@ -398,11 +420,13 @@ def generate_task_or_group( test_behavior: TestBehavior, source_rendering_behavior: SourceRenderingBehavior, test_indirect_selection: TestIndirectSelection, - on_warning_callback: Callable[..., Any] | None, + source_pruning: bool = False, + on_warning_callback: Callable[..., Any] | None = None, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, detached_from_parent: dict[str, DbtNode] | None = None, enable_owner_inheritance: bool | None = None, + filtered_nodes: dict[str, DbtNode] | None = None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None @@ -421,6 +445,7 @@ def generate_task_or_group( dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group), use_task_group=use_task_group, source_rendering_behavior=source_rendering_behavior, + source_pruning=source_pruning, normalize_task_id=normalize_task_id, normalize_task_display_name=normalize_task_display_name, test_behavior=test_behavior, @@ -428,6 +453,7 @@ def generate_task_or_group( on_warning_callback=on_warning_callback, detached_from_parent=detached_from_parent, enable_owner_inheritance=enable_owner_inheritance, + filtered_nodes=filtered_nodes, ) # In most cases, we'll map one DBT node to one Airflow task @@ -631,6 +657,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro node_converters = render_config.node_converters or {} test_behavior = render_config.test_behavior source_rendering_behavior = render_config.source_rendering_behavior + source_pruning = render_config.source_pruning normalize_task_id = render_config.normalize_task_id normalize_task_display_name = render_config.normalize_task_display_name enable_owner_inheritance = render_config.enable_owner_inheritance @@ -663,6 +690,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_args=task_args, test_behavior=test_behavior, source_rendering_behavior=source_rendering_behavior, + source_pruning=source_pruning, test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, normalize_task_id=normalize_task_id, @@ -670,6 +698,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro node=node, detached_from_parent=detached_from_parent, enable_owner_inheritance=enable_owner_inheritance, + filtered_nodes=nodes, ) if task_or_group is not None: logger.debug(f"Conversion of <{node.unique_id}> was successful!") diff --git a/cosmos/config.py b/cosmos/config.py index 5f4d540d9c..e0ad093689 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -69,6 +69,7 @@ class RenderConfig: :param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``. :param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4). :param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). + :param source_pruning: Determines if source nodes without a corresponding downstream task should be removed or not. Default is False :param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache. :param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name. :param normalize_task_display_name: A callable that takes a dbt node as input and returns the task display name. This allows users to assign a custom task display name separate from the node ID. @@ -92,6 +93,7 @@ class RenderConfig: project_path: Path | None = field(init=False) enable_mock_profile: bool = True source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE + source_pruning: bool = False airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list) normalize_task_id: Callable[..., Any] | None = None normalize_task_display_name: Callable[..., Any] | None = None diff --git a/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml b/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml index a3139b5858..4c3a834262 100644 --- a/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml +++ b/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml @@ -6,6 +6,7 @@ sources: database: "{{ env_var('POSTGRES_DB') }}" schema: "{{ env_var('POSTGRES_SCHEMA') }}" tables: + - name: not_connected_source - name: raw_customers columns: - name: id diff --git a/dev/dags/example_source_pruning.py b/dev/dags/example_source_pruning.py new file mode 100644 index 0000000000..c8f6f316f3 --- /dev/null +++ b/dev/dags/example_source_pruning.py @@ -0,0 +1,54 @@ +""" +An example DAG that uses Cosmos to render a dbt project into an Airflow DAG using Cosmos source rendering, and then +prunes unused source nodes. +""" + +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import SourceRenderingBehavior +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START cosmos_source_node_example] + +source_rendering_dag = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "altered_jaffle_shop", + ), + profile_config=profile_config, + operator_args={ + "install_deps": True, # install any necessary dependencies before running any dbt command + "full_refresh": True, # used only in dbt commands that support this flag + }, + render_config=RenderConfig( + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + exclude=[ + "multibyte" + ], # Exclude multibyte model to avoid Airflow 3.0 AssetAlias ASCII validation issues + ), + # normal dag parameters + schedule="@daily", + start_date=datetime(2024, 1, 1), + catchup=False, + dag_id="source_pruning_dag", + default_args={"retries": 0}, + on_warning_callback=lambda context: print(context), +) +# [END cosmos_source_node_example] diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index 26af118fe0..d58553f90f 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -22,7 +22,8 @@ The ``RenderConfig`` class takes the following arguments: - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` - ``airflow_vars_to_purge_cache``: (new in v1.5) Specify Airflow variables that will affect the ``LoadMode.DBT_LS`` cache. See `Caching <./caching.html>`_ for more information. - ``source_rendering_behavior``: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information. -- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task’s display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task’s display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information. +- ``source_pruning``: When set to ``True``, automatically removes (or "prunes") any dbt source nodes from your Airflow DAG that do not have any downstream dependencies within the selected portion of the dbt graph. Defaults to ``False``. See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information. +- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task's display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task's display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information. - ``normalize_task_display_name``: This function allows users to set a custom user-defined function to alter the display name independently of the model name. This way, the task_id can be preserved while the model display name is modified. - ``should_detach_multiple_parents_tests``: A boolean to control if tests that depend on multiple parents should be run as standalone tasks. See `Parsing Methods `_ for more information. - ``enable_owner_inheritance``: (introduced in 1.10.2) A boolean to control if dbt owners should be imported as part of the airflow DAG owners. Defaults to True. diff --git a/docs/configuration/source-nodes-rendering.rst b/docs/configuration/source-nodes-rendering.rst index a773cf5da5..9bfcf0e97b 100644 --- a/docs/configuration/source-nodes-rendering.rst +++ b/docs/configuration/source-nodes-rendering.rst @@ -37,6 +37,32 @@ Example: ) +Source Pruning +~~~~~~~~~~~~~~ + +The ``source_pruning`` is a boolean parameter available in the ``RenderConfig``. +When set to ``True``, it automatically removes (or "prunes") any dbt source nodes from your Airflow DAG that do not have any downstream dependencies within the selected portion of the dbt graph. + +This is particularly useful for keeping your DAGs clean and focused, especially in large dbt projects where you might be selecting only a subset of models to run. + +By default, this option is set to ``False``. + +Example: + +.. code-block:: python + + from cosmos import DbtTaskGroup, RenderConfig + + jaffle_shop_pruned = DbtTaskGroup( + render_config=RenderConfig( + select=["+customers"], # select only customers and its parents + source_pruning=True, + ) + ) + +In this example, if the ``jaffle_shop`` project has multiple sources, but only some of them are upstream of the ``customers`` model, Cosmos will only render the necessary sources and prune the rest. + + on_warning_callback Callback ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 82ca1df655..bbe2186f63 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -654,7 +654,7 @@ def test_load_via_dbt_ls_with_exclude(postgres_profile_config): @pytest.mark.integration @pytest.mark.parametrize( "project_dir,node_count", - [(DBT_PROJECTS_ROOT_DIR / ALTERED_DBT_PROJECT_NAME, 39), (DBT_PROJECTS_ROOT_DIR / "jaffle_shop_python", 28)], + [(DBT_PROJECTS_ROOT_DIR / ALTERED_DBT_PROJECT_NAME, 40), (DBT_PROJECTS_ROOT_DIR / "jaffle_shop_python", 28)], ) def test_load_via_dbt_ls_without_exclude(project_dir, node_count, postgres_profile_config): project_config = ProjectConfig(dbt_project_path=project_dir) diff --git a/tests/dbt/test_pruning.py b/tests/dbt/test_pruning.py new file mode 100644 index 0000000000..9d34a0e3ea --- /dev/null +++ b/tests/dbt/test_pruning.py @@ -0,0 +1,515 @@ +""" +This test suite focuses on the source pruning logic that removes sources +without downstream dependencies from the rendered DAG. +""" + +from datetime import datetime +from pathlib import Path + +from airflow.models import DAG + +from cosmos.airflow.graph import ( + _is_source_used_by_filtered_nodes, + create_task_metadata, + generate_task_or_group, +) +from cosmos.config import ProfileConfig +from cosmos.constants import ( + DbtResourceType, + ExecutionMode, + SourceRenderingBehavior, + TestBehavior, + TestIndirectSelection, +) +from cosmos.dbt.graph import DbtNode +from cosmos.profiles import PostgresUserPasswordProfileMapping + + +class TestSourcePruningLogic: + """Test the core source pruning logic.""" + + def setup_method(self): + """Set up test data for each test method.""" + # Create a graph structure for testing: + # source_used -> model1 -> model2 + # source_orphaned (no downstream dependencies) + # source_with_multiple_deps -> model3 + # -> model4 + + self.source_used = DbtNode( + unique_id="source.test_project.source_used", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_used.yml"), + package_name="test_project", + ) + + self.source_orphaned = DbtNode( + unique_id="source.test_project.source_orphaned", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_orphaned.yml"), + package_name="test_project", + ) + + self.source_with_multiple_deps = DbtNode( + unique_id="source.test_project.source_multiple", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_multiple.yml"), + package_name="test_project", + ) + + self.model1 = DbtNode( + unique_id="model.test_project.model1", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_used"], + file_path=Path("/test/model1.sql"), + package_name="test_project", + ) + + self.model2 = DbtNode( + unique_id="model.test_project.model2", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.model1"], + file_path=Path("/test/model2.sql"), + package_name="test_project", + ) + + self.model3 = DbtNode( + unique_id="model.test_project.model3", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_multiple"], + file_path=Path("/test/model3.sql"), + package_name="test_project", + ) + + self.model4 = DbtNode( + unique_id="model.test_project.model4", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_multiple"], + file_path=Path("/test/model4.sql"), + package_name="test_project", + ) + + def test_is_source_used_by_filtered_nodes_basic(self): + """Test the basic functionality of _is_source_used_by_filtered_nodes.""" + # Test case 1: Source is used by a filtered node + filtered_nodes = { + self.model1.unique_id: self.model1, + self.model2.unique_id: self.model2, + } + + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + # Test case 2: Source is used by multiple filtered nodes + filtered_nodes_multiple = { + self.model3.unique_id: self.model3, + self.model4.unique_id: self.model4, + } + + assert _is_source_used_by_filtered_nodes(self.source_with_multiple_deps, filtered_nodes_multiple) is True + + def test_is_source_used_by_filtered_nodes_empty_filtered_nodes(self): + """Test behavior with empty filtered_nodes.""" + empty_filtered_nodes = {} + + # No filtered nodes means no source should be considered "used" + assert _is_source_used_by_filtered_nodes(self.source_used, empty_filtered_nodes) is False + assert _is_source_used_by_filtered_nodes(self.source_orphaned, empty_filtered_nodes) is False + + def test_is_source_used_by_filtered_nodes_no_dependencies(self): + """Test with filtered nodes that have no dependencies.""" + # Create a model with no dependencies + independent_model = DbtNode( + unique_id="model.test_project.independent", + resource_type=DbtResourceType.MODEL, + depends_on=[], # No dependencies + file_path=Path("/test/independent.sql"), + package_name="test_project", + ) + + filtered_nodes = { + independent_model.unique_id: independent_model, + } + + # Sources should not be considered "used" if no models depend on them + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is False + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + def test_is_source_used_by_filtered_nodes_mixed_dependencies(self): + """Test with filtered nodes that have mixed dependency types.""" + # Create a model that depends on another model AND a source + complex_model = DbtNode( + unique_id="model.test_project.complex", + resource_type=DbtResourceType.MODEL, + depends_on=[ + "model.test_project.model1", # Model dependency + "source.test_project.source_used", # Source dependency + ], + file_path=Path("/test/complex.sql"), + package_name="test_project", + ) + + filtered_nodes = { + self.model1.unique_id: self.model1, + complex_model.unique_id: complex_model, + } + + # Source should be considered "used" even when model has mixed dependencies + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + +class TestSourcePruningIntegration: + """Test source pruning integration with create_task_metadata and generate_task_or_group.""" + + def setup_method(self): + """Set up test data for integration tests.""" + self.dag = DAG("test_source_pruning", start_date=datetime(2023, 1, 1)) + + self.profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ) + + self.task_args = { + "project_dir": Path("/test/project"), + "profile_config": self.profile_config, + } + + # Create test nodes + self.source_with_downstream = DbtNode( + unique_id="source.test_project.source_with_downstream", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_with_downstream.yml"), + package_name="test_project", + has_freshness=True, # Make it renderable + ) + + self.source_orphaned = DbtNode( + unique_id="source.test_project.source_orphaned", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_orphaned.yml"), + package_name="test_project", + has_freshness=True, # Make it renderable + ) + + self.model_using_source = DbtNode( + unique_id="model.test_project.model_using_source", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_with_downstream"], + file_path=Path("/test/model_using_source.sql"), + package_name="test_project", + ) + + # Filtered nodes represent what will actually be rendered + self.filtered_nodes_with_model = { + self.model_using_source.unique_id: self.model_using_source, + } + + self.filtered_nodes_empty = {} + + def test_create_task_metadata_source_pruning_disabled(self): + """Test create_task_metadata when source_pruning is disabled.""" + # Source should be rendered regardless of downstream dependencies + metadata_with_downstream = create_task_metadata( + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=False, # Disabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + metadata_orphaned = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=False, # Disabled + filtered_nodes=self.filtered_nodes_empty, + ) + + # Both sources should generate metadata when pruning is disabled + assert metadata_with_downstream is not None + assert metadata_orphaned is not None + assert metadata_with_downstream.id == "source_with_downstream_source" + assert metadata_orphaned.id == "source_orphaned_source" + + def test_create_task_metadata_source_pruning_enabled(self): + """Test create_task_metadata when source_pruning is enabled.""" + # Source with downstream should be rendered + metadata_with_downstream = create_task_metadata( + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, # Enabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Orphaned source should NOT be rendered + metadata_orphaned = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, # Enabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + assert metadata_with_downstream is not None + assert metadata_orphaned is None # Pruned! + assert metadata_with_downstream.id == "source_with_downstream_source" + + def test_create_task_metadata_source_pruning_edge_cases(self): + """Test create_task_metadata edge cases with filtered_nodes.""" + # When filtered_nodes is None, pruning should be skipped (condition: source_pruning and filtered_nodes) + metadata_none = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes=None, # None means skip pruning check + ) + + # When filtered_nodes is empty, pruning should be skipped (empty dict is falsy) + metadata_empty = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes={}, # Empty dict means skip pruning check + ) + + # When filtered_nodes contains nodes that don't use this source, it should be pruned + unrelated_model = DbtNode( + unique_id="model.test_project.unrelated", + resource_type=DbtResourceType.MODEL, + depends_on=["some.other.source"], # Doesn't depend on source_orphaned + file_path=Path("/test/unrelated.sql"), + package_name="test_project", + ) + metadata_unused = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes={unrelated_model.unique_id: unrelated_model}, # Non-empty but doesn't use this source + ) + + # Looking at the code: if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(...) + # When filtered_nodes is None, the condition fails at "filtered_nodes", so no pruning occurs + # When filtered_nodes is {}, the condition also fails at "filtered_nodes" (empty dict is falsy), so no pruning occurs + # When filtered_nodes is non-empty but doesn't use the source, pruning occurs + assert metadata_none is not None # Not pruned when filtered_nodes is None + assert metadata_empty is not None # Not pruned when filtered_nodes is empty dict (falsy) + assert metadata_unused is None # Pruned when filtered_nodes doesn't use this source + + def test_generate_task_or_group_source_pruning(self): + """Test generate_task_or_group with source pruning.""" + # Source with downstream dependencies should generate a task + task_with_downstream = generate_task_or_group( + dag=self.dag, + task_group=None, + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + task_args=self.task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + test_indirect_selection=TestIndirectSelection.EAGER, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Orphaned source should NOT generate a task + task_orphaned = generate_task_or_group( + dag=self.dag, + task_group=None, + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + task_args=self.task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + test_indirect_selection=TestIndirectSelection.EAGER, + filtered_nodes=self.filtered_nodes_with_model, + ) + + assert task_with_downstream is not None + assert task_orphaned is None + + def test_source_pruning_with_different_source_rendering_behaviors(self): + """Test source pruning interaction with different source rendering behaviors.""" + # Create source without freshness + source_no_freshness = DbtNode( + unique_id="source.test_project.source_no_freshness", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_no_freshness.yml"), + package_name="test_project", + has_freshness=False, # No freshness + ) + + # Test with SourceRenderingBehavior.NONE - should return None regardless of pruning + metadata_none = create_task_metadata( + node=source_no_freshness, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.NONE, + source_pruning=True, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Test with SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS - should return None for source without freshness + metadata_conditional = create_task_metadata( + node=source_no_freshness, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS, + source_pruning=True, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Both should be None due to source rendering behavior, before pruning logic is even reached + assert metadata_none is None + assert metadata_conditional is None + + +class TestSourcePruningEdgeCases: + """Test edge cases and error conditions for source pruning.""" + + def test_non_source_node_with_pruning_function(self): + """Test that _is_source_used_by_filtered_nodes works with non-source nodes.""" + # Create a model node + model_node = DbtNode( + unique_id="model.test_project.some_model", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.some_source"], + file_path=Path("/test/some_model.sql"), + package_name="test_project", + ) + + # Create a filtered node that depends on the model + dependent_model = DbtNode( + unique_id="model.test_project.dependent_model", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.some_model"], + file_path=Path("/test/dependent_model.sql"), + package_name="test_project", + ) + + filtered_nodes = { + dependent_model.unique_id: dependent_model, + } + + # Function should work with any node type, not just sources + assert _is_source_used_by_filtered_nodes(model_node, filtered_nodes) is True + + def test_source_pruning_with_complex_dependencies(self): + """Test source pruning with complex dependency chains.""" + # Create a complex dependency chain + source = DbtNode( + unique_id="source.test_project.raw_data", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/raw_data.yml"), + package_name="test_project", + has_freshness=True, + ) + + staging_model = DbtNode( + unique_id="model.test_project.stg_data", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.raw_data"], + file_path=Path("/test/stg_data.sql"), + package_name="test_project", + ) + + mart_model = DbtNode( + unique_id="model.test_project.mart_data", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.stg_data"], + file_path=Path("/test/mart_data.sql"), + package_name="test_project", + ) + + # Test case 1: Only the mart model is in filtered_nodes + # Source should still be considered "unused" because the direct dependency (staging_model) is not included + filtered_nodes_mart_only = { + mart_model.unique_id: mart_model, + } + + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_mart_only) is False + + # Test case 2: Both staging and mart models are in filtered_nodes + # Source should be considered "used" because staging model directly depends on it + filtered_nodes_both = { + staging_model.unique_id: staging_model, + mart_model.unique_id: mart_model, + } + + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_both) is True + + def test_source_pruning_parameter_defaults(self): + """Test that source_pruning parameter defaults work correctly.""" + # Create minimal test data + source = DbtNode( + unique_id="source.test_project.test_source", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/test_source.yml"), + package_name="test_project", + has_freshness=True, + ) + + dag = DAG("test", start_date=datetime(2023, 1, 1)) + task_args = { + "project_dir": Path("/test"), + "profile_config": ProfileConfig( + profile_name="test", + target_name="test", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + + # Test that generate_task_or_group works without explicitly passing source_pruning + # (should default to False) + task = generate_task_or_group( + dag=dag, + task_group=None, + node=source, + execution_mode=ExecutionMode.LOCAL, + task_args=task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + test_indirect_selection=TestIndirectSelection.EAGER, + # source_pruning not specified - should default to False + # filtered_nodes not specified - should default to None + ) + + # Should create a task because pruning is disabled by default + assert task is not None