diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index d067d77b77..fc10482207 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -1,5 +1,6 @@ from __future__ import annotations from pathlib import Path +import copy from typing import TYPE_CHECKING @@ -43,6 +44,10 @@ def __init__(self, project_dir: Path, statement: str): self.other: list[str] = [] self.load_from_statement(statement) + @property + def is_empty(self) -> bool: + return not (self.paths or self.tags or self.config or self.other) + def load_from_statement(self, statement: str) -> None: """ Load in-place select parameters. @@ -84,27 +89,30 @@ def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: Selector https://docs.getdbt.com/reference/node-selection/yaml-selectors """ selected_nodes = set() - for node_id, node in nodes.items(): - if config.tags and not (sorted(node.tags) == sorted(config.tags)): - continue - supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} - if config.config: - config_tag = config.config.get("tags") - if config_tag and config_tag not in supported_node_config.get("tags", []): + if not config.is_empty: + for node_id, node in nodes.items(): + if config.tags and not (sorted(node.tags) == sorted(config.tags)): continue - # Remove 'tags' as they've already been filtered for - config.config.pop("tags", None) - supported_node_config.pop("tags", None) + supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} + config_tag = config.config.get("tags") + if config.config: + if config_tag and config_tag not in supported_node_config.get("tags", []): + continue + + # Remove 'tags' as they've already been filtered for + config_copy = copy.deepcopy(config.config) + config_copy.pop("tags", None) + supported_node_config.pop("tags", None) - if not (config.config.items() <= supported_node_config.items()): - continue + if not (config_copy.items() <= supported_node_config.items()): + continue - if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))): - continue + if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))): + continue - selected_nodes.add(node_id) + selected_nodes.add(node_id) return selected_nodes @@ -166,9 +174,10 @@ def select_nodes( nodes_ids = set(nodes.keys()) + exclude_ids: set[str] = set() for statement in exclude: config = SelectorConfig(project_dir, statement) - exclude_ids = select_nodes_ids_by_intersection(nodes, config) + exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config))) subset_ids = set(nodes_ids) - set(exclude_ids) return {id_: nodes[id_] for id_ in subset_ids} diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index ac404b14a6..7c6ff32922 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -2,6 +2,7 @@ import pytest +from cosmos.dbt.selector import SelectorConfig from cosmos.constants import DbtResourceType from cosmos.dbt.graph import DbtNode from cosmos.dbt.selector import select_nodes @@ -9,6 +10,34 @@ SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") + +@pytest.fixture +def selector_config(): + project_dir = Path("/path/to/project") + statement = "" + return SelectorConfig(project_dir, statement) + + +@pytest.mark.parametrize( + "paths, tags, config, other, expected", + [ + ([], [], {}, [], True), + ([Path("path1")], [], {}, [], False), + ([], ["tag:has_child"], {}, [], False), + ([], [], {"config.tags:test"}, [], False), + ([], [], {}, ["other"], False), + ([Path("path1")], ["tag:has_child"], {"config.tags:test"}, ["other"], False), + ], +) +def test_is_empty_config(selector_config, paths, tags, config, other, expected): + selector_config.paths = paths + selector_config.tags = tags + selector_config.config = config + selector_config.other = other + + assert selector_config.is_empty == expected + + grandparent_node = DbtNode( name="grandparent", unique_id="grandparent", @@ -37,10 +66,32 @@ config={"materialized": "table", "tags": ["is_child"]}, ) +grandchild_1_test_node = DbtNode( + name="grandchild_1", + unique_id="grandchild_1", + resource_type=DbtResourceType.MODEL, + depends_on=["parent"], + file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_1.sql", + tags=["nightly"], + config={"materialized": "table", "tags": ["deprecated", "test"]}, +) + +grandchild_2_test_node = DbtNode( + name="grandchild_2", + unique_id="grandchild_2", + resource_type=DbtResourceType.MODEL, + depends_on=["parent"], + file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_2.sql", + tags=["nightly"], + config={"materialized": "table", "tags": ["deprecated", "test2"]}, +) + sample_nodes = { grandparent_node.unique_id: grandparent_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, } @@ -52,13 +103,19 @@ def test_select_nodes_by_select_tag(): def test_select_nodes_by_select_config(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.materialized:table"]) - expected = {child_node.unique_id: child_node} + expected = { + child_node.unique_id: child_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, + } assert selected == expected def test_select_nodes_by_select_config_tag(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child"]) - expected = {child_node.unique_id: child_node} + expected = { + child_node.unique_id: child_node, + } assert selected == expected @@ -74,6 +131,21 @@ def test_select_nodes_by_select_union_config_tag(): assert selected == expected +def test_select_nodes_by_select_union_config_test_tags(): + selected = select_nodes( + project_dir=SAMPLE_PROJ_PATH, + nodes=sample_nodes, + select=["config.tags:test", "config.tags:test2", "config.materialized:view"], + ) + expected = { + grandparent_node.unique_id: grandparent_node, + parent_node.unique_id: parent_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, + } + assert selected == expected + + def test_select_nodes_by_select_intersection_config_tag(): selected = select_nodes( project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child,config.materialized:view"] @@ -95,6 +167,8 @@ def test_select_nodes_by_select_union(): grandparent_node.unique_id: grandparent_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, } assert selected == expected @@ -106,7 +180,11 @@ def test_select_nodes_by_select_intersection(): def test_select_nodes_by_exclude_tag(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["tag:has_child"]) - expected = {child_node.unique_id: child_node} + expected = { + child_node.unique_id: child_node, + grandchild_1_test_node.unique_id: grandchild_1_test_node, + grandchild_2_test_node.unique_id: grandchild_2_test_node, + } assert selected == expected @@ -122,3 +200,15 @@ def test_select_nodes_by_select_union_exclude_tags(): ) expected = {} assert selected == expected + + +def test_select_nodes_by_exclude_union_config_test_tags(): + selected = select_nodes( + project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["config.tags:test", "config.tags:test2"] + ) + expected = { + grandparent_node.unique_id: grandparent_node, + parent_node.unique_id: parent_node, + child_node.unique_id: child_node, + } + assert selected == expected