Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def _default_freshness_callback(
task_group: TaskGroup | None,
nodes: dict[str, DbtNode] | None,
sources_json: dict[str, Any] | None,
) -> tuple[list[str], str]:
"""Return unique_ids of nodes that must be skipped due to stale sources, plus the status ``"skip"``.
) -> list[tuple[str, str]]:
"""Return a list of ``(unique_id, state)`` tuples for nodes that must be skipped due to stale sources.

Stale sources are those with ``status`` of ``"error"`` or ``"warn"`` in ``sources_json["results"]``.

Expand All @@ -74,11 +74,11 @@ def _default_freshness_callback(
Traversal is a DFS over the reverse-dependency graph built from ``nodes``.
"""
if not nodes or not sources_json:
return [], "skip"
return []

stale_source_ids = {r["unique_id"] for r in sources_json.get("results", []) if r.get("status") in ("error", "warn")}
if not stale_source_ids:
return [], "skip"
return []

# Build reverse map: dep_id -> set of node_ids that directly depend on it
dependents: dict[str, set[str]] = {}
Expand Down Expand Up @@ -110,7 +110,7 @@ def _default_freshness_callback(
# and test hash-suffixed unique_ids are not valid dbt --exclude selectors.
excludable = [uid for uid in visited if nodes.get(uid) and nodes[uid].resource_type in _excludable_resource_types]
logger.info("Nodes to skip due to stale sources: %s", excludable)
return excludable, "skip"
return [(uid, "skipped") for uid in excludable]


class _NullWriter:
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._check_source_freshness: bool = kwargs.pop("_check_source_freshness", False)
self._freshness_callback: Callable[
[Context, Any, TaskGroup | None, dict[str, DbtNode] | None, dict[str, Any] | None],
tuple[list[str], str],
list[tuple[str, str]],
] = _default_freshness_callback
# Do not publish compiled_sql to the producer's rendered_template: it would contain SQL for
# all models run by the producer, is often truncated in the UI due to size, and is of no use
Expand Down Expand Up @@ -271,13 +271,13 @@ def _event_callback(event: Any) -> None:
return result
return super().run_dbt_runner(command, env, cwd, **kwargs)

def _push_skipped_xcom_for_model(self, ti: Any, unique_id: str) -> None:
"""Push a synthetic ``"skipped"`` status XCom for a model excluded due to a stale upstream source.
def _push_node_state_xcom(self, ti: Any, unique_id: str, state: str) -> None:
"""Push a synthetic status XCom for a node using the given ``state``.

Uses the unified ``*_status`` XCom key that consumer sensors already poll.
"""
uid_key = unique_id.replace(".", "__")
safe_xcom_push(task_instance=ti, key=f"{uid_key}_status", value={"status": "skipped", "outlet_uris": []})
safe_xcom_push(task_instance=ti, key=f"{uid_key}_status", value={"status": state, "outlet_uris": []})

def _run_source_freshness(self, context: Context) -> None:
"""Run ``dbt source freshness`` via ``build_cmd`` and ``run_command``.
Expand Down Expand Up @@ -306,25 +306,25 @@ def _run_source_freshness(self, context: Context) -> None:
self.dbt_cmd_flags = original_dbt_cmd_flags
context.pop("_check_source_freshness", None) # type: ignore[typeddict-item]

def _skipped_node_token(self, context: Context, node_unique_ids: list[str]) -> None:
if not node_unique_ids:
def _apply_node_state_tokens(self, context: Context, node_state_pairs: list[tuple[str, str]]) -> None:
if not node_state_pairs:
return

ti = context["ti"]

for unique_id in node_unique_ids:
logger.info(
"Marking resource '%s' as skipped (stale upstream source)",
unique_id,
)
self._push_skipped_xcom_for_model(ti, unique_id)
for unique_id, state in node_state_pairs:
logger.info("Marking resource '%s' as %s (stale upstream source)", unique_id, state)
self._push_node_state_xcom(ti, unique_id, state)

# Only exclude nodes in the "skipped" state from future dbt runs.
# Use the same parsing as DbtNode.resource_name: unique_id.split(".", 2)[2]
# This preserves version suffixes (e.g. model.pkg.my_model.v1 -> my_model.v1)
model_names = sorted({uid.split(".", 2)[2] for uid in node_unique_ids if len(uid.split(".", 2)) == 3})

current_exclude = getattr(self, "exclude", None)
skipped_ids = [uid for uid, state in node_state_pairs if state == "skipped"]
if not skipped_ids:
return
model_names = sorted({uid.split(".", 2)[2] for uid in skipped_ids if len(uid.split(".", 2)) == 3})
exclude_str = " ".join(model_names)
current_exclude = getattr(self, "exclude", None)
if current_exclude:
self.exclude = f"{current_exclude} {exclude_str}"
else:
Expand Down Expand Up @@ -368,8 +368,8 @@ def _apply_source_freshness(self, context: Context) -> None:
tg_dbt_graph = getattr(task_group, "dbt_graph", None)
nodes = getattr(tg_dbt_graph, "nodes", None)

node_ids_to_skip, _ = self._freshness_callback(context, dag, task_group, nodes, self._sources_json)
self._skipped_node_token(context, node_ids_to_skip)
freshness_results = self._freshness_callback(context, dag, task_group, nodes, self._sources_json)
self._apply_node_state_tokens(context, freshness_results)

def execute(self, context: Context, **kwargs: Any) -> Any:
# Pre-compute the dataset namespace for per-model outlet URI generation.
Expand Down
70 changes: 37 additions & 33 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,11 +1978,10 @@ class TestDefaultFreshnessCallback:
"""Tests for the _default_freshness_callback function."""

def test_returns_empty_when_no_nodes(self):
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=None, sources_json=None
)
assert node_ids == []
assert status == "skip"
assert result == []

def test_returns_empty_when_no_stale_sources(self):
from cosmos.constants import DbtResourceType
Expand All @@ -1998,11 +1997,10 @@ def test_returns_empty_when_no_stale_sources(self):
),
}
sources_json = {"results": [{"unique_id": "source.pkg.src1", "status": "pass"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
assert node_ids == []
assert status == "skip"
assert result == []

def test_returns_transitive_dependents_of_stale_source(self):
from cosmos.constants import DbtResourceType
Expand Down Expand Up @@ -2032,11 +2030,11 @@ def test_returns_transitive_dependents_of_stale_source(self):
),
}
sources_json = {"results": [{"unique_id": "source.pkg.src1", "status": "error"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
assert set(node_ids) == {"model.pkg.m1", "model.pkg.m2"}
assert status == "skip"
assert {uid for uid, _ in result} == {"model.pkg.m1", "model.pkg.m2"}
assert all(state == "skipped" for _, state in result)

def test_excludes_test_nodes(self):
from cosmos.constants import DbtResourceType
Expand Down Expand Up @@ -2066,12 +2064,11 @@ def test_excludes_test_nodes(self):
),
}
sources_json = {"results": [{"unique_id": "source.pkg.src1", "status": "warn"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
# Only model nodes, not test nodes
assert node_ids == ["model.pkg.m1"]
assert status == "skip"
assert result == [("model.pkg.m1", "skipped")]

def test_node_with_clean_upstream_not_skipped(self):
"""A node that depends on both a stale source and a clean model should not be skipped.
Expand Down Expand Up @@ -2109,12 +2106,11 @@ def test_node_with_clean_upstream_not_skipped(self):
),
}
sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "warn"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
# A has a clean path via clean_model → neither A nor C should be skipped
assert node_ids == []
assert status == "skip"
assert result == []

