Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,56 @@ def _get_status_from_run_results(self, ti: Any, context: Context) -> Any:
def _get_producer_task_state(self, ti: Any) -> Any:
return ti.xcom_pull(task_ids=self.producer_task_id, key="state")

def _get_producer_task_status(self, context: Context) -> str | None:
"""
Get the task status of the producer task for both Airflow 2 and Airflow 3.

Returns the state of the producer task instance, or None if not found.
"""
ti = context["ti"]
run_id = context["run_id"]
dag_id = ti.dag_id

if AIRFLOW_VERSION < Version("3.0.0"):
# Airflow 2: Query TaskInstance from database
try:
from airflow.models import TaskInstance
from airflow.utils.session import create_session
except ImportError as exc: # pragma: no cover - defensive fallback for tests without DB
logger.warning("Could not import create_session to read producer state: %s", exc)
return None

with create_session() as session:
producer_ti = (
session.query(TaskInstance)
.filter_by(
dag_id=dag_id,
task_id=self.producer_task_id,
run_id=run_id,
)
.first()
)
if producer_ti:
return str(producer_ti.state)
return None
else:
# Airflow 3: Use RuntimeTaskInstance.get_task_states
try:
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states = RuntimeTaskInstance.get_task_states(
dag_id=dag_id,
task_ids=[self.producer_task_id],
run_ids=[run_id],
)
state = task_states.get(run_id, {}).get(self.producer_task_id)
if state is not None:
return str(state)
return None
except (ImportError, NameError) as exc:
logger.warning("Could not retrieve producer task status via RuntimeTaskInstance: %s", exc)
return None

def execute(self, context: Context, **kwargs: Any) -> None:
if not self.deferrable:
super().execute(context)
Expand Down Expand Up @@ -433,7 +483,7 @@ def poke(self, context: Context) -> bool:
return self._fallback_to_local_run(try_number, context)

# We have assumption here that both the build producer and the sensor task will have same invocation mode
producer_task_state = self._get_producer_task_state(ti)
producer_task_state = self._get_producer_task_status(context)
if self._use_event():
status = self._get_status_from_events(ti, context)
else:
Expand Down
115 changes: 112 additions & 3 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,117 @@ def make_sensor(self, **kwargs):
)

sensor.invocation_mode = "DBT_RUNNER"
sensor._get_producer_task_status = MagicMock(return_value=None)
Comment thread
pankajkoti marked this conversation as resolved.
return sensor

def make_context(self, ti_mock):
return {"ti": ti_mock}
def make_context(self, ti_mock, *, run_id: str = "test-run", map_index: int = 0):
return {
"ti": ti_mock,
"run_id": run_id,
"task_instance": MagicMock(map_index=map_index),
}

@pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0")
@patch("airflow.utils.session.create_session")
@patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0"))
def test_get_producer_task_status_airflow2(self, mock_create_session):
sensor = self.make_sensor()
sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__(
sensor, DbtConsumerWatcherSensor
)
ti = MagicMock()
ti.dag_id = "example_dag"
context = self.make_context(ti, run_id="run_1")

mock_state_ti = MagicMock()
mock_state_ti.state = "success"
session_cm = mock_create_session.return_value
session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = mock_state_ti

status = sensor._get_producer_task_status(context)

assert status == "success"
mock_create_session.assert_called_once()

@pytest.mark.skipif(AIRFLOW_VERSION >= Version("3.0.0"), reason="RuntimeTaskInstance path in Airflow >= 3.0")
@patch("airflow.utils.session.create_session")
@patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("2.7.0"))
def test_get_producer_task_status_airflow2_missing_instance(self, mock_create_session):
sensor = self.make_sensor()
sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__(
sensor, DbtConsumerWatcherSensor
)
ti = MagicMock()
ti.dag_id = "example_dag"
context = self.make_context(ti, run_id="run_2")

session_cm = mock_create_session.return_value
session_cm.__enter__.return_value.query.return_value.filter_by.return_value.first.return_value = None

status = sensor._get_producer_task_status(context)

assert status is None

@pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0")
@patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0"))
@patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
def test_get_producer_task_status_airflow3(self, mock_get_task_states):
sensor = self.make_sensor()
sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__(
sensor, DbtConsumerWatcherSensor
)
ti = MagicMock()
ti.dag_id = "example_dag"
context = self.make_context(ti, run_id="run_3")

mock_get_task_states.return_value = {"run_3": {sensor.producer_task_id: "running"}}

status = sensor._get_producer_task_status(context)

assert status == "running"
mock_get_task_states.assert_called_once_with(
dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3"]
)

@pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0")
@patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0"))
@patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
def test_get_producer_task_status_airflow3_missing_state(self, mock_get_task_states):
sensor = self.make_sensor()
sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__(
sensor, DbtConsumerWatcherSensor
)
ti = MagicMock()
ti.dag_id = "example_dag"
context = self.make_context(ti, run_id="run_3_missing")

mock_get_task_states.return_value = {"run_3_missing": {}}

status = sensor._get_producer_task_status(context)

assert status is None
mock_get_task_states.assert_called_once_with(
dag_id="example_dag", task_ids=[sensor.producer_task_id], run_ids=["run_3_missing"]
)

@pytest.mark.skipif(AIRFLOW_VERSION < Version("3.0.0"), reason="Database lookup path in Airflow < 3.0")
@patch("cosmos.operators.watcher.AIRFLOW_VERSION", new=Version("3.0.0"))
@patch(
"airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states",
side_effect=ImportError("missing runtime"),
)
def test_get_producer_task_status_airflow3_import_error(self, _mock_get_task_states):
sensor = self.make_sensor()
sensor._get_producer_task_status = DbtConsumerWatcherSensor._get_producer_task_status.__get__(
sensor, DbtConsumerWatcherSensor
)
ti = MagicMock()
ti.dag_id = "example_dag"
context = self.make_context(ti, run_id="run_4")

status = sensor._get_producer_task_status(context)

assert status is None

@patch("cosmos.operators.watcher.EventMsg")
def test_poke_status_none_from_events(self, MockEventMsg):
Expand Down Expand Up @@ -544,7 +651,7 @@ def test_poke_success_from_run_results(self):
result = sensor.poke(context)
assert result is True

@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_state", return_value=None)
@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_producer_task_status", return_value=None)
def _fallback_to_local_run(self, mock_get_producer_task_state):
sensor = self.make_sensor()
sensor.invocation_mode = None
Expand Down Expand Up @@ -663,6 +770,7 @@ def test_get_status_from_events_sets_compiled_sql(self):
@patch("cosmos.operators.watcher.DbtConsumerWatcherSensor._get_status_from_run_results")
def test_producer_state_failed(self, mock_run_result):
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "failed"
ti = MagicMock()
ti.try_number = 1
sensor.poke_retry_number = 1
Expand All @@ -687,6 +795,7 @@ def test_producer_state_does_not_fail_if_previously_upstream_failed(
More details: https://github.com/astronomer/astronomer-cosmos/pull/2062
"""
sensor = self.make_sensor()
sensor._get_producer_task_status.return_value = "failed"
ti = MagicMock()
ti.try_number = 1
sensor.poke_retry_number = 0
Expand Down