Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,48 @@ def profile_config_to_override(self) -> dict[str, Any]:
)
return operator_kwargs

@staticmethod
def get_resource_name_from_unique_id(unique_id: str) -> str:
"""
Return the ``resource_name`` segment of a dbt node ``unique_id``.

Per the `dbt manifest spec
<https://docs.getdbt.com/reference/artifacts/manifest-json#resource-details>`_,
a node ``unique_id`` is ``<resource_type>.<package>.<resource_name>``.
Comment thread
pankajkoti marked this conversation as resolved.
Both ``resource_type`` and ``package`` are constrained identifiers that
cannot contain dots, so the first two dots are unambiguous separators
and everything after the second dot is the full resource name.

For versioned models, dbt appends a fourth segment:
``model.<package>.<resource_name>.<version>`` (see
`node_args.py <https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/contracts/graph/node_args.py#L26C3-L31>`_).
Splitting with ``maxsplit=2`` preserves that suffix:
``model.pkg.my_model.v1`` -> ``my_model.v1``.

:raises ValueError: if ``unique_id`` does not have the expected
``<resource_type>.<package>.<resource_name>`` shape, i.e. fewer
than two dots or any empty segment (e.g. ``model..name``, ``..``,
``model.pkg.``). Malformed inputs are surfaced loudly rather than
silently mis-parsed.
"""
# ``maxsplit=2`` caps the result at 3 elements, so a well-formed
# unique_id always yields exactly 3 non-empty parts (the versioned/source
# suffixes stay attached to the third part).
parts = unique_id.split(".", 2)
if len(parts) != 3 or not all(parts):
raise ValueError(
f"Malformed dbt unique_id, expected '<resource_type>.<package>.<resource_name>': {unique_id!r}"
)
return parts[2]

@property
def resource_name(self) -> str:
"""
Use this property to retrieve the resource name for command generation, for instance: ["dbt", "run", "--models", f"{resource_name}"].
The unique_id format is defined as [<resource_type>.<package>.<resource_name>](https://docs.getdbt.com/reference/artifacts/manifest-json#resource-details).
For a special case like a versioned model, the unique_id follows this pattern: [model.<package>.<resource_name>.<version>](https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/contracts/graph/node_args.py#L26C3-L31)
Delegates to :meth:`get_resource_name_from_unique_id`, which documents the dbt ``unique_id`` format
(including the versioned-model variant ``model.<package>.<resource_name>.<version>``).
"""
return self.unique_id.split(".", 2)[2]
return self.get_resource_name_from_unique_id(self.unique_id)

@property
def name(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PRODUCER_WATCHER_TASK_ID,
WATCHER_TASK_WEIGHT_RULE,
)
from cosmos.dbt.graph import DbtNode
from cosmos.listeners.dag_run_listener import EventStatus
from cosmos.log import get_logger
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key, push_test_result_or_aggregate
Expand Down Expand Up @@ -512,7 +513,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo
raw_flags = upstream_task.add_cmd_flags()
extra_flags = self._filter_flags(raw_flags)

model_selector = self.model_unique_id.split(".", 2)[2]
model_selector = DbtNode.get_resource_name_from_unique_id(self.model_unique_id)
cmd_flags = extra_flags + ["--select", model_selector]

self.build_and_run_cmd(context, cmd_flags=cmd_flags) # type: ignore[attr-defined]
Expand Down
14 changes: 9 additions & 5 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,16 @@ def _apply_node_state_tokens(self, context: Context, node_state_pairs: list[tupl
# "error") that a custom freshness_callback may return. Without exclusion,
# dbt would run the model anyway and either overwrite the pre-set XCom status
# or trigger a race condition with the consumer sensor.
# Use the same parsing as DbtNode.resource_name: unique_id.split(".", 2)[2]
# This preserves version suffixes (e.g. model.pkg.my_model.v1 -> my_model.v1)
excluded_ids = [uid for uid, state in node_state_pairs if state not in DBT_SUCCESS_STATUSES]
if not excluded_ids:
return
model_names = sorted({uid.split(".", 2)[2] for uid in excluded_ids if len(uid.split(".", 2)) == 3})
resource_names = set()
for uid in excluded_ids:
try:
resource_names.add(DbtNode.get_resource_name_from_unique_id(uid))
except ValueError:
logger.warning("Skipping malformed dbt unique_id while building source-freshness exclude list: %s", uid)
model_names = sorted(resource_names)
exclude_str = " ".join(model_names)
if exclude_str:
current_exclude = getattr(self, "exclude", None)
Expand Down Expand Up @@ -591,7 +595,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo
self.model_unique_id,
self.project_dir,
)
resource_name = self.model_unique_id.split(".", 2)[2]
resource_name = DbtNode.get_resource_name_from_unique_id(self.model_unique_id)
cmd_flags = ["--select", f"source:{resource_name}"]
self.build_and_run_cmd(context, cmd_flags=cmd_flags)
logger.info("dbt source freshness completed successfully on retry for source '%s'", self.model_unique_id)
Expand Down Expand Up @@ -658,7 +662,7 @@ def _fallback_to_non_watcher_run(self, try_number: int, context: Context) -> boo
try_number,
)