def test_node_skipped_only_when_all_upstreams_stale(self):
"""A node whose every upstream is stale or already skipped must be skipped.
Expand Down Expand Up @@ -2163,11 +2159,11 @@ def test_node_skipped_only_when_all_upstreams_stale(self):
{"unique_id": "source.pkg.stale_src2", "status": "error"},
]
}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
assert set(node_ids) == {"model.pkg.A", "model.pkg.B", "model.pkg.C", "model.pkg.D"}
assert status == "skip"
assert {uid for uid, _ in result} == {"model.pkg.A", "model.pkg.B", "model.pkg.C", "model.pkg.D"}
assert all(state == "skipped" for _, state in result)

def test_already_visited_dependent_not_processed_twice(self):
"""A dependent reachable via two stale paths is only processed once.
Expand Down Expand Up @@ -2208,11 +2204,11 @@ def test_already_visited_dependent_not_processed_twice(self):
),
}
sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "error"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
assert set(node_ids) == {"model.pkg.A", "model.pkg.B", "model.pkg.C"}
assert status == "skip"
assert {uid for uid, _ in result} == {"model.pkg.A", "model.pkg.B", "model.pkg.C"}
assert all(state == "skipped" for _, state in result)

def test_dependent_node_missing_from_nodes_is_skipped(self):
"""A dependent_id whose node cannot be resolved via ``nodes.get`` is silently ignored.
Expand Down Expand Up @@ -2249,11 +2245,10 @@ def get(self, key, default=None): # type: ignore[override]
# nodes.get("model.pkg.A") will return None → the node is silently skipped
nodes = _NullOnGet({"model.pkg.A"}, raw_nodes)
sources_json = {"results": [{"unique_id": "source.pkg.stale_src", "status": "error"}]}
node_ids, status = _default_freshness_callback(
result = _default_freshness_callback(
context=MagicMock(), dag=None, task_group=None, nodes=nodes, sources_json=sources_json
)
assert node_ids == []
assert status == "skip"
assert result == []


class TestProducerSourceFreshness:
Expand All @@ -2278,44 +2273,53 @@ def test_init_stores_check_source_freshness_flag(self):
producer = self._make_producer(_check_source_freshness=True)
assert producer._check_source_freshness is True

def test_push_skipped_xcom_for_model(self):
def test_push_node_state_xcom(self):
producer = self._make_producer()
ti = MagicMock()
producer._push_skipped_xcom_for_model(ti, "model.pkg.my_model")
producer._push_node_state_xcom(ti, "model.pkg.my_model", "skipped")
ti.xcom_push.assert_called_once_with(
key="model__pkg__my_model_status", value={"status": "skipped", "outlet_uris": []}
)

def test_skipped_node_token_updates_exclude(self):
def test_apply_node_state_tokens_updates_exclude(self):
producer = self._make_producer()
producer.exclude = None
ti = MagicMock()
context = {"ti": ti}
producer._skipped_node_token(context, ["model.pkg.m1", "model.pkg.m2"])
# Both models should be pushed as skipped
producer._apply_node_state_tokens(context, [("model.pkg.m1", "skipped"), ("model.pkg.m2", "skipped")])
# Both models should be pushed with their state
assert ti.xcom_push.call_count == 2
# Exclude should contain the model short names
assert "m1" in producer.exclude
assert "m2" in producer.exclude

def test_skipped_node_token_appends_to_existing_exclude(self):
def test_apply_node_state_tokens_appends_to_existing_exclude(self):
producer = self._make_producer()
producer.exclude = "existing_model"
ti = MagicMock()
context = {"ti": ti}
producer._skipped_node_token(context, ["model.pkg.m1"])
producer._apply_node_state_tokens(context, [("model.pkg.m1", "skipped")])
assert "existing_model" in producer.exclude
assert "m1" in producer.exclude

def test_skipped_node_token_noop_when_empty(self):
def test_apply_node_state_tokens_noop_when_empty(self):
producer = self._make_producer()
producer.exclude = None
ti = MagicMock()
context = {"ti": ti}
producer._skipped_node_token(context, [])
producer._apply_node_state_tokens(context, [])
ti.xcom_push.assert_not_called()
assert producer.exclude is None

def test_apply_node_state_tokens_non_skipped_state_does_not_update_exclude(self):
producer = self._make_producer()
producer.exclude = None
ti = MagicMock()
context = {"ti": ti}
producer._apply_node_state_tokens(context, [("model.pkg.m1", "failed")])
ti.xcom_push.assert_called_once_with(key="model__pkg__m1_status", value={"status": "failed", "outlet_uris": []})
assert producer.exclude is None

def test_run_dbt_runner_skips_callback_during_source_freshness(self):
"""run_dbt_runner must not register the XCom-pushing callback during the source freshness
pre-check. Registering it would leave a stale entry in _dbt_runner_callbacks that fires
Expand Down
Loading