Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import zlib
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, List, Union
Comment thread
tatiana marked this conversation as resolved.

import airflow
from packaging.version import Version
Expand Down Expand Up @@ -303,7 +303,7 @@ def _get_status_from_events(self, ti: Any, context: Context) -> Any:

return event_json.get("data", {}).get("run_result", {}).get("status")

def _get_status_from_run_results(self, ti: Any) -> Any:
def _get_status_from_run_results(self, ti: Any, context: Context) -> Any:
compressed_b64_run_results = ti.xcom_pull(task_ids=self.producer_task_id, key="run_results")

if not compressed_b64_run_results:
Expand All @@ -323,6 +323,10 @@ def _get_status_from_run_results(self, ti: Any) -> Any:
return None

logger.info("Node Info: %s", run_results_str)
self.compiled_sql = node_result.get("compiled_code")
if self.compiled_sql:
self._override_rtif(context)

return node_result.get("status")

def _get_producer_task_state(self, ti: Any) -> Any:
Expand Down Expand Up @@ -356,7 +360,7 @@ def poke(self, context: Context) -> bool:
if use_events:
status = self._get_status_from_events(ti, context)
else:
status = self._get_status_from_run_results(ti)
status = self._get_status_from_run_results(ti, context)

if status is None:

Expand Down Expand Up @@ -411,7 +415,7 @@ class DbtSourceWatcherOperator(DbtSourceLocalOperator):
Executes a dbt source freshness command, synchronously, as ExecutionMode.LOCAL.
"""

template_fields: Sequence[str] = DbtSourceLocalOperator.template_fields
template_fields: tuple[str, ...] = DbtConsumerWatcherSensor.template_fields


class DbtRunWatcherOperator(DbtConsumerWatcherSensor):
Expand Down
32 changes: 30 additions & 2 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json
import zlib
Expand Down Expand Up @@ -474,15 +476,15 @@ def test_get_status_from_run_results_success(self):
ti = MagicMock()
ti.xcom_pull.return_value = ENCODED_RUN_RESULTS

result = sensor._get_status_from_run_results(ti)
result = sensor._get_status_from_run_results(ti, _MockContext(ti=ti))
assert result == "success"

def test_get_status_from_run_results_none(self):
sensor = self.make_sensor()
ti = MagicMock()
ti.xcom_pull.return_value = None

result = sensor._get_status_from_run_results(ti)
result = sensor._get_status_from_run_results(ti, _MockContext(ti=ti))
assert result is None

def test_get_status_from_events_success(self):
Expand Down Expand Up @@ -553,6 +555,32 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed(
sensor.poke(context)
mock_fallback_to_local_run.assert_called_once()

@patch("cosmos.operators.local.AbstractDbtLocalBase._override_rtif")
def test_get_status_from_run_results_with_compiled_sql(self, mock_override_rtif, monkeypatch):
sensor = self.make_sensor()
sensor.model_unique_id = "model.test_table"

# Create a fake run_results payload containing compiled_code and status
run_results = {
"results": [
{
"unique_id": "model.test_table",
"compiled_code": "SELECT * FROM dummy_table;",
"status": "success",
}
]
}

compressed = zlib.compress(json.dumps(run_results).encode())
encoded = base64.b64encode(compressed).decode()

# Mock TaskInstance.xcom_pull to return encoded results
ti = MagicMock()
ti.xcom_pull.return_value = encoded
context = {"ti": ti}
sensor._get_status_from_run_results(ti, context)
mock_override_rtif.assert_called_with(context)


class TestDbtBuildWatcherOperator:

Expand Down