model_selector = self.model_unique_id.split(".", 2)[2]
model_selector = DbtNode.get_resource_name_from_unique_id(self.model_unique_id)
cmd_flags = ["--select", model_selector]
self.build_and_run_cmd(context, cmd_flags=cmd_flags)
logger.info("dbt test completed successfully for model '%s'", self.model_unique_id)
Expand Down
16 changes: 16 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ def test_dbt_node_name_and_select(unique_id, expected_name, expected_select):
assert node.resource_name == expected_select


class TestGetResourceNameFromUniqueId:
def test_plain_model(self):
assert DbtNode.get_resource_name_from_unique_id("model.my_pkg.my_model") == "my_model"

def test_versioned_model_preserves_version_suffix(self):
assert DbtNode.get_resource_name_from_unique_id("model.my_pkg.my_model.v1") == "my_model.v1"

@pytest.mark.parametrize(
"malformed",
["", "foo", "foo.bar", "model..name", "..", "model.pkg.", ".pkg.name"],
)
def test_malformed_unique_id_raises(self, malformed):
with pytest.raises(ValueError):
DbtNode.get_resource_name_from_unique_id(malformed)


def test_dbt_node_meta():
valid_node = DbtNode(
unique_id="some-id",
Expand Down
41 changes: 38 additions & 3 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ExecutionMode,
SourceRenderingBehavior,
)
from cosmos.dbt.graph import DbtNode
from cosmos.operators._watcher.base import store_compiled_sql_for_model
from cosmos.operators._watcher.triggerer import WatcherEventReason, WatcherTrigger
from cosmos.operators.watcher import (
Expand Down Expand Up @@ -1227,7 +1228,7 @@ def test_fallback_to_non_watcher_run(self):
sensor.build_and_run_cmd.assert_called_once()
args, kwargs = sensor.build_and_run_cmd.call_args
assert "--select" in kwargs["cmd_flags"]
assert MODEL_UNIQUE_ID.split(".", 2)[2] in kwargs["cmd_flags"]
assert DbtNode.get_resource_name_from_unique_id(MODEL_UNIQUE_ID) in kwargs["cmd_flags"]

def test_fallback_strips_producer_log_format_by_default(self):
"""Producer's ``--log-format json`` (internal, used for event-stream parsing) must not leak into
Expand Down Expand Up @@ -2493,6 +2494,31 @@ def test_dbt_source_watcher_operator_template_fields():
assert field in DbtSourceWatcherOperator.template_fields


def test_dbt_source_watcher_operator_fallback_runs_source_freshness():
"""On retry the source sensor should run ``dbt source freshness --select source:<resource_name>``
locally for its specific source.
"""
from cosmos.operators.watcher import DbtSourceWatcherOperator

source_uid = "source.jaffle_shop.raw.orders"
extra_context = {"dbt_node_config": {"unique_id": source_uid}}
sensor = DbtSourceWatcherOperator(
task_id="raw_orders.source",
project_dir="/tmp/project",
profile_config=None,
extra_context=extra_context,
)
sensor.build_and_run_cmd = MagicMock()
context = MagicMock()

result = sensor._fallback_to_non_watcher_run(2, context)

assert result is True
sensor.build_and_run_cmd.assert_called_once()
_, kwargs = sensor.build_and_run_cmd.call_args
assert kwargs["cmd_flags"] == ["--select", "source:raw.orders"]


class TestDbtTestWatcherOperator:
"""Tests for DbtTestWatcherOperator — the sensor that watches aggregated test results."""

Expand Down Expand Up @@ -2589,7 +2615,7 @@ def test_fallback_runs_dbt_test_on_retry(self):
mock_fallback_to_non_watcher_run.assert_called_once()
sensor.build_and_run_cmd.assert_called_once()
_, kwargs = sensor.build_and_run_cmd.call_args
assert kwargs["cmd_flags"] == ["--select", self.MODEL_UID.split(".", 2)[2]]
assert kwargs["cmd_flags"] == ["--select", DbtNode.get_resource_name_from_unique_id(self.MODEL_UID)]

def test_fallback_via_poke_does_not_forward_producer_flags(self):
"""Driving through ``poke`` on retry, the fallback should issue ``dbt test`` with
Expand All @@ -2611,7 +2637,7 @@ def test_fallback_via_poke_does_not_forward_producer_flags(self):

mock_fallback_to_non_watcher_run.assert_called_once()
_, kwargs = sensor.build_and_run_cmd.call_args
assert kwargs["cmd_flags"] == ["--select", self.MODEL_UID.split(".", 2)[2]]
assert kwargs["cmd_flags"] == ["--select", DbtNode.get_resource_name_from_unique_id(self.MODEL_UID)]
assert "--full-refresh" not in kwargs["cmd_flags"]
assert sensor.base_cmd == ["test"]

Expand Down Expand Up @@ -2983,6 +3009,15 @@ def test_apply_node_state_tokens_appends_to_existing_exclude(self):
assert "existing_model" in producer.exclude
assert "m1" in producer.exclude

def test_apply_node_state_tokens_skips_malformed_unique_id(self):
producer = self._make_producer()
producer.exclude = None
ti = MagicMock()
context = {"ti": ti}
producer._apply_node_state_tokens(context, [("malformed_uid", "skipped"), ("model.pkg.m1", "skipped")])
# The malformed id is skipped while the valid one still makes it into exclude.
assert producer.exclude == "m1"

def test_apply_node_state_tokens_noop_when_empty(self):
producer = self._make_producer()
producer.exclude = None
Expand Down
Loading