diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 06fdb151ea..ca1e475485 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -6,6 +6,7 @@ import zlib from datetime import timedelta from pathlib import Path +from threading import Lock from typing import TYPE_CHECKING, Any, Callable, List, Union import airflow @@ -14,7 +15,9 @@ if TYPE_CHECKING: # pragma: no cover try: from airflow.sdk.definitions.context import Context + from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance except ImportError: + from airflow.models.taskinstance import TaskInstance # type: ignore[assignment] from airflow.utils.context import Context # type: ignore[attr-defined] try: @@ -50,13 +53,25 @@ EventMsg = None logger = logging.getLogger(__name__) - +xcom_set_lock = Lock() CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10 PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999 WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test() +def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None: + """ + Safely set an XCom value in a thread-safe manner in Airflow 3.0 and 3.1. + We noticed that the combination of using dbt (multi-threaded) and Airflow 3.0 and 3.1 to set XCom lead to race conditions. + This leads the producer task to get stuck while running the dbt build command. + Unfortunately, since this is non-deterministic, and happens once every five runs, we were not able to have a proper test. + However, we applied this fix and run over 20 times a pipeline that would fail every 5 runs and this allowed us to no longer face the issue. + """ + with xcom_set_lock: + task_instance.xcom_push(key=key, value=value) + + class DbtProducerWatcherOperator(DbtLocalBaseOperator): """Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode. @@ -145,24 +160,21 @@ def _handle_node_finished( context: Context, ) -> None: logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) - ti = context["ti"] uid = event_message.data.node_info.unique_id event_message_dict = self._serialize_event(event_message) compiled_sql = self._extract_compiled_sql_for_node_event(event_message) if compiled_sql: event_message_dict["compiled_sql"] = compiled_sql payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() - ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) + safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: - ti = context["ti"] # Only push startup events; per-model statuses are available via individual nodefinished_ entries. if startup_events: - ti.xcom_push(key="dbt_startup_events", value=startup_events) + safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events) def _store_producer_task_state(self, context: Context) -> None: - ti = context["ti"] - ti.xcom_push(key="state", value="failed") + safe_xcom_push(task_instance=context["ti"], key="state", value="failed") def execute(self, context: Context, **kwargs: Any) -> Any: try: @@ -193,11 +205,11 @@ def _callback(event_message: EventMsg) -> None: kwargs["push_run_results_to_xcom"] = True return_value = super().execute(context=context, **kwargs) - context["ti"].xcom_push(key="task_status", value="completed") + safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") return return_value except Exception: - context["ti"].xcom_push(key="task_status", value="completed") + safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") raise