Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
import copy

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -43,6 +44,10 @@ def __init__(self, project_dir: Path, statement: str):
self.other: list[str] = []
self.load_from_statement(statement)

@property
def is_empty(self) -> bool:
return not (self.paths or self.tags or self.config or self.other)

def load_from_statement(self, statement: str) -> None:
"""
Load in-place select parameters.
Expand Down Expand Up @@ -84,27 +89,30 @@ def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: Selector
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
selected_nodes = set()
for node_id, node in nodes.items():
if config.tags and not (sorted(node.tags) == sorted(config.tags)):
continue

supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
if config.config:
config_tag = config.config.get("tags")
if config_tag and config_tag not in supported_node_config.get("tags", []):
if not config.is_empty:
for node_id, node in nodes.items():
if config.tags and not (sorted(node.tags) == sorted(config.tags)):
continue

# Remove 'tags' as they've already been filtered for
config.config.pop("tags", None)
supported_node_config.pop("tags", None)
supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
config_tag = config.config.get("tags")
if config.config:
if config_tag and config_tag not in supported_node_config.get("tags", []):
continue

# Remove 'tags' as they've already been filtered for
config_copy = copy.deepcopy(config.config)
config_copy.pop("tags", None)
supported_node_config.pop("tags", None)

if not (config.config.items() <= supported_node_config.items()):
continue
if not (config_copy.items() <= supported_node_config.items()):
continue

if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))):
continue
if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))):
continue

selected_nodes.add(node_id)
selected_nodes.add(node_id)

return selected_nodes

Expand Down Expand Up @@ -166,9 +174,10 @@ def select_nodes(

nodes_ids = set(nodes.keys())

exclude_ids: set[str] = set()
for statement in exclude:
config = SelectorConfig(project_dir, statement)
exclude_ids = select_nodes_ids_by_intersection(nodes, config)
exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config)))
Comment thread
jensenity marked this conversation as resolved.
subset_ids = set(nodes_ids) - set(exclude_ids)

return {id_: nodes[id_] for id_ in subset_ids}
96 changes: 93 additions & 3 deletions tests/dbt/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,42 @@

import pytest

from cosmos.dbt.selector import SelectorConfig
from cosmos.constants import DbtResourceType
from cosmos.dbt.graph import DbtNode
from cosmos.dbt.selector import select_nodes
from cosmos.exceptions import CosmosValueError

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")


@pytest.fixture
def selector_config():
project_dir = Path("/path/to/project")
statement = ""
return SelectorConfig(project_dir, statement)


@pytest.mark.parametrize(
"paths, tags, config, other, expected",
[
([], [], {}, [], True),
([Path("path1")], [], {}, [], False),
([], ["tag:has_child"], {}, [], False),
([], [], {"config.tags:test"}, [], False),
([], [], {}, ["other"], False),
([Path("path1")], ["tag:has_child"], {"config.tags:test"}, ["other"], False),
],
)
def test_is_empty_config(selector_config, paths, tags, config, other, expected):
selector_config.paths = paths
selector_config.tags = tags
selector_config.config = config
selector_config.other = other

assert selector_config.is_empty == expected


grandparent_node = DbtNode(
name="grandparent",
unique_id="grandparent",
Expand Down Expand Up @@ -37,10 +66,32 @@
config={"materialized": "table", "tags": ["is_child"]},
)

grandchild_1_test_node = DbtNode(
name="grandchild_1",
unique_id="grandchild_1",
resource_type=DbtResourceType.MODEL,
depends_on=["parent"],
file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_1.sql",
tags=["nightly"],
config={"materialized": "table", "tags": ["deprecated", "test"]},
)

grandchild_2_test_node = DbtNode(
name="grandchild_2",
unique_id="grandchild_2",
resource_type=DbtResourceType.MODEL,
depends_on=["parent"],
file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_2.sql",
tags=["nightly"],
config={"materialized": "table", "tags": ["deprecated", "test2"]},
)

sample_nodes = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}


Expand All @@ -52,13 +103,19 @@ def test_select_nodes_by_select_tag():

def test_select_nodes_by_select_config():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.materialized:table"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


def test_select_nodes_by_select_config_tag():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
}
assert selected == expected


Expand All @@ -74,6 +131,21 @@ def test_select_nodes_by_select_union_config_tag():
assert selected == expected


def test_select_nodes_by_select_union_config_test_tags():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH,
nodes=sample_nodes,
select=["config.tags:test", "config.tags:test2", "config.materialized:view"],
)
expected = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


def test_select_nodes_by_select_intersection_config_tag():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child,config.materialized:view"]
Expand All @@ -95,6 +167,8 @@ def test_select_nodes_by_select_union():
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected

Expand All @@ -106,7 +180,11 @@ def test_select_nodes_by_select_intersection():

def test_select_nodes_by_exclude_tag():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["tag:has_child"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


Expand All @@ -122,3 +200,15 @@ def test_select_nodes_by_select_union_exclude_tags():
)
expected = {}
assert selected == expected


def test_select_nodes_by_exclude_union_config_test_tags():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["config.tags:test", "config.tags:test2"]
)
expected = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
}
assert selected == expected