Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import zlib
from datetime import timedelta
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, List, Union

import airflow
Expand All @@ -14,7 +15,9 @@
if TYPE_CHECKING: # pragma: no cover
try:
from airflow.sdk.definitions.context import Context
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
except ImportError:
from airflow.models.taskinstance import TaskInstance # type: ignore[assignment]
from airflow.utils.context import Context # type: ignore[attr-defined]

try:
Expand Down Expand Up @@ -50,13 +53,25 @@
EventMsg = None

logger = logging.getLogger(__name__)

xcom_set_lock = Lock()
Comment thread
tatiana marked this conversation as resolved.

CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10
PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999
WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test()


def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None:
"""
Safely set an XCom value in a thread-safe manner in Airflow 3.0 and 3.1.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mention specific Airflow versions here? Because we are not altering the behavior here based on Airflow versions here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should mention we're using this because of these versions bug, but it is also harmless to use for others. It simplifies our code-base to not have to handle this differently based on Airflow version. If Airflow fixes this, we can then do the version check

We noticed that the combination of using dbt (multi-threaded) and Airflow 3.0 and 3.1 to set XCom lead to race conditions.
This leads the producer task to get stuck while running the dbt build command.
Unfortunately, since this is non-deterministic, and happens once every five runs, we were not able to have a proper test.
However, we applied this fix and run over 20 times a pipeline that would fail every 5 runs and this allowed us to no longer face the issue.
"""
Comment thread
tatiana marked this conversation as resolved.
with xcom_set_lock:
task_instance.xcom_push(key=key, value=value)


class DbtProducerWatcherOperator(DbtLocalBaseOperator):
"""Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode.

Expand Down Expand Up @@ -145,24 +160,21 @@ def _handle_node_finished(
context: Context,
) -> None:
logger.debug("DbtProducerWatcherOperator: handling node finished event: %s", event_message)
ti = context["ti"]
uid = event_message.data.node_info.unique_id
event_message_dict = self._serialize_event(event_message)
compiled_sql = self._extract_compiled_sql_for_node_event(event_message)
if compiled_sql:
event_message_dict["compiled_sql"] = compiled_sql
payload = base64.b64encode(zlib.compress(json.dumps(event_message_dict).encode())).decode()
ti.xcom_push(key=f"nodefinished_{uid.replace('.', '__')}", value=payload)
safe_xcom_push(task_instance=context["ti"], key=f"nodefinished_{uid.replace('.', '__')}", value=payload)

def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> None:
ti = context["ti"]
# Only push startup events; per-model statuses are available via individual nodefinished_<uid> entries.
if startup_events:
ti.xcom_push(key="dbt_startup_events", value=startup_events)
safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events)

def _store_producer_task_state(self, context: Context) -> None:
ti = context["ti"]
ti.xcom_push(key="state", value="failed")
safe_xcom_push(task_instance=context["ti"], key="state", value="failed")

def execute(self, context: Context, **kwargs: Any) -> Any:
try:
Expand Down Expand Up @@ -193,11 +205,11 @@ def _callback(event_message: EventMsg) -> None:
kwargs["push_run_results_to_xcom"] = True
return_value = super().execute(context=context, **kwargs)

context["ti"].xcom_push(key="task_status", value="completed")
safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed")
return return_value

except Exception:
context["ti"].xcom_push(key="task_status", value="completed")
safe_xcom_push(task_instance=context["ti"], key="task_status", value="completed")
raise


Expand Down