-
Notifications
You must be signed in to change notification settings - Fork 297
Fix ExecutionMode.WATCHER deadlock in Airflow 3.0 & 3.1
#2087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import zlib | ||
| from datetime import timedelta | ||
| from pathlib import Path | ||
| from threading import Lock | ||
| from typing import TYPE_CHECKING, Any, Callable, List, Union | ||
|
|
||
| import airflow | ||
|
|
@@ -14,7 +15,9 @@ | |
| if TYPE_CHECKING: # pragma: no cover | ||
| try: | ||
| from airflow.sdk.definitions.context import Context | ||
| from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance | ||
| except ImportError: | ||
| from airflow.models.taskinstance import TaskInstance # type: ignore[assignment] | ||
| from airflow.utils.context import Context # type: ignore[attr-defined] | ||
|
|
||
| try: | ||
|
|
@@ -50,13 +53,25 @@ | |
| EventMsg = None | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| xcom_set_lock = Lock() | ||
|
|
||
| CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10 | ||
| PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999 | ||
| WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test() | ||
|
|
||
|
|
||
| def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None: | ||
| """ | ||
| Safely set an XCom value in a thread-safe manner in Airflow 3.0 and 3.1. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to mention specific Airflow versions here? Because we are not altering the behavior here based on Airflow versions here.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should mention we're using this because of these versions bug, but it is also harmless to use for others. It simplifies our code-base to not have to handle this differently based on Airflow version. If Airflow fixes this, we can then do the version check |
||
| We noticed that the combination of using dbt (multi-threaded) and Airflow 3.0 and 3.1 to set XCom lead to race conditions. | ||
| This leads the producer task to get stuck while running the dbt build command. | ||
| Unfortunately, since this is non-deterministic, and happens once every five runs, we were not able to have a proper test. | ||
| However, we applied this fix and run over 20 times a pipeline that would fail every 5 runs and this allowed us to no longer face the issue. | ||
| """ | ||
|
tatiana marked this conversation as resolved.
|
||
| with xcom_set_lock: | ||
| task_instance.xcom_push(key=key, value=value) | ||
|
|
||
|
|
||
| class DbtProducerWatcherOperator(DbtLocalBaseOperator): | ||
| """Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode. | ||
|
|
||
|
|
@@ -145,24 +160,21 @@ def _handle_node_finished( | |
| context: Context, | ||
| ) -> None: | ||
| logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message) | ||
| ti = context["ti"] | ||
| uid = event_message.data.node_info.unique_id | ||
| event_message_dict = self._serialize_event(event_message) | ||
| compiled_sql = self._extract_compiled_sql_for_node_event(event_message) | ||
| if compiled_sql: | ||
| event_message_dict["compiled_sql"] = compiled_sql | ||
| payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode() | ||
| ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload) | ||
| safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload) | ||
|
|
||
| def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None: | ||
| ti = context["ti"] | ||
| # Only push startup events; per-model statuses are available via individual nodefinished_<uid> entries. | ||
| if startup_events: | ||
| ti.xcom_push(key="dbt_startup_events", value=startup_events) | ||
| safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events) | ||
|
|
||
| def _store_producer_task_state(self, context: Context) -> None: | ||
| ti = context["ti"] | ||
| ti.xcom_push(key="state", value="failed") | ||
| safe_xcom_push(task_instance=context["ti"], key="state", value="failed") | ||
|
|
||
| def execute(self, context: Context, **kwargs: Any) -> Any: | ||
| try: | ||
|
|
@@ -193,11 +205,11 @@ def _callback(event_message: EventMsg) -> None: | |
| kwargs["push_run_results_to_xcom"] = True | ||
| return_value = super().execute(context=context, **kwargs) | ||
|
|
||
| context["ti"].xcom_push(key="task_status", value="completed") | ||
| safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") | ||
| return return_value | ||
|
|
||
| except Exception: | ||
| context["ti"].xcom_push(key="task_status", value="completed") | ||
| safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed") | ||
| raise | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.