Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_iso_to_string,
_log_dbt_event,
build_producer_state_fetcher,
get_compiled_sql_xcom_key,
get_dbt_event_xcom_key,
get_status_xcom_key,
get_xcom_val,
is_dbt_node_status_skipped,
is_dbt_node_status_success,
Expand Down Expand Up @@ -118,8 +121,7 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None:
"msg": msg,
}

xcom_key = f"{unique_id.replace('.', '__')}_dbt_event"
safe_xcom_push(task_instance=task_instance, key=xcom_key, value=dbt_event)
safe_xcom_push(task_instance=task_instance, key=get_dbt_event_xcom_key(unique_id), value=dbt_event)


def _extract_compiled_sql(
Expand Down Expand Up @@ -155,7 +157,7 @@ def _push_compiled_sql_for_model(task_instance: Any, unique_id: str, compiled_sq
"""
safe_xcom_push(
task_instance=task_instance,
key=f"{unique_id.replace('.', '__')}_compiled_sql",
key=get_compiled_sql_xcom_key(unique_id),
value=compiled_sql,
)

Expand Down Expand Up @@ -332,7 +334,7 @@ def store_dbt_resource_status_from_log(
}
safe_xcom_push(
task_instance=context["ti"],
key=f"{unique_id.replace('.', '__')}_status",
key=get_status_xcom_key(unique_id),
value=status_value,
)

Expand Down Expand Up @@ -477,10 +479,11 @@ def _execute_core(self, context: Context) -> None:
if not self.deferrable:
super().execute(context)
elif not self.poke(context):
if self.is_test_sensor:
xcom_key = get_tests_status_xcom_key(self.model_unique_id)
else:
xcom_key = f"{self.model_unique_id.replace('.', '__')}_status"
xcom_key = (
get_tests_status_xcom_key(self.model_unique_id)
if self.is_test_sensor
else get_status_xcom_key(self.model_unique_id)
)
logger.info(
"Deferring %s '%s'. The trigger will poll XCom key '%s' from producer task '%s'.",
self._resource_label.lower(),
Expand Down Expand Up @@ -545,7 +548,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:

dbt_events = get_xcom_val(
task_instance=context["ti"],
key=f"{self.model_unique_id.replace('.', '__')}_dbt_event",
key=get_dbt_event_xcom_key(self.model_unique_id),
task_ids=self.producer_task_id,
)
_log_dbt_event(dbt_events)
Expand Down Expand Up @@ -590,19 +593,16 @@ def _get_node_status(self, ti: Any, context: Context) -> Any:
dataset emission.
"""
if self.is_test_sensor:
xcom_key = get_tests_status_xcom_key(self.model_unique_id)
return get_xcom_val(ti, self.producer_task_id, xcom_key)
xcom_val = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status")
return get_xcom_val(ti, self.producer_task_id, get_tests_status_xcom_key(self.model_unique_id))
xcom_val = get_xcom_val(ti, self.producer_task_id, get_status_xcom_key(self.model_unique_id))
if xcom_val is None:
return None
self._outlet_uris = xcom_val.get("outlet_uris", [])
return xcom_val.get("status")

def _cache_compiled_sql(self, ti: Any, context: Context) -> None:
"""Pull compiled_sql from XCom and cache it on the sensor instance."""
compiled_sql = get_xcom_val(
ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_compiled_sql"
)
compiled_sql = get_xcom_val(ti, self.producer_task_id, get_compiled_sql_xcom_key(self.model_unique_id))
if compiled_sql:
self.compiled_sql = compiled_sql
if hasattr(self, "_override_rtif"):
Expand Down Expand Up @@ -655,10 +655,11 @@ def poke(self, context: Context) -> bool:
ti = context["ti"]
try_number = ti.try_number

if self.is_test_sensor:
xcom_key = get_tests_status_xcom_key(self.model_unique_id)
else:
xcom_key = f"{self.model_unique_id.replace('.', '__')}_status"
xcom_key = (
get_tests_status_xcom_key(self.model_unique_id)
if self.is_test_sensor
else get_status_xcom_key(self.model_unique_id)
)
logger.info(
"Try number #%s, poke attempt #%s: Pulling status from task_id '%s' via XCom key '%s' for %s '%s'",
try_number,
Expand Down Expand Up @@ -686,7 +687,7 @@ def poke(self, context: Context) -> bool:

dbt_events = get_xcom_val(
task_instance=context["ti"],
key=f"{self.model_unique_id.replace('.', '__')}_dbt_event",
key=get_dbt_event_xcom_key(self.model_unique_id),
task_ids=self.producer_task_id,
)
_log_dbt_event(dbt_events)
Expand Down
15 changes: 15 additions & 0 deletions cosmos/operators/_watcher/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def is_producer_task_terminated(state: str | None) -> bool:
return state in PRODUCER_TERMINAL_STATES


def get_status_xcom_key(unique_id: str) -> str:
"""Build the XCom key used to publish a dbt node's status, sanitising dots in ``unique_id`` to ``__``."""
return f"{unique_id.replace('.', '__')}_status"


def get_dbt_event_xcom_key(unique_id: str) -> str:
"""Build the XCom key used to publish a dbt node's structured event, sanitising dots in ``unique_id`` to ``__``."""
return f"{unique_id.replace('.', '__')}_dbt_event"


def get_compiled_sql_xcom_key(unique_id: str) -> str:
Comment thread
pankajastro marked this conversation as resolved.
"""Build the XCom key used to publish a dbt node's compiled SQL, sanitising dots in ``unique_id`` to ``__``."""
return f"{unique_id.replace('.', '__')}_compiled_sql"
Comment thread
pankajastro marked this conversation as resolved.


xcom_set_lock = Lock()


Expand Down
17 changes: 9 additions & 8 deletions cosmos/operators/_watcher/triggerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PRODUCER_FINAL_STATES,
_log_dbt_event,
build_producer_state_fetcher,
get_compiled_sql_xcom_key,
get_dbt_event_xcom_key,
get_status_xcom_key,
is_dbt_node_status_failed,
is_dbt_node_status_skipped,
is_dbt_node_status_success,
Expand Down Expand Up @@ -125,8 +128,7 @@ async def _get_node_status(self) -> Any | None:
The XCom value is always a dict with ``status`` and ``outlet_uris`` keys.
Stores outlet URIs on ``self._outlet_uris`` for later dataset emission.
"""
status_key = f"{self.model_unique_id.replace('.', '__')}_status"
xcom_val = await self.get_xcom_val(status_key)
xcom_val = await self.get_xcom_val(get_status_xcom_key(self.model_unique_id))
if xcom_val is None:
return None
self._outlet_uris = xcom_val.get("outlet_uris", [])
Expand All @@ -149,14 +151,13 @@ async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str
if self.is_test_sensor:
from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key

status_key = get_tests_status_xcom_key(self.model_unique_id)
status = await self.get_xcom_val(status_key)
status = await self.get_xcom_val(get_tests_status_xcom_key(self.model_unique_id))
return status, None

compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql"

status = await self._get_node_status()
compiled_sql = await self.get_xcom_val(compiled_sql_key) if status is not None else None
compiled_sql = (
await self.get_xcom_val(get_compiled_sql_xcom_key(self.model_unique_id)) if status is not None else None
)
return status, compiled_sql

async def _get_producer_task_status(self) -> str | None:
Expand Down Expand Up @@ -229,7 +230,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

while True:
producer_task_state = await self._get_producer_task_status()
dbt_log_event = await self.get_xcom_val(f"{self.model_unique_id.replace('.', '__')}_dbt_event")
dbt_log_event = await self.get_xcom_val(get_dbt_event_xcom_key(self.model_unique_id))
_log_dbt_event(dbt_log_event)
dbt_node_status, compiled_sql = await self._parse_dbt_node_status_and_compiled_sql()
if is_dbt_node_status_success(dbt_node_status):
Expand Down
29 changes: 29 additions & 0 deletions tests/operators/_watcher/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from cosmos.operators._watcher.state import (
_log_dbt_event,
get_compiled_sql_xcom_key,
get_dbt_event_xcom_key,
get_status_xcom_key,
is_dbt_node_status_failed,
is_dbt_node_status_skipped,
is_dbt_node_status_success,
Expand Down Expand Up @@ -65,6 +68,32 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None):
assert is_dbt_node_status_terminal(status) is False


class TestXcomKeyHelpers:
"""Tests for the watcher XCom-key builder helpers."""

@pytest.mark.parametrize(
"helper,suffix",
[
(get_status_xcom_key, "_status"),
(get_dbt_event_xcom_key, "_dbt_event"),
(get_compiled_sql_xcom_key, "_compiled_sql"),
],
)
def test_replaces_dots_and_appends_suffix(self, helper, suffix):
assert helper("model.my_project.my_model") == f"model__my_project__my_model{suffix}"

@pytest.mark.parametrize(
"helper,suffix",
[
(get_status_xcom_key, "_status"),
(get_dbt_event_xcom_key, "_dbt_event"),
(get_compiled_sql_xcom_key, "_compiled_sql"),
],
)
def test_unique_id_without_dots(self, helper, suffix):
assert helper("plain_id") == f"plain_id{suffix}"


class TestProducerTaskTerminated:
"""Tests for is_producer_task_terminated helper."""

Expand Down
Loading