From 99ce1b811e294cbf4fcc8f744dd3be0dc9078ec1 Mon Sep 17 00:00:00 2001 From: Pankaj Singh Date: Wed, 13 May 2026 21:11:34 +0530 Subject: [PATCH 1/2] Extract watcher XCom-key helpers and inline single-use bindings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The watcher operators sanitised dbt unique_ids into XCom keys with the same f"{uid.replace('.', '__')}_" pattern repeated 12 times across base.py and triggerer.py for three suffixes (_status, _dbt_event, _compiled_sql) — parallel to the existing get_tests_status_xcom_key in aggregation.py for _tests_status. Add three slim helpers in cosmos/operators/_watcher/state.py: get_status_xcom_key, get_dbt_event_xcom_key, get_compiled_sql_xcom_key. Each is a single-expression function so the ".replace('.', '__')" sanitisation rule lives in one place. Replace all 12 inline call sites and, while there, inline the single-use xcom_key/status_key/compiled_sql_key bindings that introduced a redundant variable just to immediately pass it as the next argument. No behaviour change. Co-Authored-By: Claude Opus 4.7 (1M context) --- cosmos/operators/_watcher/base.py | 41 +++++++++++++------------- cosmos/operators/_watcher/state.py | 12 ++++++++ cosmos/operators/_watcher/triggerer.py | 17 ++++++----- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index f2a4648ef8..1e4a4b3d8c 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -25,6 +25,9 @@ _iso_to_string, _log_dbt_event, build_producer_state_fetcher, + get_compiled_sql_xcom_key, + get_dbt_event_xcom_key, + get_status_xcom_key, get_xcom_val, is_dbt_node_status_skipped, is_dbt_node_status_success, @@ -118,8 +121,7 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None: "msg": msg, } - xcom_key = f"{unique_id.replace('.', '__')}_dbt_event" - safe_xcom_push(task_instance=task_instance, key=xcom_key, value=dbt_event) + safe_xcom_push(task_instance=task_instance, key=get_dbt_event_xcom_key(unique_id), value=dbt_event) def _extract_compiled_sql( @@ -155,7 +157,7 @@ def _push_compiled_sql_for_model(task_instance: Any, unique_id: str, compiled_sq """ safe_xcom_push( task_instance=task_instance, - key=f"{unique_id.replace('.', '__')}_compiled_sql", + key=get_compiled_sql_xcom_key(unique_id), value=compiled_sql, ) @@ -332,7 +334,7 @@ def store_dbt_resource_status_from_log( } safe_xcom_push( task_instance=context["ti"], - key=f"{unique_id.replace('.', '__')}_status", + key=get_status_xcom_key(unique_id), value=status_value, ) @@ -477,10 +479,11 @@ def _execute_core(self, context: Context) -> None: if not self.deferrable: super().execute(context) elif not self.poke(context): - if self.is_test_sensor: - xcom_key = get_tests_status_xcom_key(self.model_unique_id) - else: - xcom_key = f"{self.model_unique_id.replace('.', '__')}_status" + xcom_key = ( + get_tests_status_xcom_key(self.model_unique_id) + if self.is_test_sensor + else get_status_xcom_key(self.model_unique_id) + ) logger.info( "Deferring %s '%s'. The trigger will poll XCom key '%s' from producer task '%s'.", self._resource_label.lower(), @@ -545,7 +548,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: dbt_events = get_xcom_val( task_instance=context["ti"], - key=f"{self.model_unique_id.replace('.', '__')}_dbt_event", + key=get_dbt_event_xcom_key(self.model_unique_id), task_ids=self.producer_task_id, ) _log_dbt_event(dbt_events) @@ -590,9 +593,8 @@ def _get_node_status(self, ti: Any, context: Context) -> Any: dataset emission. """ if self.is_test_sensor: - xcom_key = get_tests_status_xcom_key(self.model_unique_id) - return get_xcom_val(ti, self.producer_task_id, xcom_key) - xcom_val = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status") + return get_xcom_val(ti, self.producer_task_id, get_tests_status_xcom_key(self.model_unique_id)) + xcom_val = get_xcom_val(ti, self.producer_task_id, get_status_xcom_key(self.model_unique_id)) if xcom_val is None: return None self._outlet_uris = xcom_val.get("outlet_uris", []) @@ -600,9 +602,7 @@ def _get_node_status(self, ti: Any, context: Context) -> Any: def _cache_compiled_sql(self, ti: Any, context: Context) -> None: """Pull compiled_sql from XCom and cache it on the sensor instance.""" - compiled_sql = get_xcom_val( - ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_compiled_sql" - ) + compiled_sql = get_xcom_val(ti, self.producer_task_id, get_compiled_sql_xcom_key(self.model_unique_id)) if compiled_sql: self.compiled_sql = compiled_sql if hasattr(self, "_override_rtif"): @@ -655,10 +655,11 @@ def poke(self, context: Context) -> bool: ti = context["ti"] try_number = ti.try_number - if self.is_test_sensor: - xcom_key = get_tests_status_xcom_key(self.model_unique_id) - else: - xcom_key = f"{self.model_unique_id.replace('.', '__')}_status" + xcom_key = ( + get_tests_status_xcom_key(self.model_unique_id) + if self.is_test_sensor + else get_status_xcom_key(self.model_unique_id) + ) logger.info( "Try number #%s, poke attempt #%s: Pulling status from task_id '%s' via XCom key '%s' for %s '%s'", try_number, @@ -686,7 +687,7 @@ def poke(self, context: Context) -> bool: dbt_events = get_xcom_val( task_instance=context["ti"], - key=f"{self.model_unique_id.replace('.', '__')}_dbt_event", + key=get_dbt_event_xcom_key(self.model_unique_id), task_ids=self.producer_task_id, ) _log_dbt_event(dbt_events) diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index bca9e37de1..a9fbb64a6f 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -71,6 +71,18 @@ def is_producer_task_terminated(state: str | None) -> bool: return state in PRODUCER_TERMINAL_STATES +def get_status_xcom_key(unique_id: str) -> str: + return f"{unique_id.replace('.', '__')}_status" + + +def get_dbt_event_xcom_key(unique_id: str) -> str: + return f"{unique_id.replace('.', '__')}_dbt_event" + + +def get_compiled_sql_xcom_key(unique_id: str) -> str: + return f"{unique_id.replace('.', '__')}_compiled_sql" + + xcom_set_lock = Lock() diff --git a/cosmos/operators/_watcher/triggerer.py b/cosmos/operators/_watcher/triggerer.py index 1d3819ac67..daff257494 100644 --- a/cosmos/operators/_watcher/triggerer.py +++ b/cosmos/operators/_watcher/triggerer.py @@ -16,6 +16,9 @@ PRODUCER_FINAL_STATES, _log_dbt_event, build_producer_state_fetcher, + get_compiled_sql_xcom_key, + get_dbt_event_xcom_key, + get_status_xcom_key, is_dbt_node_status_failed, is_dbt_node_status_skipped, is_dbt_node_status_success, @@ -125,8 +128,7 @@ async def _get_node_status(self) -> Any | None: The XCom value is always a dict with ``status`` and ``outlet_uris`` keys. Stores outlet URIs on ``self._outlet_uris`` for later dataset emission. """ - status_key = f"{self.model_unique_id.replace('.', '__')}_status" - xcom_val = await self.get_xcom_val(status_key) + xcom_val = await self.get_xcom_val(get_status_xcom_key(self.model_unique_id)) if xcom_val is None: return None self._outlet_uris = xcom_val.get("outlet_uris", []) @@ -149,14 +151,13 @@ async def _parse_dbt_node_status_and_compiled_sql(self) -> tuple[str | None, str if self.is_test_sensor: from cosmos.operators._watcher.aggregation import get_tests_status_xcom_key - status_key = get_tests_status_xcom_key(self.model_unique_id) - status = await self.get_xcom_val(status_key) + status = await self.get_xcom_val(get_tests_status_xcom_key(self.model_unique_id)) return status, None - compiled_sql_key = f"{self.model_unique_id.replace('.', '__')}_compiled_sql" - status = await self._get_node_status() - compiled_sql = await self.get_xcom_val(compiled_sql_key) if status is not None else None + compiled_sql = ( + await self.get_xcom_val(get_compiled_sql_xcom_key(self.model_unique_id)) if status is not None else None + ) return status, compiled_sql async def _get_producer_task_status(self) -> str | None: @@ -229,7 +230,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while True: producer_task_state = await self._get_producer_task_status() - dbt_log_event = await self.get_xcom_val(f"{self.model_unique_id.replace('.', '__')}_dbt_event") + dbt_log_event = await self.get_xcom_val(get_dbt_event_xcom_key(self.model_unique_id)) _log_dbt_event(dbt_log_event) dbt_node_status, compiled_sql = await self._parse_dbt_node_status_and_compiled_sql() if is_dbt_node_status_success(dbt_node_status): From ac177c35b5cf409b8ca04d176e589ad648205cce Mon Sep 17 00:00:00 2001 From: Pankaj Singh Date: Thu, 14 May 2026 14:42:17 +0530 Subject: [PATCH 2/2] Add docstrings and tests for watcher XCom-key helpers Address Copilot review feedback on PR #2673: document the new get_status_xcom_key, get_dbt_event_xcom_key, and get_compiled_sql_xcom_key helpers to match the surrounding module style, and add parametrized unit tests that pin the dot-to-"__" sanitisation rule and each helper's suffix to guard against accidental key-format drift. Co-Authored-By: Claude Opus 4.7 (1M context) --- cosmos/operators/_watcher/state.py | 3 +++ tests/operators/_watcher/test_state.py | 29 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/cosmos/operators/_watcher/state.py b/cosmos/operators/_watcher/state.py index a9fbb64a6f..b27e63fafe 100644 --- a/cosmos/operators/_watcher/state.py +++ b/cosmos/operators/_watcher/state.py @@ -72,14 +72,17 @@ def is_producer_task_terminated(state: str | None) -> bool: def get_status_xcom_key(unique_id: str) -> str: + """Build the XCom key used to publish a dbt node's status, sanitising dots in ``unique_id`` to ``__``.""" return f"{unique_id.replace('.', '__')}_status" def get_dbt_event_xcom_key(unique_id: str) -> str: + """Build the XCom key used to publish a dbt node's structured event, sanitising dots in ``unique_id`` to ``__``.""" return f"{unique_id.replace('.', '__')}_dbt_event" def get_compiled_sql_xcom_key(unique_id: str) -> str: + """Build the XCom key used to publish a dbt node's compiled SQL, sanitising dots in ``unique_id`` to ``__``.""" return f"{unique_id.replace('.', '__')}_compiled_sql" diff --git a/tests/operators/_watcher/test_state.py b/tests/operators/_watcher/test_state.py index d452ab2d2f..988baf943e 100644 --- a/tests/operators/_watcher/test_state.py +++ b/tests/operators/_watcher/test_state.py @@ -12,6 +12,9 @@ from cosmos.operators._watcher.state import ( _log_dbt_event, + get_compiled_sql_xcom_key, + get_dbt_event_xcom_key, + get_status_xcom_key, is_dbt_node_status_failed, is_dbt_node_status_skipped, is_dbt_node_status_success, @@ -65,6 +68,32 @@ def test_is_dbt_node_status_terminal_false(self, status: str | None): assert is_dbt_node_status_terminal(status) is False +class TestXcomKeyHelpers: + """Tests for the watcher XCom-key builder helpers.""" + + @pytest.mark.parametrize( + "helper,suffix", + [ + (get_status_xcom_key, "_status"), + (get_dbt_event_xcom_key, "_dbt_event"), + (get_compiled_sql_xcom_key, "_compiled_sql"), + ], + ) + def test_replaces_dots_and_appends_suffix(self, helper, suffix): + assert helper("model.my_project.my_model") == f"model__my_project__my_model{suffix}" + + @pytest.mark.parametrize( + "helper,suffix", + [ + (get_status_xcom_key, "_status"), + (get_dbt_event_xcom_key, "_dbt_event"), + (get_compiled_sql_xcom_key, "_compiled_sql"), + ], + ) + def test_unique_id_without_dots(self, helper, suffix): + assert helper("plain_id") == f"plain_id{suffix}" + + class TestProducerTaskTerminated: """Tests for is_producer_task_terminated helper."""