From be8e5d0dc78bc68f4f5f84330bc2d5fe2e8d43c5 Mon Sep 17 00:00:00 2001 From: Giovanni Corsetti <155465603+corsettigyg@users.noreply.github.com> Date: Fri, 19 Sep 2025 07:23:58 +0200 Subject: [PATCH 01/26] Add source_prunning logic to cosmos --- cosmos/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cosmos/config.py b/cosmos/config.py index e1a4a0c353..0551450f94 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -63,6 +63,7 @@ class RenderConfig: :param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``. :param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4). :param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). + :param source_prunning: Determines if sources without a corresponding downstream task should be removed or not. Default is False :param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache. :param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name. :param normalize_task_display_name: A callable that takes a dbt node as input and returns the task display name. This allows users to assign a custom task display name separate from the node ID. @@ -86,6 +87,7 @@ class RenderConfig: project_path: Path | None = field(init=False) enable_mock_profile: bool = True source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE + source_prunning: bool = False airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list) normalize_task_id: Callable[..., Any] | None = None normalize_task_display_name: Callable[..., Any] | None = None From d94314e98cdb4fffcafb1021710c1299f0524682 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 12:26:56 +0200 Subject: [PATCH 02/26] add source prunning logic --- cosmos/airflow/graph.py | 9 ++- cosmos/dbt/graph.py | 60 +++++++++++++++++++ .../models/staging/sources.yml | 1 + dev/dags/example_source_prunning.py | 48 +++++++++++++++ docs/configuration/source-nodes-rendering.rst | 26 ++++++++ 5 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 dev/dags/example_source_prunning.py diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index eb3186551c..74fa6e6339 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -257,6 +257,7 @@ def create_task_metadata( dbt_dag_task_group_identifier: str, use_task_group: bool = False, source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE, + source_prunning: bool = False, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, test_behavior: TestBehavior = TestBehavior.AFTER_ALL, @@ -330,6 +331,8 @@ def create_task_metadata( ): return None + if source_prunning and not node.downstream_nodes: + return None task_id, args = _get_task_id_and_args( node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source" ) @@ -397,8 +400,9 @@ def generate_task_or_group( task_args: dict[str, Any], test_behavior: TestBehavior, source_rendering_behavior: SourceRenderingBehavior, + source_prunning: bool, test_indirect_selection: TestIndirectSelection, - on_warning_callback: Callable[..., Any] | None, + on_warning_callback: Callable[..., Any] | None = None, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, detached_from_parent: dict[str, DbtNode] | None = None, @@ -421,6 +425,7 @@ def generate_task_or_group( dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group), use_task_group=use_task_group, source_rendering_behavior=source_rendering_behavior, + source_prunning=source_prunning, normalize_task_id=normalize_task_id, normalize_task_display_name=normalize_task_display_name, test_behavior=test_behavior, @@ -631,6 +636,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro node_converters = render_config.node_converters or {} test_behavior = render_config.test_behavior source_rendering_behavior = render_config.source_rendering_behavior + source_prunning = render_config.source_prunning normalize_task_id = render_config.normalize_task_id normalize_task_display_name = render_config.normalize_task_display_name enable_owner_inheritance = render_config.enable_owner_inheritance @@ -663,6 +669,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_args=task_args, test_behavior=test_behavior, source_rendering_behavior=source_rendering_behavior, + source_prunning=source_prunning, test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, normalize_task_id=normalize_task_id, diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ac17a02692..ccab4c2d32 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -83,6 +83,7 @@ class DbtNode: config: dict[str, Any] = field(default_factory=lambda: {}) has_freshness: bool = False has_test: bool = False + _graph_nodes: dict[str, "DbtNode"] = field(default_factory=dict, init=False, repr=False) @property def meta(self) -> Dict[str, Any]: @@ -170,8 +171,42 @@ def context_dict(self) -> dict[str, Any]: "has_test": self.has_test, "resource_name": self.resource_name, "name": self.name, + "downstream_ids": self.downstream_ids, } + @property + def downstream_nodes(self) -> list["DbtNode"]: + """ + Find all nodes that depend on this node (downstream dependencies). + + :returns: List of DbtNode objects that depend on this node + """ + downstream_nodes = [] + for node in self._graph_nodes.values(): + if self.unique_id in node.depends_on: + downstream_nodes.append(node) + return downstream_nodes + + @property + def downstream_ids(self) -> list[str]: + """ + Find all node IDs that depend on this node (downstream dependencies). + + :returns: List of unique_ids of nodes that depend on this node + """ + downstream_ids = [] + for node in self._graph_nodes.values(): + if self.unique_id in node.depends_on: + downstream_ids.append(node.unique_id) + return downstream_ids + + def _set_graph_nodes(self, graph_nodes: dict[str, "DbtNode"]) -> None: + """ + Internal method to set the graph nodes reference. + This should only be called by DbtGraph after loading nodes. + """ + self._graph_nodes = graph_nodes + def is_freshness_effective(freshness: Optional[dict[str, Any]]) -> bool: """Function to find if a source has null freshness. Scenarios where freshness @@ -562,6 +597,7 @@ def load( load_method[method]() self.update_node_dependency() + self._set_graph_references() logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.filtered_nodes)) @@ -963,3 +999,27 @@ def update_node_dependency(self) -> None: if node_id in self.filtered_nodes: self.filtered_nodes[node_id].has_test = True self.filtered_nodes[node.unique_id] = node + + def _set_graph_references(self) -> None: + """ + Set the graph nodes reference for all nodes so they can compute downstream dependencies. + This should be called after loading and filtering nodes. + """ + for node in self.nodes.values(): + node._set_graph_nodes(self.filtered_nodes) + + def get_downstream_nodes(self, node_unique_id: str, use_filtered: bool = True) -> list[DbtNode]: + """ + Get all downstream nodes for a given node. + + :param node_unique_id: The unique_id of the node to find downstream dependencies for + :param use_filtered: Whether to search in filtered_nodes (True) or all nodes (False) + :returns: List of DbtNode objects that depend on the specified node + """ + nodes_dict = self.filtered_nodes if use_filtered else self.nodes + + if node_unique_id not in nodes_dict: + return [] + + node = nodes_dict[node_unique_id] + return node.downstream_nodes diff --git a/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml b/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml index a3139b5858..4c3a834262 100644 --- a/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml +++ b/dev/dags/dbt/altered_jaffle_shop/models/staging/sources.yml @@ -6,6 +6,7 @@ sources: database: "{{ env_var('POSTGRES_DB') }}" schema: "{{ env_var('POSTGRES_SCHEMA') }}" tables: + - name: not_connected_source - name: raw_customers columns: - name: id diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_prunning.py new file mode 100644 index 0000000000..bc6e4afd6a --- /dev/null +++ b/dev/dags/example_source_prunning.py @@ -0,0 +1,48 @@ +""" +An example DAG that uses Cosmos to render a dbt project into an Airflow DAG using Cosmos source rendering. +""" + +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ProfileConfig, ProjectConfig, RenderConfig +from cosmos.constants import SourceRenderingBehavior +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START cosmos_source_node_example] + +source_rendering_dag = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "altered_jaffle_shop", + ), + profile_config=profile_config, + operator_args={ + "install_deps": True, # install any necessary dependencies before running any dbt command + "full_refresh": True, # used only in dbt commands that support this flag + }, + render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, + source_prunning=True), + # normal dag parameters + schedule="@daily", + start_date=datetime(2024, 1, 1), + catchup=False, + dag_id="source_prunning_dag", + default_args={"retries": 0}, + on_warning_callback=lambda context: print(context), +) +# [END cosmos_source_node_example] diff --git a/docs/configuration/source-nodes-rendering.rst b/docs/configuration/source-nodes-rendering.rst index a773cf5da5..864d9b4d70 100644 --- a/docs/configuration/source-nodes-rendering.rst +++ b/docs/configuration/source-nodes-rendering.rst @@ -37,6 +37,32 @@ Example: ) +Source Pruning +~~~~~~~~~~~~~~ + +The ``source_prunning`` is a boolean parameter available in the ``RenderConfig``. +When set to ``True``, it automatically removes (or "prunes") any dbt source nodes from your Airflow DAG that do not have any downstream dependencies within the selected portion of the dbt graph. + +This is particularly useful for keeping your DAGs clean and focused, especially in large dbt projects where you might be selecting only a subset of models to run. + +By default, this option is set to ``False``. + +Example: + +.. code-block:: python + + from cosmos import DbtTaskGroup, RenderConfig + + jaffle_shop_pruned = DbtTaskGroup( + render_config=RenderConfig( + select=["+customers"], # select only customers and its parents + source_prunning=True, + ) + ) + +In this example, if the ``jaffle_shop`` project has multiple sources, but only some of them are upstream of the ``customers`` model, Cosmos will only render the necessary sources and prune the rest. + + on_warning_callback Callback ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 25800b7f3daf5f7d1c8bf228ff1ae67c16a5f44b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 11:01:55 +0000 Subject: [PATCH 03/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/dbt/graph.py | 16 ++++++++-------- dev/dags/example_source_prunning.py | 3 +-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ccab4c2d32..0a146e1059 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -83,7 +83,7 @@ class DbtNode: config: dict[str, Any] = field(default_factory=lambda: {}) has_freshness: bool = False has_test: bool = False - _graph_nodes: dict[str, "DbtNode"] = field(default_factory=dict, init=False, repr=False) + _graph_nodes: dict[str, DbtNode] = field(default_factory=dict, init=False, repr=False) @property def meta(self) -> Dict[str, Any]: @@ -175,10 +175,10 @@ def context_dict(self) -> dict[str, Any]: } @property - def downstream_nodes(self) -> list["DbtNode"]: + def downstream_nodes(self) -> list[DbtNode]: """ Find all nodes that depend on this node (downstream dependencies). - + :returns: List of DbtNode objects that depend on this node """ downstream_nodes = [] @@ -191,7 +191,7 @@ def downstream_nodes(self) -> list["DbtNode"]: def downstream_ids(self) -> list[str]: """ Find all node IDs that depend on this node (downstream dependencies). - + :returns: List of unique_ids of nodes that depend on this node """ downstream_ids = [] @@ -200,7 +200,7 @@ def downstream_ids(self) -> list[str]: downstream_ids.append(node.unique_id) return downstream_ids - def _set_graph_nodes(self, graph_nodes: dict[str, "DbtNode"]) -> None: + def _set_graph_nodes(self, graph_nodes: dict[str, DbtNode]) -> None: """ Internal method to set the graph nodes reference. This should only be called by DbtGraph after loading nodes. @@ -1011,15 +1011,15 @@ def _set_graph_references(self) -> None: def get_downstream_nodes(self, node_unique_id: str, use_filtered: bool = True) -> list[DbtNode]: """ Get all downstream nodes for a given node. - + :param node_unique_id: The unique_id of the node to find downstream dependencies for :param use_filtered: Whether to search in filtered_nodes (True) or all nodes (False) :returns: List of DbtNode objects that depend on the specified node """ nodes_dict = self.filtered_nodes if use_filtered else self.nodes - + if node_unique_id not in nodes_dict: return [] - + node = nodes_dict[node_unique_id] return node.downstream_nodes diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_prunning.py index bc6e4afd6a..37934f422b 100644 --- a/dev/dags/example_source_prunning.py +++ b/dev/dags/example_source_prunning.py @@ -35,8 +35,7 @@ "install_deps": True, # install any necessary dependencies before running any dbt command "full_refresh": True, # used only in dbt commands that support this flag }, - render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, - source_prunning=True), + render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, source_prunning=True), # normal dag parameters schedule="@daily", start_date=datetime(2024, 1, 1), From a61ceb87012117946aedfba97a53703f6a7baf2d Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 16:50:40 +0200 Subject: [PATCH 04/26] add cached property to transverse map --- cosmos/dbt/graph.py | 103 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 18 deletions(-) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ccab4c2d32..5c69e99f89 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -11,7 +11,7 @@ import warnings import zlib from dataclasses import dataclass, field -from functools import cached_property +from functools import cached_property, lru_cache from pathlib import Path from subprocess import PIPE, Popen from typing import TYPE_CHECKING, Any, Dict, Optional @@ -83,7 +83,7 @@ class DbtNode: config: dict[str, Any] = field(default_factory=lambda: {}) has_freshness: bool = False has_test: bool = False - _graph_nodes: dict[str, "DbtNode"] = field(default_factory=dict, init=False, repr=False) + downstream_graph: DownstreamGraph = field(default=None, init=False, repr=False) @property def meta(self) -> Dict[str, Any]: @@ -175,17 +175,15 @@ def context_dict(self) -> dict[str, Any]: } @property - def downstream_nodes(self) -> list["DbtNode"]: + def downstream_nodes(self) -> list[DbtNode]: """ Find all nodes that depend on this node (downstream dependencies). :returns: List of DbtNode objects that depend on this node """ - downstream_nodes = [] - for node in self._graph_nodes.values(): - if self.unique_id in node.depends_on: - downstream_nodes.append(node) - return downstream_nodes + if self.downstream_graph is None: + return [] + return self.downstream_graph.get_downstream_nodes(self.unique_id) @property def downstream_ids(self) -> list[str]: @@ -194,18 +192,83 @@ def downstream_ids(self) -> list[str]: :returns: List of unique_ids of nodes that depend on this node """ - downstream_ids = [] - for node in self._graph_nodes.values(): - if self.unique_id in node.depends_on: - downstream_ids.append(node.unique_id) - return downstream_ids + if self.downstream_graph is None: + return [] + return self.downstream_graph.get_downstream_ids(self.unique_id) - def _set_graph_nodes(self, graph_nodes: dict[str, "DbtNode"]) -> None: + def set_downstream_graph(self, downstream_graph: DownstreamGraph) -> None: """ - Internal method to set the graph nodes reference. + Internal method to set the downstream graph reference. This should only be called by DbtGraph after loading nodes. """ - self._graph_nodes = graph_nodes + self.downstream_graph = downstream_graph + + +class DownstreamGraph: + """ + Optimized data structure for computing downstream relationships in a dbt graph. + + This class precomputes and caches downstream relationships to avoid O(n) lookups + for each node when computing downstream dependencies, otherwise we would risk overloading the scheduler. + """ + + def __init__(self, nodes: dict[str, DbtNode]): + """ + + :param nodes: Dictionary mapping node unique_id to DbtNode objects + """ + self.nodes = nodes + self._downstream_map: dict[str, list[str]] = {} + self._downstream_nodes_cache: dict[str, list[DbtNode]] = {} + self._build_downstream_map() + + def _build_downstream_map(self) -> None: + """ + Build the downstream map by iterating through all nodes at once. + """ + for node_id in self.nodes: + self._downstream_map[node_id] = [] + + # Build the downstream relationships + for node in self.nodes.values(): + for dependency_id in node.depends_on: + if dependency_id in self._downstream_map: + self._downstream_map[dependency_id].append(node.unique_id) + + @lru_cache(maxsize=1024) + def get_downstream_ids(self, node_id: str) -> list[str]: + """ + Get downstream node IDs for a given node. + Uses LRU cache for performance with repeated access. + + :param node_id: The unique_id of the node + :returns: List of downstream node IDs + """ + return self._downstream_map.get(node_id, []).copy() + + def get_downstream_nodes(self, node_id: str) -> list[DbtNode]: + """ + Get downstream DbtNode objects for a given node. + Uses caching to avoid repeated node lookups. + + :param node_id: The unique_id of the node + :returns: List of downstream DbtNode objects + """ + if node_id not in self._downstream_nodes_cache: + downstream_ids = self.get_downstream_ids(node_id) + downstream_nodes = [ + self.nodes[downstream_id] + for downstream_id in downstream_ids + if downstream_id in self.nodes + ] + self._downstream_nodes_cache[node_id] = downstream_nodes + + return self._downstream_nodes_cache[node_id].copy() + + def clear_cache(self) -> None: + """Clear the internal caches""" + self.get_downstream_ids.cache_clear() + self._downstream_nodes_cache.clear() def is_freshness_effective(freshness: Optional[dict[str, Any]]) -> bool: @@ -1002,11 +1065,15 @@ def update_node_dependency(self) -> None: def _set_graph_references(self) -> None: """ - Set the graph nodes reference for all nodes so they can compute downstream dependencies. + Set the downstream graph reference for all nodes so they can compute downstream dependencies. This should be called after loading and filtering nodes. """ + # Create the optimized downstream graph using filtered_nodes + downstream_graph = DownstreamGraph(self.filtered_nodes) + + # Set the downstream graph reference for all nodes for node in self.nodes.values(): - node._set_graph_nodes(self.filtered_nodes) + node.set_downstream_graph(downstream_graph) def get_downstream_nodes(self, node_unique_id: str, use_filtered: bool = True) -> list[DbtNode]: """ From 3f7438937f97334b0ac8ff3ad77b2113e128a6c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 14:51:34 +0000 Subject: [PATCH 05/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/dbt/graph.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 5c69e99f89..5fa9cf8934 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -178,7 +178,7 @@ def context_dict(self) -> dict[str, Any]: def downstream_nodes(self) -> list[DbtNode]: """ Find all nodes that depend on this node (downstream dependencies). - + :returns: List of DbtNode objects that depend on this node """ if self.downstream_graph is None: @@ -189,7 +189,7 @@ def downstream_nodes(self) -> list[DbtNode]: def downstream_ids(self) -> list[str]: """ Find all node IDs that depend on this node (downstream dependencies). - + :returns: List of unique_ids of nodes that depend on this node """ if self.downstream_graph is None: @@ -207,64 +207,62 @@ def set_downstream_graph(self, downstream_graph: DownstreamGraph) -> None: class DownstreamGraph: """ Optimized data structure for computing downstream relationships in a dbt graph. - + This class precomputes and caches downstream relationships to avoid O(n) lookups for each node when computing downstream dependencies, otherwise we would risk overloading the scheduler. """ - + def __init__(self, nodes: dict[str, DbtNode]): """ - + :param nodes: Dictionary mapping node unique_id to DbtNode objects """ self.nodes = nodes self._downstream_map: dict[str, list[str]] = {} self._downstream_nodes_cache: dict[str, list[DbtNode]] = {} self._build_downstream_map() - + def _build_downstream_map(self) -> None: """ Build the downstream map by iterating through all nodes at once. """ for node_id in self.nodes: self._downstream_map[node_id] = [] - + # Build the downstream relationships for node in self.nodes.values(): for dependency_id in node.depends_on: if dependency_id in self._downstream_map: self._downstream_map[dependency_id].append(node.unique_id) - + @lru_cache(maxsize=1024) def get_downstream_ids(self, node_id: str) -> list[str]: """ Get downstream node IDs for a given node. Uses LRU cache for performance with repeated access. - + :param node_id: The unique_id of the node :returns: List of downstream node IDs """ return self._downstream_map.get(node_id, []).copy() - + def get_downstream_nodes(self, node_id: str) -> list[DbtNode]: """ Get downstream DbtNode objects for a given node. Uses caching to avoid repeated node lookups. - + :param node_id: The unique_id of the node :returns: List of downstream DbtNode objects """ if node_id not in self._downstream_nodes_cache: downstream_ids = self.get_downstream_ids(node_id) downstream_nodes = [ - self.nodes[downstream_id] - for downstream_id in downstream_ids - if downstream_id in self.nodes + self.nodes[downstream_id] for downstream_id in downstream_ids if downstream_id in self.nodes ] self._downstream_nodes_cache[node_id] = downstream_nodes - + return self._downstream_nodes_cache[node_id].copy() - + def clear_cache(self) -> None: """Clear the internal caches""" self.get_downstream_ids.cache_clear() @@ -1070,7 +1068,7 @@ def _set_graph_references(self) -> None: """ # Create the optimized downstream graph using filtered_nodes downstream_graph = DownstreamGraph(self.filtered_nodes) - + # Set the downstream graph reference for all nodes for node in self.nodes.values(): node.set_downstream_graph(downstream_graph) @@ -1078,15 +1076,15 @@ def _set_graph_references(self) -> None: def get_downstream_nodes(self, node_unique_id: str, use_filtered: bool = True) -> list[DbtNode]: """ Get all downstream nodes for a given node. - + :param node_unique_id: The unique_id of the node to find downstream dependencies for :param use_filtered: Whether to search in filtered_nodes (True) or all nodes (False) :returns: List of DbtNode objects that depend on the specified node """ nodes_dict = self.filtered_nodes if use_filtered else self.nodes - + if node_unique_id not in nodes_dict: return [] - + node = nodes_dict[node_unique_id] return node.downstream_nodes From 27c5b7423420552e4ff4e736f20e8e69cd20bd59 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 16:54:13 +0200 Subject: [PATCH 06/26] fix issue --- dev/dags/example_source_prunning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_prunning.py index 37934f422b..a106ff134f 100644 --- a/dev/dags/example_source_prunning.py +++ b/dev/dags/example_source_prunning.py @@ -1,5 +1,6 @@ """ -An example DAG that uses Cosmos to render a dbt project into an Airflow DAG using Cosmos source rendering. +An example DAG that uses Cosmos to render a dbt project into an Airflow DAG using Cosmos source rendering, and then +prunes unused source nodes. """ import os From de3dab10e7619794f9098cc5f0b02a8c61582e18 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:20:47 +0200 Subject: [PATCH 07/26] improve performance --- cosmos/converter.py | 6 +++++- cosmos/dbt/graph.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/cosmos/converter.py b/cosmos/converter.py index b3f804807d..e6df41f0df 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -285,7 +285,11 @@ def __init__( dbt_vars=project_config.dbt_vars, airflow_metadata=cache._get_airflow_metadata(dag, task_group), ) - self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode) + self.dbt_graph.load( + method=render_config.load_method, + execution_mode=execution_config.execution_mode, + needs_downstream_graph=render_config.source_prunning + ) self._add_dbt_project_hash_to_dag_docs(dag) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 5fa9cf8934..22096342b6 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -624,6 +624,7 @@ def load( self, method: LoadMode = LoadMode.AUTOMATIC, execution_mode: ExecutionMode = ExecutionMode.LOCAL, + needs_downstream_graph: bool = False, ) -> None: """ Load a `dbt` project into a `DbtGraph`, setting `nodes` and `filtered_nodes` accordingly. @@ -631,6 +632,8 @@ def load( :param method: How to load `nodes` from a `dbt` project (automatically, using custom parser, using dbt manifest or dbt ls) :param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.KUBERNETES) + :param needs_downstream_graph: Whether to build the downstream graph references. Set to False to skip + building downstream relationships for better performance when source pruning is not needed. Fundamentally, there are two different execution paths There is automatic, and manual. @@ -658,7 +661,10 @@ def load( load_method[method]() self.update_node_dependency() - self._set_graph_references() + + # Only build downstream graph references if needed (for performance optimization) + if needs_downstream_graph: + self._set_graph_references() logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.filtered_nodes)) From 31b4a45f2231d7ce2a60d25b653a0eb59e207ea7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:21:07 +0000 Subject: [PATCH 08/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/converter.py | 4 ++-- cosmos/dbt/graph.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cosmos/converter.py b/cosmos/converter.py index e6df41f0df..c648a97823 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -286,9 +286,9 @@ def __init__( airflow_metadata=cache._get_airflow_metadata(dag, task_group), ) self.dbt_graph.load( - method=render_config.load_method, + method=render_config.load_method, execution_mode=execution_config.execution_mode, - needs_downstream_graph=render_config.source_prunning + needs_downstream_graph=render_config.source_prunning, ) self._add_dbt_project_hash_to_dag_docs(dag) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 22096342b6..ec12585c80 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -661,7 +661,7 @@ def load( load_method[method]() self.update_node_dependency() - + # Only build downstream graph references if needed (for performance optimization) if needs_downstream_graph: self._set_graph_references() From 69c7673c03de441883537eb538fd4c0b654fc05d Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:25:59 +0200 Subject: [PATCH 09/26] fix spelling --- cosmos/airflow/graph.py | 12 ++++++------ cosmos/config.py | 4 ++-- cosmos/converter.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 74fa6e6339..a604585a6a 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -257,7 +257,7 @@ def create_task_metadata( dbt_dag_task_group_identifier: str, use_task_group: bool = False, source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE, - source_prunning: bool = False, + source_pruning: bool = False, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, test_behavior: TestBehavior = TestBehavior.AFTER_ALL, @@ -331,7 +331,7 @@ def create_task_metadata( ): return None - if source_prunning and not node.downstream_nodes: + if source_pruning and not node.downstream_nodes: return None task_id, args = _get_task_id_and_args( node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source" @@ -400,7 +400,7 @@ def generate_task_or_group( task_args: dict[str, Any], test_behavior: TestBehavior, source_rendering_behavior: SourceRenderingBehavior, - source_prunning: bool, + source_pruning: bool, test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None = None, normalize_task_id: Callable[..., Any] | None = None, @@ -425,7 +425,7 @@ def generate_task_or_group( dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group), use_task_group=use_task_group, source_rendering_behavior=source_rendering_behavior, - source_prunning=source_prunning, + source_pruning=source_pruning, normalize_task_id=normalize_task_id, normalize_task_display_name=normalize_task_display_name, test_behavior=test_behavior, @@ -636,7 +636,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro node_converters = render_config.node_converters or {} test_behavior = render_config.test_behavior source_rendering_behavior = render_config.source_rendering_behavior - source_prunning = render_config.source_prunning + source_pruning = render_config.source_pruning normalize_task_id = render_config.normalize_task_id normalize_task_display_name = render_config.normalize_task_display_name enable_owner_inheritance = render_config.enable_owner_inheritance @@ -669,7 +669,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro task_args=task_args, test_behavior=test_behavior, source_rendering_behavior=source_rendering_behavior, - source_prunning=source_prunning, + source_pruning=source_pruning, test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, normalize_task_id=normalize_task_id, diff --git a/cosmos/config.py b/cosmos/config.py index 0551450f94..18bd135c51 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -63,7 +63,7 @@ class RenderConfig: :param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``. :param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4). :param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). - :param source_prunning: Determines if sources without a corresponding downstream task should be removed or not. Default is False + :param source_pruning: Determines if sources without a corresponding downstream task should be removed or not. Default is False :param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache. :param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name. :param normalize_task_display_name: A callable that takes a dbt node as input and returns the task display name. This allows users to assign a custom task display name separate from the node ID. @@ -87,7 +87,7 @@ class RenderConfig: project_path: Path | None = field(init=False) enable_mock_profile: bool = True source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE - source_prunning: bool = False + source_pruning: bool = False airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list) normalize_task_id: Callable[..., Any] | None = None normalize_task_display_name: Callable[..., Any] | None = None diff --git a/cosmos/converter.py b/cosmos/converter.py index e6df41f0df..bdad8c9ad7 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -288,7 +288,7 @@ def __init__( self.dbt_graph.load( method=render_config.load_method, execution_mode=execution_config.execution_mode, - needs_downstream_graph=render_config.source_prunning + needs_downstream_graph=render_config.source_pruning ) self._add_dbt_project_hash_to_dag_docs(dag) From ddbd32d70e76ea0665f6f6fcb3a9e70ca75d461a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:27:12 +0000 Subject: [PATCH 10/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/converter.py b/cosmos/converter.py index bdad8c9ad7..ea568a32f0 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -286,9 +286,9 @@ def __init__( airflow_metadata=cache._get_airflow_metadata(dag, task_group), ) self.dbt_graph.load( - method=render_config.load_method, + method=render_config.load_method, execution_mode=execution_config.execution_mode, - needs_downstream_graph=render_config.source_pruning + needs_downstream_graph=render_config.source_pruning, ) self._add_dbt_project_hash_to_dag_docs(dag) From ea187a6c555a98900a6c2edd558b89994152d814 Mon Sep 17 00:00:00 2001 From: Giovanni Corsetti <155465603+corsettigyg@users.noreply.github.com> Date: Fri, 19 Sep 2025 17:30:37 +0200 Subject: [PATCH 11/26] Update cosmos/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cosmos/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/config.py b/cosmos/config.py index 18bd135c51..cb0993c074 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -63,7 +63,7 @@ class RenderConfig: :param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``. :param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4). :param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). - :param source_pruning: Determines if sources without a corresponding downstream task should be removed or not. Default is False + :param source_pruning: Determines if source nodes without a corresponding downstream task should be removed or not. Default is False :param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache. :param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name. :param normalize_task_display_name: A callable that takes a dbt node as input and returns the task display name. This allows users to assign a custom task display name separate from the node ID. From 1758e3d1ce94a999e318be02529055335a70263a Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:31:29 +0200 Subject: [PATCH 12/26] rename to source pruning --- docs/configuration/source-nodes-rendering.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration/source-nodes-rendering.rst b/docs/configuration/source-nodes-rendering.rst index 864d9b4d70..9bfcf0e97b 100644 --- a/docs/configuration/source-nodes-rendering.rst +++ b/docs/configuration/source-nodes-rendering.rst @@ -40,7 +40,7 @@ Example: Source Pruning ~~~~~~~~~~~~~~ -The ``source_prunning`` is a boolean parameter available in the ``RenderConfig``. +The ``source_pruning`` is a boolean parameter available in the ``RenderConfig``. When set to ``True``, it automatically removes (or "prunes") any dbt source nodes from your Airflow DAG that do not have any downstream dependencies within the selected portion of the dbt graph. This is particularly useful for keeping your DAGs clean and focused, especially in large dbt projects where you might be selecting only a subset of models to run. @@ -56,7 +56,7 @@ Example: jaffle_shop_pruned = DbtTaskGroup( render_config=RenderConfig( select=["+customers"], # select only customers and its parents - source_prunning=True, + source_pruning=True, ) ) From 6427f0dfeee3b5fa03e387ad122ee5f4ba92a31b Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:32:49 +0200 Subject: [PATCH 13/26] rename to source pruning v2 --- dev/dags/example_source_prunning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_prunning.py index a106ff134f..c1bfddd617 100644 --- a/dev/dags/example_source_prunning.py +++ b/dev/dags/example_source_prunning.py @@ -36,7 +36,7 @@ "install_deps": True, # install any necessary dependencies before running any dbt command "full_refresh": True, # used only in dbt commands that support this flag }, - render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, source_prunning=True), + render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, source_pruning=True), # normal dag parameters schedule="@daily", start_date=datetime(2024, 1, 1), From ceee75c736cb03b1118273dd535ae19b85e12042 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:33:53 +0200 Subject: [PATCH 14/26] fix spelling in dag_id --- dev/dags/example_source_prunning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_prunning.py index c1bfddd617..49317d889c 100644 --- a/dev/dags/example_source_prunning.py +++ b/dev/dags/example_source_prunning.py @@ -41,7 +41,7 @@ schedule="@daily", start_date=datetime(2024, 1, 1), catchup=False, - dag_id="source_prunning_dag", + dag_id="source_pruning_dag", default_args={"retries": 0}, on_warning_callback=lambda context: print(context), ) From 32b7bdd9f720e115bfecd5df9babd2ad0a440a68 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Fri, 19 Sep 2025 17:36:23 +0200 Subject: [PATCH 15/26] fix file name --- .../{example_source_prunning.py => example_source_pruning.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dev/dags/{example_source_prunning.py => example_source_pruning.py} (100%) diff --git a/dev/dags/example_source_prunning.py b/dev/dags/example_source_pruning.py similarity index 100% rename from dev/dags/example_source_prunning.py rename to dev/dags/example_source_pruning.py From 152bc63d173f7a557538eb4edb9f63807aa9912b Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Sat, 20 Sep 2025 21:11:24 +0200 Subject: [PATCH 16/26] revert changes in node --- cosmos/airflow/graph.py | 25 +++++++- cosmos/converter.py | 6 +- cosmos/dbt/graph.py | 133 +--------------------------------------- 3 files changed, 26 insertions(+), 138 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index a604585a6a..67ee96233f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -64,6 +64,25 @@ def calculate_operator_class( ) +def _is_source_used_by_filtered_nodes(source_node: DbtNode, filtered_nodes: dict[str, DbtNode]) -> bool: + """ + Check if a source node is referenced by any of the filtered nodes. + This is a much simpler approach than building the entire downstream graph. + + :param source_node: The source node to check + :param filtered_nodes: Dictionary of filtered nodes (nodes that will be rendered) + :returns: True if the source is used by any filtered node, False otherwise + """ + source_id = source_node.unique_id + + # Check if any filtered node depends on this source + for node in filtered_nodes.values(): + if source_id in node.depends_on: + return True + + return False + + def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[str]: """ Return a list of unique_ids for nodes that are not parents (don't have dependencies on other tasks). @@ -265,6 +284,7 @@ def create_task_metadata( on_warning_callback: Callable[..., Any] | None = None, detached_from_parent: dict[str, DbtNode] | None = None, enable_owner_inheritance: bool | None = None, + filtered_nodes: dict[str, DbtNode] | None = None, ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -331,7 +351,7 @@ def create_task_metadata( ): return None - if source_pruning and not node.downstream_nodes: + if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(node, filtered_nodes): return None task_id, args = _get_task_id_and_args( node, args, use_task_group, normalize_task_id, normalize_task_display_name, "source" @@ -407,6 +427,7 @@ def generate_task_or_group( normalize_task_display_name: Callable[..., Any] | None = None, detached_from_parent: dict[str, DbtNode] | None = None, enable_owner_inheritance: bool | None = None, + filtered_nodes: dict[str, DbtNode] | None = None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None @@ -433,6 +454,7 @@ def generate_task_or_group( on_warning_callback=on_warning_callback, detached_from_parent=detached_from_parent, enable_owner_inheritance=enable_owner_inheritance, + filtered_nodes=filtered_nodes, ) # In most cases, we'll map one DBT node to one Airflow task @@ -677,6 +699,7 @@ def build_airflow_graph( # noqa: C901 TODO: https://github.com/astronomer/astro node=node, detached_from_parent=detached_from_parent, enable_owner_inheritance=enable_owner_inheritance, + filtered_nodes=nodes, ) if task_or_group is not None: logger.debug(f"Conversion of <{node.unique_id}> was successful!") diff --git a/cosmos/converter.py b/cosmos/converter.py index ea568a32f0..b3f804807d 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -285,11 +285,7 @@ def __init__( dbt_vars=project_config.dbt_vars, airflow_metadata=cache._get_airflow_metadata(dag, task_group), ) - self.dbt_graph.load( - method=render_config.load_method, - execution_mode=execution_config.execution_mode, - needs_downstream_graph=render_config.source_pruning, - ) + self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode) self._add_dbt_project_hash_to_dag_docs(dag) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ec12585c80..ac17a02692 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -11,7 +11,7 @@ import warnings import zlib from dataclasses import dataclass, field -from functools import cached_property, lru_cache +from functools import cached_property from pathlib import Path from subprocess import PIPE, Popen from typing import TYPE_CHECKING, Any, Dict, Optional @@ -83,7 +83,6 @@ class DbtNode: config: dict[str, Any] = field(default_factory=lambda: {}) has_freshness: bool = False has_test: bool = False - downstream_graph: DownstreamGraph = field(default=None, init=False, repr=False) @property def meta(self) -> Dict[str, Any]: @@ -171,103 +170,8 @@ def context_dict(self) -> dict[str, Any]: "has_test": self.has_test, "resource_name": self.resource_name, "name": self.name, - "downstream_ids": self.downstream_ids, } - @property - def downstream_nodes(self) -> list[DbtNode]: - """ - Find all nodes that depend on this node (downstream dependencies). - - :returns: List of DbtNode objects that depend on this node - """ - if self.downstream_graph is None: - return [] - return self.downstream_graph.get_downstream_nodes(self.unique_id) - - @property - def downstream_ids(self) -> list[str]: - """ - Find all node IDs that depend on this node (downstream dependencies). - - :returns: List of unique_ids of nodes that depend on this node - """ - if self.downstream_graph is None: - return [] - return self.downstream_graph.get_downstream_ids(self.unique_id) - - def set_downstream_graph(self, downstream_graph: DownstreamGraph) -> None: - """ - Internal method to set the downstream graph reference. - This should only be called by DbtGraph after loading nodes. - """ - self.downstream_graph = downstream_graph - - -class DownstreamGraph: - """ - Optimized data structure for computing downstream relationships in a dbt graph. - - This class precomputes and caches downstream relationships to avoid O(n) lookups - for each node when computing downstream dependencies, otherwise we would risk overloading the scheduler. - """ - - def __init__(self, nodes: dict[str, DbtNode]): - """ - - :param nodes: Dictionary mapping node unique_id to DbtNode objects - """ - self.nodes = nodes - self._downstream_map: dict[str, list[str]] = {} - self._downstream_nodes_cache: dict[str, list[DbtNode]] = {} - self._build_downstream_map() - - def _build_downstream_map(self) -> None: - """ - Build the downstream map by iterating through all nodes at once. - """ - for node_id in self.nodes: - self._downstream_map[node_id] = [] - - # Build the downstream relationships - for node in self.nodes.values(): - for dependency_id in node.depends_on: - if dependency_id in self._downstream_map: - self._downstream_map[dependency_id].append(node.unique_id) - - @lru_cache(maxsize=1024) - def get_downstream_ids(self, node_id: str) -> list[str]: - """ - Get downstream node IDs for a given node. - Uses LRU cache for performance with repeated access. - - :param node_id: The unique_id of the node - :returns: List of downstream node IDs - """ - return self._downstream_map.get(node_id, []).copy() - - def get_downstream_nodes(self, node_id: str) -> list[DbtNode]: - """ - Get downstream DbtNode objects for a given node. - Uses caching to avoid repeated node lookups. - - :param node_id: The unique_id of the node - :returns: List of downstream DbtNode objects - """ - if node_id not in self._downstream_nodes_cache: - downstream_ids = self.get_downstream_ids(node_id) - downstream_nodes = [ - self.nodes[downstream_id] for downstream_id in downstream_ids if downstream_id in self.nodes - ] - self._downstream_nodes_cache[node_id] = downstream_nodes - - return self._downstream_nodes_cache[node_id].copy() - - def clear_cache(self) -> None: - """Clear the internal caches""" - self.get_downstream_ids.cache_clear() - self._downstream_nodes_cache.clear() - def is_freshness_effective(freshness: Optional[dict[str, Any]]) -> bool: """Function to find if a source has null freshness. Scenarios where freshness @@ -624,7 +528,6 @@ def load( self, method: LoadMode = LoadMode.AUTOMATIC, execution_mode: ExecutionMode = ExecutionMode.LOCAL, - needs_downstream_graph: bool = False, ) -> None: """ Load a `dbt` project into a `DbtGraph`, setting `nodes` and `filtered_nodes` accordingly. @@ -632,8 +535,6 @@ def load( :param method: How to load `nodes` from a `dbt` project (automatically, using custom parser, using dbt manifest or dbt ls) :param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.KUBERNETES) - :param needs_downstream_graph: Whether to build the downstream graph references. Set to False to skip - building downstream relationships for better performance when source pruning is not needed. Fundamentally, there are two different execution paths There is automatic, and manual. @@ -662,10 +563,6 @@ def load( self.update_node_dependency() - # Only build downstream graph references if needed (for performance optimization) - if needs_downstream_graph: - self._set_graph_references() - logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.filtered_nodes)) @@ -1066,31 +963,3 @@ def update_node_dependency(self) -> None: if node_id in self.filtered_nodes: self.filtered_nodes[node_id].has_test = True self.filtered_nodes[node.unique_id] = node - - def _set_graph_references(self) -> None: - """ - Set the downstream graph reference for all nodes so they can compute downstream dependencies. - This should be called after loading and filtering nodes. - """ - # Create the optimized downstream graph using filtered_nodes - downstream_graph = DownstreamGraph(self.filtered_nodes) - - # Set the downstream graph reference for all nodes - for node in self.nodes.values(): - node.set_downstream_graph(downstream_graph) - - def get_downstream_nodes(self, node_unique_id: str, use_filtered: bool = True) -> list[DbtNode]: - """ - Get all downstream nodes for a given node. - - :param node_unique_id: The unique_id of the node to find downstream dependencies for - :param use_filtered: Whether to search in filtered_nodes (True) or all nodes (False) - :returns: List of DbtNode objects that depend on the specified node - """ - nodes_dict = self.filtered_nodes if use_filtered else self.nodes - - if node_unique_id not in nodes_dict: - return [] - - node = nodes_dict[node_unique_id] - return node.downstream_nodes From b5f32e1d0868fb664cf390d17c0c60f6505af82b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Sep 2025 19:11:43 +0000 Subject: [PATCH 17/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/airflow/graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 67ee96233f..dd1238cd8c 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -68,18 +68,18 @@ def _is_source_used_by_filtered_nodes(source_node: DbtNode, filtered_nodes: dict """ Check if a source node is referenced by any of the filtered nodes. This is a much simpler approach than building the entire downstream graph. - + :param source_node: The source node to check :param filtered_nodes: Dictionary of filtered nodes (nodes that will be rendered) :returns: True if the source is used by any filtered node, False otherwise """ source_id = source_node.unique_id - + # Check if any filtered node depends on this source for node in filtered_nodes.values(): if source_id in node.depends_on: return True - + return False From 8140db7c885210911edf2cb987e5fa620c6c3799 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Sat, 20 Sep 2025 21:16:29 +0200 Subject: [PATCH 18/26] fix doc --- cosmos/airflow/graph.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 67ee96233f..5935190c35 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -67,10 +67,9 @@ def calculate_operator_class( def _is_source_used_by_filtered_nodes(source_node: DbtNode, filtered_nodes: dict[str, DbtNode]) -> bool: """ Check if a source node is referenced by any of the filtered nodes. - This is a much simpler approach than building the entire downstream graph. - + :param source_node: The source node to check - :param filtered_nodes: Dictionary of filtered nodes (nodes that will be rendered) + :param filtered_nodes: Dictionary of filtered nodes :returns: True if the source is used by any filtered node, False otherwise """ source_id = source_node.unique_id From a3ee387f2ce9ab96cb8a1bfe43396c20b2980a36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Sep 2025 19:17:24 +0000 Subject: [PATCH 19/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/airflow/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 5935190c35..708406d09a 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -73,12 +73,12 @@ def _is_source_used_by_filtered_nodes(source_node: DbtNode, filtered_nodes: dict :returns: True if the source is used by any filtered node, False otherwise """ source_id = source_node.unique_id - + # Check if any filtered node depends on this source for node in filtered_nodes.values(): if source_id in node.depends_on: return True - + return False From a19fb7e9ff9c45d30cceec8489ef0a951a2e33fc Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Sat, 20 Sep 2025 21:28:35 +0200 Subject: [PATCH 20/26] fix source_pruning not having default value --- cosmos/airflow/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 5935190c35..f5980fc66c 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -419,8 +419,8 @@ def generate_task_or_group( task_args: dict[str, Any], test_behavior: TestBehavior, source_rendering_behavior: SourceRenderingBehavior, - source_pruning: bool, test_indirect_selection: TestIndirectSelection, + source_pruning: bool = False, on_warning_callback: Callable[..., Any] | None = None, normalize_task_id: Callable[..., Any] | None = None, normalize_task_display_name: Callable[..., Any] | None = None, From 71266aba21ce9c50c731245b6f354eed2eb93863 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Mon, 22 Sep 2025 10:39:08 +0200 Subject: [PATCH 21/26] add tests --- tests/dbt/test_pruning.py | 515 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 515 insertions(+) create mode 100644 tests/dbt/test_pruning.py diff --git a/tests/dbt/test_pruning.py b/tests/dbt/test_pruning.py new file mode 100644 index 0000000000..703cf7c544 --- /dev/null +++ b/tests/dbt/test_pruning.py @@ -0,0 +1,515 @@ +""" +This test suite focuses on the source pruning logic that removes sources +without downstream dependencies from the rendered DAG. +""" + +from pathlib import Path +from datetime import datetime + +from airflow.models import DAG + +from cosmos.airflow.graph import ( + _is_source_used_by_filtered_nodes, + create_task_metadata, + generate_task_or_group, +) +from cosmos.config import ProfileConfig +from cosmos.constants import ( + DbtResourceType, + ExecutionMode, + SourceRenderingBehavior, + TestBehavior, + TestIndirectSelection, +) +from cosmos.dbt.graph import DbtNode +from cosmos.profiles import PostgresUserPasswordProfileMapping + + +class TestSourcePruningLogic: + """Test the core source pruning logic.""" + + def setup_method(self): + """Set up test data for each test method.""" + # Create a graph structure for testing: + # source_used -> model1 -> model2 + # source_orphaned (no downstream dependencies) + # source_with_multiple_deps -> model3 + # -> model4 + + self.source_used = DbtNode( + unique_id="source.test_project.source_used", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_used.yml"), + package_name="test_project" + ) + + self.source_orphaned = DbtNode( + unique_id="source.test_project.source_orphaned", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_orphaned.yml"), + package_name="test_project" + ) + + self.source_with_multiple_deps = DbtNode( + unique_id="source.test_project.source_multiple", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_multiple.yml"), + package_name="test_project" + ) + + self.model1 = DbtNode( + unique_id="model.test_project.model1", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_used"], + file_path=Path("/test/model1.sql"), + package_name="test_project" + ) + + self.model2 = DbtNode( + unique_id="model.test_project.model2", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.model1"], + file_path=Path("/test/model2.sql"), + package_name="test_project" + ) + + self.model3 = DbtNode( + unique_id="model.test_project.model3", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_multiple"], + file_path=Path("/test/model3.sql"), + package_name="test_project" + ) + + self.model4 = DbtNode( + unique_id="model.test_project.model4", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_multiple"], + file_path=Path("/test/model4.sql"), + package_name="test_project" + ) + + def test_is_source_used_by_filtered_nodes_basic(self): + """Test the basic functionality of _is_source_used_by_filtered_nodes.""" + # Test case 1: Source is used by a filtered node + filtered_nodes = { + self.model1.unique_id: self.model1, + self.model2.unique_id: self.model2, + } + + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + # Test case 2: Source is used by multiple filtered nodes + filtered_nodes_multiple = { + self.model3.unique_id: self.model3, + self.model4.unique_id: self.model4, + } + + assert _is_source_used_by_filtered_nodes(self.source_with_multiple_deps, filtered_nodes_multiple) is True + + def test_is_source_used_by_filtered_nodes_empty_filtered_nodes(self): + """Test behavior with empty filtered_nodes.""" + empty_filtered_nodes = {} + + # No filtered nodes means no source should be considered "used" + assert _is_source_used_by_filtered_nodes(self.source_used, empty_filtered_nodes) is False + assert _is_source_used_by_filtered_nodes(self.source_orphaned, empty_filtered_nodes) is False + + def test_is_source_used_by_filtered_nodes_no_dependencies(self): + """Test with filtered nodes that have no dependencies.""" + # Create a model with no dependencies + independent_model = DbtNode( + unique_id="model.test_project.independent", + resource_type=DbtResourceType.MODEL, + depends_on=[], # No dependencies + file_path=Path("/test/independent.sql"), + package_name="test_project" + ) + + filtered_nodes = { + independent_model.unique_id: independent_model, + } + + # Sources should not be considered "used" if no models depend on them + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is False + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + def test_is_source_used_by_filtered_nodes_mixed_dependencies(self): + """Test with filtered nodes that have mixed dependency types.""" + # Create a model that depends on another model AND a source + complex_model = DbtNode( + unique_id="model.test_project.complex", + resource_type=DbtResourceType.MODEL, + depends_on=[ + "model.test_project.model1", # Model dependency + "source.test_project.source_used", # Source dependency + ], + file_path=Path("/test/complex.sql"), + package_name="test_project" + ) + + filtered_nodes = { + self.model1.unique_id: self.model1, + complex_model.unique_id: complex_model, + } + + # Source should be considered "used" even when model has mixed dependencies + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True + assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False + + +class TestSourcePruningIntegration: + """Test source pruning integration with create_task_metadata and generate_task_or_group.""" + + def setup_method(self): + """Set up test data for integration tests.""" + self.dag = DAG("test_source_pruning", start_date=datetime(2023, 1, 1)) + + self.profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ) + + self.task_args = { + "project_dir": Path("/test/project"), + "profile_config": self.profile_config, + } + + # Create test nodes + self.source_with_downstream = DbtNode( + unique_id="source.test_project.source_with_downstream", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_with_downstream.yml"), + package_name="test_project", + has_freshness=True, # Make it renderable + ) + + self.source_orphaned = DbtNode( + unique_id="source.test_project.source_orphaned", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_orphaned.yml"), + package_name="test_project", + has_freshness=True, # Make it renderable + ) + + self.model_using_source = DbtNode( + unique_id="model.test_project.model_using_source", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.source_with_downstream"], + file_path=Path("/test/model_using_source.sql"), + package_name="test_project" + ) + + # Filtered nodes represent what will actually be rendered + self.filtered_nodes_with_model = { + self.model_using_source.unique_id: self.model_using_source, + } + + self.filtered_nodes_empty = {} + + def test_create_task_metadata_source_pruning_disabled(self): + """Test create_task_metadata when source_pruning is disabled.""" + # Source should be rendered regardless of downstream dependencies + metadata_with_downstream = create_task_metadata( + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=False, # Disabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + metadata_orphaned = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=False, # Disabled + filtered_nodes=self.filtered_nodes_empty, + ) + + # Both sources should generate metadata when pruning is disabled + assert metadata_with_downstream is not None + assert metadata_orphaned is not None + assert metadata_with_downstream.id == "source_with_downstream_source" + assert metadata_orphaned.id == "source_orphaned_source" + + def test_create_task_metadata_source_pruning_enabled(self): + """Test create_task_metadata when source_pruning is enabled.""" + # Source with downstream should be rendered + metadata_with_downstream = create_task_metadata( + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, # Enabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Orphaned source should NOT be rendered + metadata_orphaned = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, # Enabled + filtered_nodes=self.filtered_nodes_with_model, + ) + + assert metadata_with_downstream is not None + assert metadata_orphaned is None # Pruned! + assert metadata_with_downstream.id == "source_with_downstream_source" + + def test_create_task_metadata_source_pruning_edge_cases(self): + """Test create_task_metadata edge cases with filtered_nodes.""" + # When filtered_nodes is None, pruning should be skipped (condition: source_pruning and filtered_nodes) + metadata_none = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes=None, # None means skip pruning check + ) + + # When filtered_nodes is empty, pruning should be skipped (empty dict is falsy) + metadata_empty = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes={}, # Empty dict means skip pruning check + ) + + # When filtered_nodes contains nodes that don't use this source, it should be pruned + unrelated_model = DbtNode( + unique_id="model.test_project.unrelated", + resource_type=DbtResourceType.MODEL, + depends_on=["some.other.source"], # Doesn't depend on source_orphaned + file_path=Path("/test/unrelated.sql"), + package_name="test_project" + ) + metadata_unused = create_task_metadata( + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + filtered_nodes={unrelated_model.unique_id: unrelated_model}, # Non-empty but doesn't use this source + ) + + # Looking at the code: if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(...) + # When filtered_nodes is None, the condition fails at "filtered_nodes", so no pruning occurs + # When filtered_nodes is {}, the condition also fails at "filtered_nodes" (empty dict is falsy), so no pruning occurs + # When filtered_nodes is non-empty but doesn't use the source, pruning occurs + assert metadata_none is not None # Not pruned when filtered_nodes is None + assert metadata_empty is not None # Not pruned when filtered_nodes is empty dict (falsy) + assert metadata_unused is None # Pruned when filtered_nodes doesn't use this source + + def test_generate_task_or_group_source_pruning(self): + """Test generate_task_or_group with source pruning.""" + # Source with downstream dependencies should generate a task + task_with_downstream = generate_task_or_group( + dag=self.dag, + task_group=None, + node=self.source_with_downstream, + execution_mode=ExecutionMode.LOCAL, + task_args=self.task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + test_indirect_selection=TestIndirectSelection.EAGER, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Orphaned source should NOT generate a task + task_orphaned = generate_task_or_group( + dag=self.dag, + task_group=None, + node=self.source_orphaned, + execution_mode=ExecutionMode.LOCAL, + task_args=self.task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + test_indirect_selection=TestIndirectSelection.EAGER, + filtered_nodes=self.filtered_nodes_with_model, + ) + + assert task_with_downstream is not None + assert task_orphaned is None + + def test_source_pruning_with_different_source_rendering_behaviors(self): + """Test source pruning interaction with different source rendering behaviors.""" + # Create source without freshness + source_no_freshness = DbtNode( + unique_id="source.test_project.source_no_freshness", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/source_no_freshness.yml"), + package_name="test_project", + has_freshness=False, # No freshness + ) + + # Test with SourceRenderingBehavior.NONE - should return None regardless of pruning + metadata_none = create_task_metadata( + node=source_no_freshness, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.NONE, + source_pruning=True, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Test with SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS - should return None for source without freshness + metadata_conditional = create_task_metadata( + node=source_no_freshness, + execution_mode=ExecutionMode.LOCAL, + args=self.task_args, + dbt_dag_task_group_identifier="test", + source_rendering_behavior=SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS, + source_pruning=True, + filtered_nodes=self.filtered_nodes_with_model, + ) + + # Both should be None due to source rendering behavior, before pruning logic is even reached + assert metadata_none is None + assert metadata_conditional is None + + +class TestSourcePruningEdgeCases: + """Test edge cases and error conditions for source pruning.""" + + def test_non_source_node_with_pruning_function(self): + """Test that _is_source_used_by_filtered_nodes works with non-source nodes.""" + # Create a model node + model_node = DbtNode( + unique_id="model.test_project.some_model", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.some_source"], + file_path=Path("/test/some_model.sql"), + package_name="test_project" + ) + + # Create a filtered node that depends on the model + dependent_model = DbtNode( + unique_id="model.test_project.dependent_model", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.some_model"], + file_path=Path("/test/dependent_model.sql"), + package_name="test_project" + ) + + filtered_nodes = { + dependent_model.unique_id: dependent_model, + } + + # Function should work with any node type, not just sources + assert _is_source_used_by_filtered_nodes(model_node, filtered_nodes) is True + + def test_source_pruning_with_complex_dependencies(self): + """Test source pruning with complex dependency chains.""" + # Create a complex dependency chain + source = DbtNode( + unique_id="source.test_project.raw_data", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/raw_data.yml"), + package_name="test_project", + has_freshness=True, + ) + + staging_model = DbtNode( + unique_id="model.test_project.stg_data", + resource_type=DbtResourceType.MODEL, + depends_on=["source.test_project.raw_data"], + file_path=Path("/test/stg_data.sql"), + package_name="test_project" + ) + + mart_model = DbtNode( + unique_id="model.test_project.mart_data", + resource_type=DbtResourceType.MODEL, + depends_on=["model.test_project.stg_data"], + file_path=Path("/test/mart_data.sql"), + package_name="test_project" + ) + + # Test case 1: Only the mart model is in filtered_nodes + # Source should still be considered "unused" because the direct dependency (staging_model) is not included + filtered_nodes_mart_only = { + mart_model.unique_id: mart_model, + } + + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_mart_only) is False + + # Test case 2: Both staging and mart models are in filtered_nodes + # Source should be considered "used" because staging model directly depends on it + filtered_nodes_both = { + staging_model.unique_id: staging_model, + mart_model.unique_id: mart_model, + } + + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_both) is True + + def test_source_pruning_parameter_defaults(self): + """Test that source_pruning parameter defaults work correctly.""" + # Create minimal test data + source = DbtNode( + unique_id="source.test_project.test_source", + resource_type=DbtResourceType.SOURCE, + depends_on=[], + file_path=Path("/test/test_source.yml"), + package_name="test_project", + has_freshness=True, + ) + + dag = DAG("test", start_date=datetime(2023, 1, 1)) + task_args = { + "project_dir": Path("/test"), + "profile_config": ProfileConfig( + profile_name="test", + target_name="test", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + + # Test that generate_task_or_group works without explicitly passing source_pruning + # (should default to False) + task = generate_task_or_group( + dag=dag, + task_group=None, + node=source, + execution_mode=ExecutionMode.LOCAL, + task_args=task_args, + test_behavior=TestBehavior.NONE, + source_rendering_behavior=SourceRenderingBehavior.ALL, + test_indirect_selection=TestIndirectSelection.EAGER, + # source_pruning not specified - should default to False + # filtered_nodes not specified - should default to None + ) + + # Should create a task because pruning is disabled by default + assert task is not None From 82ea59d972d982916370d9c68ad662d25a7e9a07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 08:39:27 +0000 Subject: [PATCH 22/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/dbt/test_pruning.py | 134 +++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/tests/dbt/test_pruning.py b/tests/dbt/test_pruning.py index 703cf7c544..9d34a0e3ea 100644 --- a/tests/dbt/test_pruning.py +++ b/tests/dbt/test_pruning.py @@ -3,8 +3,8 @@ without downstream dependencies from the rendered DAG. """ -from pathlib import Path from datetime import datetime +from pathlib import Path from airflow.models import DAG @@ -35,61 +35,61 @@ def setup_method(self): # source_orphaned (no downstream dependencies) # source_with_multiple_deps -> model3 # -> model4 - + self.source_used = DbtNode( unique_id="source.test_project.source_used", resource_type=DbtResourceType.SOURCE, depends_on=[], file_path=Path("/test/source_used.yml"), - package_name="test_project" + package_name="test_project", ) - + self.source_orphaned = DbtNode( unique_id="source.test_project.source_orphaned", resource_type=DbtResourceType.SOURCE, depends_on=[], file_path=Path("/test/source_orphaned.yml"), - package_name="test_project" + package_name="test_project", ) - + self.source_with_multiple_deps = DbtNode( unique_id="source.test_project.source_multiple", resource_type=DbtResourceType.SOURCE, depends_on=[], file_path=Path("/test/source_multiple.yml"), - package_name="test_project" + package_name="test_project", ) - + self.model1 = DbtNode( unique_id="model.test_project.model1", resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.source_used"], file_path=Path("/test/model1.sql"), - package_name="test_project" + package_name="test_project", ) - + self.model2 = DbtNode( unique_id="model.test_project.model2", resource_type=DbtResourceType.MODEL, depends_on=["model.test_project.model1"], file_path=Path("/test/model2.sql"), - package_name="test_project" + package_name="test_project", ) - + self.model3 = DbtNode( unique_id="model.test_project.model3", resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.source_multiple"], file_path=Path("/test/model3.sql"), - package_name="test_project" + package_name="test_project", ) - + self.model4 = DbtNode( unique_id="model.test_project.model4", resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.source_multiple"], file_path=Path("/test/model4.sql"), - package_name="test_project" + package_name="test_project", ) def test_is_source_used_by_filtered_nodes_basic(self): @@ -99,22 +99,22 @@ def test_is_source_used_by_filtered_nodes_basic(self): self.model1.unique_id: self.model1, self.model2.unique_id: self.model2, } - + assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False - + # Test case 2: Source is used by multiple filtered nodes filtered_nodes_multiple = { self.model3.unique_id: self.model3, self.model4.unique_id: self.model4, } - + assert _is_source_used_by_filtered_nodes(self.source_with_multiple_deps, filtered_nodes_multiple) is True def test_is_source_used_by_filtered_nodes_empty_filtered_nodes(self): """Test behavior with empty filtered_nodes.""" empty_filtered_nodes = {} - + # No filtered nodes means no source should be considered "used" assert _is_source_used_by_filtered_nodes(self.source_used, empty_filtered_nodes) is False assert _is_source_used_by_filtered_nodes(self.source_orphaned, empty_filtered_nodes) is False @@ -127,13 +127,13 @@ def test_is_source_used_by_filtered_nodes_no_dependencies(self): resource_type=DbtResourceType.MODEL, depends_on=[], # No dependencies file_path=Path("/test/independent.sql"), - package_name="test_project" + package_name="test_project", ) - + filtered_nodes = { independent_model.unique_id: independent_model, } - + # Sources should not be considered "used" if no models depend on them assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is False assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False @@ -149,14 +149,14 @@ def test_is_source_used_by_filtered_nodes_mixed_dependencies(self): "source.test_project.source_used", # Source dependency ], file_path=Path("/test/complex.sql"), - package_name="test_project" + package_name="test_project", ) - + filtered_nodes = { self.model1.unique_id: self.model1, complex_model.unique_id: complex_model, } - + # Source should be considered "used" even when model has mixed dependencies assert _is_source_used_by_filtered_nodes(self.source_used, filtered_nodes) is True assert _is_source_used_by_filtered_nodes(self.source_orphaned, filtered_nodes) is False @@ -168,7 +168,7 @@ class TestSourcePruningIntegration: def setup_method(self): """Set up test data for integration tests.""" self.dag = DAG("test_source_pruning", start_date=datetime(2023, 1, 1)) - + self.profile_config = ProfileConfig( profile_name="test", target_name="test", @@ -177,12 +177,12 @@ def setup_method(self): profile_args={"schema": "public"}, ), ) - + self.task_args = { "project_dir": Path("/test/project"), "profile_config": self.profile_config, } - + # Create test nodes self.source_with_downstream = DbtNode( unique_id="source.test_project.source_with_downstream", @@ -192,7 +192,7 @@ def setup_method(self): package_name="test_project", has_freshness=True, # Make it renderable ) - + self.source_orphaned = DbtNode( unique_id="source.test_project.source_orphaned", resource_type=DbtResourceType.SOURCE, @@ -201,20 +201,20 @@ def setup_method(self): package_name="test_project", has_freshness=True, # Make it renderable ) - + self.model_using_source = DbtNode( unique_id="model.test_project.model_using_source", resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.source_with_downstream"], file_path=Path("/test/model_using_source.sql"), - package_name="test_project" + package_name="test_project", ) - + # Filtered nodes represent what will actually be rendered self.filtered_nodes_with_model = { self.model_using_source.unique_id: self.model_using_source, } - + self.filtered_nodes_empty = {} def test_create_task_metadata_source_pruning_disabled(self): @@ -229,7 +229,7 @@ def test_create_task_metadata_source_pruning_disabled(self): source_pruning=False, # Disabled filtered_nodes=self.filtered_nodes_with_model, ) - + metadata_orphaned = create_task_metadata( node=self.source_orphaned, execution_mode=ExecutionMode.LOCAL, @@ -239,7 +239,7 @@ def test_create_task_metadata_source_pruning_disabled(self): source_pruning=False, # Disabled filtered_nodes=self.filtered_nodes_empty, ) - + # Both sources should generate metadata when pruning is disabled assert metadata_with_downstream is not None assert metadata_orphaned is not None @@ -258,7 +258,7 @@ def test_create_task_metadata_source_pruning_enabled(self): source_pruning=True, # Enabled filtered_nodes=self.filtered_nodes_with_model, ) - + # Orphaned source should NOT be rendered metadata_orphaned = create_task_metadata( node=self.source_orphaned, @@ -269,7 +269,7 @@ def test_create_task_metadata_source_pruning_enabled(self): source_pruning=True, # Enabled filtered_nodes=self.filtered_nodes_with_model, ) - + assert metadata_with_downstream is not None assert metadata_orphaned is None # Pruned! assert metadata_with_downstream.id == "source_with_downstream_source" @@ -286,7 +286,7 @@ def test_create_task_metadata_source_pruning_edge_cases(self): source_pruning=True, filtered_nodes=None, # None means skip pruning check ) - + # When filtered_nodes is empty, pruning should be skipped (empty dict is falsy) metadata_empty = create_task_metadata( node=self.source_orphaned, @@ -297,14 +297,14 @@ def test_create_task_metadata_source_pruning_edge_cases(self): source_pruning=True, filtered_nodes={}, # Empty dict means skip pruning check ) - + # When filtered_nodes contains nodes that don't use this source, it should be pruned unrelated_model = DbtNode( unique_id="model.test_project.unrelated", resource_type=DbtResourceType.MODEL, depends_on=["some.other.source"], # Doesn't depend on source_orphaned file_path=Path("/test/unrelated.sql"), - package_name="test_project" + package_name="test_project", ) metadata_unused = create_task_metadata( node=self.source_orphaned, @@ -315,14 +315,14 @@ def test_create_task_metadata_source_pruning_edge_cases(self): source_pruning=True, filtered_nodes={unrelated_model.unique_id: unrelated_model}, # Non-empty but doesn't use this source ) - + # Looking at the code: if source_pruning and filtered_nodes and not _is_source_used_by_filtered_nodes(...) # When filtered_nodes is None, the condition fails at "filtered_nodes", so no pruning occurs - # When filtered_nodes is {}, the condition also fails at "filtered_nodes" (empty dict is falsy), so no pruning occurs + # When filtered_nodes is {}, the condition also fails at "filtered_nodes" (empty dict is falsy), so no pruning occurs # When filtered_nodes is non-empty but doesn't use the source, pruning occurs - assert metadata_none is not None # Not pruned when filtered_nodes is None + assert metadata_none is not None # Not pruned when filtered_nodes is None assert metadata_empty is not None # Not pruned when filtered_nodes is empty dict (falsy) - assert metadata_unused is None # Pruned when filtered_nodes doesn't use this source + assert metadata_unused is None # Pruned when filtered_nodes doesn't use this source def test_generate_task_or_group_source_pruning(self): """Test generate_task_or_group with source pruning.""" @@ -339,7 +339,7 @@ def test_generate_task_or_group_source_pruning(self): test_indirect_selection=TestIndirectSelection.EAGER, filtered_nodes=self.filtered_nodes_with_model, ) - + # Orphaned source should NOT generate a task task_orphaned = generate_task_or_group( dag=self.dag, @@ -353,7 +353,7 @@ def test_generate_task_or_group_source_pruning(self): test_indirect_selection=TestIndirectSelection.EAGER, filtered_nodes=self.filtered_nodes_with_model, ) - + assert task_with_downstream is not None assert task_orphaned is None @@ -368,7 +368,7 @@ def test_source_pruning_with_different_source_rendering_behaviors(self): package_name="test_project", has_freshness=False, # No freshness ) - + # Test with SourceRenderingBehavior.NONE - should return None regardless of pruning metadata_none = create_task_metadata( node=source_no_freshness, @@ -379,7 +379,7 @@ def test_source_pruning_with_different_source_rendering_behaviors(self): source_pruning=True, filtered_nodes=self.filtered_nodes_with_model, ) - + # Test with SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS - should return None for source without freshness metadata_conditional = create_task_metadata( node=source_no_freshness, @@ -390,7 +390,7 @@ def test_source_pruning_with_different_source_rendering_behaviors(self): source_pruning=True, filtered_nodes=self.filtered_nodes_with_model, ) - + # Both should be None due to source rendering behavior, before pruning logic is even reached assert metadata_none is None assert metadata_conditional is None @@ -398,7 +398,7 @@ def test_source_pruning_with_different_source_rendering_behaviors(self): class TestSourcePruningEdgeCases: """Test edge cases and error conditions for source pruning.""" - + def test_non_source_node_with_pruning_function(self): """Test that _is_source_used_by_filtered_nodes works with non-source nodes.""" # Create a model node @@ -407,25 +407,25 @@ def test_non_source_node_with_pruning_function(self): resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.some_source"], file_path=Path("/test/some_model.sql"), - package_name="test_project" + package_name="test_project", ) - + # Create a filtered node that depends on the model dependent_model = DbtNode( unique_id="model.test_project.dependent_model", resource_type=DbtResourceType.MODEL, depends_on=["model.test_project.some_model"], file_path=Path("/test/dependent_model.sql"), - package_name="test_project" + package_name="test_project", ) - + filtered_nodes = { dependent_model.unique_id: dependent_model, } - + # Function should work with any node type, not just sources assert _is_source_used_by_filtered_nodes(model_node, filtered_nodes) is True - + def test_source_pruning_with_complex_dependencies(self): """Test source pruning with complex dependency chains.""" # Create a complex dependency chain @@ -437,38 +437,38 @@ def test_source_pruning_with_complex_dependencies(self): package_name="test_project", has_freshness=True, ) - + staging_model = DbtNode( unique_id="model.test_project.stg_data", resource_type=DbtResourceType.MODEL, depends_on=["source.test_project.raw_data"], file_path=Path("/test/stg_data.sql"), - package_name="test_project" + package_name="test_project", ) - + mart_model = DbtNode( unique_id="model.test_project.mart_data", resource_type=DbtResourceType.MODEL, depends_on=["model.test_project.stg_data"], file_path=Path("/test/mart_data.sql"), - package_name="test_project" + package_name="test_project", ) - + # Test case 1: Only the mart model is in filtered_nodes # Source should still be considered "unused" because the direct dependency (staging_model) is not included filtered_nodes_mart_only = { mart_model.unique_id: mart_model, } - + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_mart_only) is False - + # Test case 2: Both staging and mart models are in filtered_nodes # Source should be considered "used" because staging model directly depends on it filtered_nodes_both = { staging_model.unique_id: staging_model, mart_model.unique_id: mart_model, } - + assert _is_source_used_by_filtered_nodes(source, filtered_nodes_both) is True def test_source_pruning_parameter_defaults(self): @@ -482,7 +482,7 @@ def test_source_pruning_parameter_defaults(self): package_name="test_project", has_freshness=True, ) - + dag = DAG("test", start_date=datetime(2023, 1, 1)) task_args = { "project_dir": Path("/test"), @@ -495,7 +495,7 @@ def test_source_pruning_parameter_defaults(self): ), ), } - + # Test that generate_task_or_group works without explicitly passing source_pruning # (should default to False) task = generate_task_or_group( @@ -510,6 +510,6 @@ def test_source_pruning_parameter_defaults(self): # source_pruning not specified - should default to False # filtered_nodes not specified - should default to None ) - + # Should create a task because pruning is disabled by default assert task is not None From 26adfd8450ac0a9881302688c4203f6f5a12e0e5 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Mon, 22 Sep 2025 10:53:39 +0200 Subject: [PATCH 23/26] improve documentation --- docs/configuration/render-config.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index 26af118fe0..d58553f90f 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -22,7 +22,8 @@ The ``RenderConfig`` class takes the following arguments: - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` - ``airflow_vars_to_purge_cache``: (new in v1.5) Specify Airflow variables that will affect the ``LoadMode.DBT_LS`` cache. See `Caching <./caching.html>`_ for more information. - ``source_rendering_behavior``: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information. -- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task’s display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task’s display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information. +- ``source_pruning``: When set to ``True``, automatically removes (or "prunes") any dbt source nodes from your Airflow DAG that do not have any downstream dependencies within the selected portion of the dbt graph. Defaults to ``False``. See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information. +- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task's display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task's display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information. - ``normalize_task_display_name``: This function allows users to set a custom user-defined function to alter the display name independently of the model name. This way, the task_id can be preserved while the model display name is modified. - ``should_detach_multiple_parents_tests``: A boolean to control if tests that depend on multiple parents should be run as standalone tasks. See `Parsing Methods `_ for more information. - ``enable_owner_inheritance``: (introduced in 1.10.2) A boolean to control if dbt owners should be imported as part of the airflow DAG owners. Defaults to True. From dfcda48cb9d41f8b8bde45b9a106782c6c9b229a Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Mon, 22 Sep 2025 14:10:45 +0200 Subject: [PATCH 24/26] adding extra node to the test_graph --- tests/dbt/test_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 82ca1df655..bbe2186f63 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -654,7 +654,7 @@ def test_load_via_dbt_ls_with_exclude(postgres_profile_config): @pytest.mark.integration @pytest.mark.parametrize( "project_dir,node_count", - [(DBT_PROJECTS_ROOT_DIR / ALTERED_DBT_PROJECT_NAME, 39), (DBT_PROJECTS_ROOT_DIR / "jaffle_shop_python", 28)], + [(DBT_PROJECTS_ROOT_DIR / ALTERED_DBT_PROJECT_NAME, 40), (DBT_PROJECTS_ROOT_DIR / "jaffle_shop_python", 28)], ) def test_load_via_dbt_ls_without_exclude(project_dir, node_count, postgres_profile_config): project_config = ProjectConfig(dbt_project_path=project_dir) From 7187e42138d484c7c76cbcaea3b40ee715c21863 Mon Sep 17 00:00:00 2001 From: corsettigyg Date: Tue, 23 Sep 2025 13:10:56 +0200 Subject: [PATCH 25/26] exclude failing node to fix airflow 3 issue --- dev/dags/example_source_pruning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dev/dags/example_source_pruning.py b/dev/dags/example_source_pruning.py index 49317d889c..68761c24fc 100644 --- a/dev/dags/example_source_pruning.py +++ b/dev/dags/example_source_pruning.py @@ -36,7 +36,11 @@ "install_deps": True, # install any necessary dependencies before running any dbt command "full_refresh": True, # used only in dbt commands that support this flag }, - render_config=RenderConfig(source_rendering_behavior=SourceRenderingBehavior.ALL, source_pruning=True), + render_config=RenderConfig( + source_rendering_behavior=SourceRenderingBehavior.ALL, + source_pruning=True, + exclude=["multibyte"], # Exclude multibyte model to avoid Airflow 3.0 AssetAlias ASCII validation issues + ), # normal dag parameters schedule="@daily", start_date=datetime(2024, 1, 1), From 3eb6888ac50b6cc73f3f38387cb9bbff5f1d0fbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 11:12:21 +0000 Subject: [PATCH 26/26] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dev/dags/example_source_pruning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dev/dags/example_source_pruning.py b/dev/dags/example_source_pruning.py index 68761c24fc..c8f6f316f3 100644 --- a/dev/dags/example_source_pruning.py +++ b/dev/dags/example_source_pruning.py @@ -37,9 +37,11 @@ "full_refresh": True, # used only in dbt commands that support this flag }, render_config=RenderConfig( - source_rendering_behavior=SourceRenderingBehavior.ALL, + source_rendering_behavior=SourceRenderingBehavior.ALL, source_pruning=True, - exclude=["multibyte"], # Exclude multibyte model to avoid Airflow 3.0 AssetAlias ASCII validation issues + exclude=[ + "multibyte" + ], # Exclude multibyte model to avoid Airflow 3.0 AssetAlias ASCII validation issues ), # normal dag parameters schedule="@daily",