Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."
PLUS_SELECTOR = "+"
GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"
AT_SELECTOR = "@"
GRAPH_SELECTOR_REGEX = r"^(@|[0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"

logger = get_logger(__name__)

Expand All @@ -35,6 +36,7 @@ class GraphSelector:
+model_d+
2+model_e
model_f+3
@model_g
+/path/to/model_g+
path:/path/to/model_h+
+tag:nightly
Expand All @@ -46,6 +48,7 @@ class GraphSelector:
node_name: str
precursors: str | None
descendants: str | None
at_operator: bool = False

@property
def precursors_depth(self) -> int:
Expand All @@ -56,6 +59,8 @@ def precursors_depth(self) -> int:
0: if it shouldn't return any precursors
>0: upperbound number of parent generations
"""
if self.at_operator:
return -1
if not self.precursors:
return 0
if self.precursors == "+":
Expand Down Expand Up @@ -90,7 +95,13 @@ def parse(text: str) -> GraphSelector | None:
precursors, node_name, descendants = regex_match.groups()
if "/" in node_name and not node_name.startswith(PATH_SELECTOR):
node_name = f"{PATH_SELECTOR}{node_name}"
return GraphSelector(node_name, precursors, descendants)

at_operator = precursors == AT_SELECTOR
if at_operator:
precursors = None
descendants = "+" # @ implies all descendants

return GraphSelector(node_name, precursors, descendants, at_operator)
return None

def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
Expand All @@ -101,7 +112,7 @@ def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, select
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where precursor nodes will be added to.
"""
if self.precursors:
if self.precursors or self.at_operator:
depth = self.precursors_depth
previous_generation = {root_id}
processed_nodes = set()
Expand Down Expand Up @@ -203,16 +214,39 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]:
root_id = node_by_name[self.node_name]
root_nodes.add(root_id)
else:
logger.warn(f"Selector {self.node_name} not found.")
logger.warning(f"Selector {self.node_name} not found.")
return selected_nodes

selected_nodes.update(root_nodes)

for root_id in root_nodes:
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)
self._select_nodes(nodes, root_nodes, selected_nodes)

return selected_nodes

def _select_nodes(self, nodes: dict[str, DbtNode], root_nodes: set[str], selected_nodes: set[str]) -> None:
"""
Handle selection of nodes based on the graph selector configuration.

:param nodes: dbt project nodes
:param root_nodes: Set of root node ids
:param selected_nodes: Set where selected nodes will be added to.
"""
if self.at_operator:
descendants: set[str] = set()
# First get all descendants
for root_id in root_nodes:
self.select_node_descendants(nodes, root_id, descendants)
selected_nodes.update(descendants)

# Get ancestors for root nodes and all descendants
for node_id in root_nodes | descendants:
self.select_node_precursors(nodes, node_id, selected_nodes)
else:
# Normal selection
for root_id in root_nodes:
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)


class SelectorConfig:
"""
Expand Down
12 changes: 12 additions & 0 deletions docs/configuration/selecting-excluding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The ``select`` and ``exclude`` parameters are lists, with values like the follow
- ``config.materialized:table``: include/exclude models with the config ``materialized: table``
- ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory
- ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs <https://docs.getdbt.com/reference/node-selection/graph-operators>`_)
- ``@node_name`` (@ operator): include/exclude the node with name ``node_name``, all its descendants, and all ancestors of those descendants. This is useful in CI environments where you want to build a model and all its descendants, but you need the ancestors of those descendants to exist first.
- ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs <https://docs.getdbt.com/reference/node-selection/set-operators>`_)
- ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag``

Expand Down Expand Up @@ -91,6 +92,17 @@ Examples:
)
)

.. code-block:: python

from cosmos import DbtDag, RenderConfig

jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["@my_model"], # selects my_model, all its descendants,
# and all ancestors needed to build those descendants
)
)

Using ``selector``
--------------------------------
.. note::
Expand Down
80 changes: 80 additions & 0 deletions tests/dbt/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,83 @@ def test_should_include_node_without_depends_on(selector_config):
def test_select_using_graph_operators(select_statement, expected):
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=select_statement)
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator():
"""Test basic @ operator selecting node, descendants and ancestors of all"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@parent"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_leaf_node():
"""Test @ operator on a leaf node (no descendants)"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_root_node():
"""Test @ operator on a root node (no ancestors)"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@grandparent"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_union():
"""Test @ operator union with another selector"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child", "tag:has_child"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_with_path():
"""Test @ operator with a path"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@gen2/models"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_nonexistent_node():
"""Test @ operator with a node that doesn't exist"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@nonexistent"])
expected = []
assert sorted(selected.keys()) == expected


def test_exclude_with_at_operator():
"""Test excluding nodes selected by @ operator"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["@parent"])
expected = ["model.dbt-proj.orphaned"]
assert sorted(selected.keys()) == expected