Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from cosmos import settings

__version__ = "1.12.0a2"
__version__ = "1.12.0a3"

if not settings.enable_memory_optimised_imports:
from cosmos.airflow.dag import DbtDag
Expand Down
24 changes: 11 additions & 13 deletions cosmos/_triggers/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,21 @@ async def get_xcom_val(self, key: str) -> Any | None:
return await self.get_xcom_val_af3(key)

async def _parse_node_status(self) -> str | None:
key = f"nodefinished_{self.model_unique_id.replace('.', '__')}" if self.use_event else "run_results"

compressed_xcom_val = await self.get_xcom_val(key)
if not compressed_xcom_val:
return None

data_json = _parse_compressed_xcom(compressed_xcom_val)
key = (
f"nodefinished_{self.model_unique_id.replace('.', '__')}"
if self.use_event
else f"{self.model_unique_id.replace('.', '__')}_status"
)

if self.use_event:
compressed_xcom_val = await self.get_xcom_val(key)
if not compressed_xcom_val:
return None

data_json = _parse_compressed_xcom(compressed_xcom_val)
return data_json.get("data", {}).get("run_result", {}).get("status") # type: ignore[no-any-return]

results = data_json.get("results", [])
node_result: dict[str, Any] = next(
(r for r in results if r.get("unique_id") == self.model_unique_id),
{},
)
return node_result.get("status")
Comment thread
pankajastro marked this conversation as resolved.
return await self.get_xcom_val(key)

async def _get_producer_task_status(self) -> str | None:
"""Retrieve the producer task state for both Airflow 2 and Airflow 3."""
Expand Down
26 changes: 26 additions & 0 deletions cosmos/_utils/watcher_state.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
from __future__ import annotations

import logging
from threading import Lock
from typing import Any, Callable

try:
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
except ImportError:
from airflow.models.taskinstance import TaskInstance # type: ignore[assignment]

from packaging.version import Version

ProducerStateFetcher = Callable[[], str | None]


xcom_set_lock = Lock()


def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None:
"""
Safely set an XCom value in a thread-safe manner in Airflow 3.0 and 3.1.
We noticed that the combination of using dbt (multi-threaded) and Airflow 3.0 and 3.1 to set XCom lead to race conditions.
This leads the producer task to get stuck while running the dbt build command.
Unfortunately, since this is non-deterministic, and happens once every five runs, we were not able to have a proper test.
However, we applied this fix and run over 20 times a pipeline that would fail every 5 runs and this allowed us to no longer face the issue.
"""
with xcom_set_lock:
task_instance.xcom_push(key=key, value=value)


# TODO: Unify the Airflow call from cosmos/_triggers/watcher.py and cosmos/operators/watcher.py
def get_xcom_val(task_instance: TaskInstance, task_ids: str | list[str], key: str) -> Any:
return task_instance.xcom_pull(task_ids, key=key)


def _load_airflow2_dependencies() -> tuple[Any, Callable[[], Any]]:
from airflow.models import TaskInstance
from airflow.utils.session import create_session
Expand Down
4 changes: 2 additions & 2 deletions cosmos/dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from functools import lru_cache
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from cosmos.dbt.project import change_working_directory, environ
from cosmos.exceptions import CosmosDbtRunError
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_runner(callbacks: list[Callable] | None = None) -> dbtRunner: # type: i


def run_command(
command: list[str], env: dict[str, str], cwd: str, callbacks: list[Callable] | None = None # type: ignore[type-arg]
command: list[str], env: dict[str, str], cwd: str, callbacks: list[Callable] | None = None, **kwargs: Any # type: ignore[type-arg]
) -> dbtRunnerResult:
"""
Invokes the dbt command programmatically.
Expand Down
65 changes: 50 additions & 15 deletions cosmos/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
from __future__ import annotations

import contextlib
import json
import os
import signal
from subprocess import PIPE, STDOUT, Popen
from tempfile import TemporaryDirectory, gettempdir
from typing import NamedTuple
from typing import Any, NamedTuple

try:
# Airflow 3.1 onwards
from airflow.sdk.bases.hook import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook

from cosmos._utils.watcher_state import safe_xcom_push


class FullOutputSubprocessResult(NamedTuple):
exit_code: int
Expand All @@ -28,15 +31,41 @@ class FullOutputSubprocessHook(BaseHook): # type: ignore[misc]
"""Hook for running processes with the ``subprocess`` module."""

def __init__(self) -> None:
self.sub_process: Popen[bytes] | None = None
self.sub_process: Popen[str] | None = None
Comment thread
pankajastro marked this conversation as resolved.
super().__init__() # type: ignore[no-untyped-call]

def _store_dbt_resource_status_from_log(self, line: str, **kwargs: Any) -> None:
Comment thread
pankajastro marked this conversation as resolved.
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.

This method parses each log line from dbt when --log-format json is used,
extracts node status information, and pushes it to XCom for consumption
by downstream watcher sensors.
"""
try:
log_line = json.loads(line)
Comment thread
pankajastro marked this conversation as resolved.
except json.JSONDecodeError:
self.log.debug("Failed to parse log: %s", line)
log_line = {}

node_status = log_line.get("data", {}).get("node_info", {}).get("node_status")
unique_id = log_line.get("data", {}).get("node_info", {}).get("unique_id")

self.log.debug("Model: %s is in %s state", unique_id, node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
context = kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)

def run_command(
self,
command: list[str],
env: dict[str, str] | None = None,
output_encoding: str = "utf-8",
cwd: str | None = None,
**kwargs: Any,
) -> FullOutputSubprocessResult:
"""
Execute the command.
Expand Down Expand Up @@ -79,26 +108,32 @@ def pre_exec() -> None:
cwd=cwd,
env=env if env or env == {} else os.environ,
preexec_fn=pre_exec,
bufsize=1, # line-buffered (works only in text mode)
text=True,
encoding=output_encoding,
errors="backslashreplace",
Comment thread
pankajastro marked this conversation as resolved.
)

self.log.info("Command output:")
line = ""

if self.sub_process is None:
raise RuntimeError("The subprocess should be created here and is None!")
if self.sub_process.stdout is not None:
for raw_line in iter(self.sub_process.stdout.readline, b""):
line = raw_line.decode(output_encoding, errors="backslashreplace").rstrip()
# storing the warn & error lines to be used later
log_lines.append(line)
self.log.info("%s", line)

self.sub_process.wait()
self.log.info("Command output:")

last_line: str = ""
assert self.sub_process.stdout is not None
for line in self.sub_process.stdout:
line = line.rstrip("\n")
last_line = line
log_lines.append(line)
self.log.info("%s", line)
self._store_dbt_resource_status_from_log(line, **kwargs)

# Wait until process completes
return_code = self.sub_process.wait()

self.log.info("Command exited with return code %s", self.sub_process.returncode)
return_code: int = self.sub_process.returncode
self.log.info("Command exited with return code %s", return_code)

return FullOutputSubprocessResult(exit_code=return_code, output=line, full_output=log_lines)
return FullOutputSubprocessResult(exit_code=return_code, output=last_line, full_output=log_lines)

def send_sigterm(self) -> None:
"""Sends SIGTERM signal to ``self.sub_process`` if one exists."""
Expand Down
6 changes: 4 additions & 2 deletions cosmos/operators/_asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, *args: Any, **kwargs: Any):
kwargs["emit_datasets"] = False
super().__init__(*args, **kwargs)

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
def run_subprocess(
self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any
) -> FullOutputSubprocessResult:
profile_type = self.profile_config.get_profile_type()
if not self._py_bin:
raise AttributeError("_py_bin attribute not set for VirtualEnv operator")
Expand All @@ -49,7 +51,7 @@ def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> F
with open(dbt_executable_path, "w") as f:
f.writelines(dbt_entrypoint_script)

return super().run_subprocess(command, env, cwd)
return super().run_subprocess(command, env, cwd, **kwargs)

def execute(self, context: Context, **kwargs: Any) -> None:
async_context = {"profile_type": self.profile_config.get_profile_type(), "run_id": context["run_id"]}
Expand Down
14 changes: 12 additions & 2 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,21 @@ def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore


class DbtBuildMixin:
"""Mixin for dbt build command."""
"""
Mixin for dbt build command.

:param full_refresh: whether to add the flag --full-refresh to the dbt build command
:param log_format: format for dbt logs (e.g., 'json', 'text'). If provided, adds --log-format flag
"""

base_cmd = ["build"]
ui_color = "#8194E0"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
def __init__(self, full_refresh: bool | str = False, log_format: str | None = None, **kwargs: Any) -> None:
self.full_refresh = full_refresh
self.log_format = log_format
super().__init__(**kwargs)
Comment thread
pankajastro marked this conversation as resolved.

def add_cmd_flags(self) -> list[str]:
Expand All @@ -341,6 +347,10 @@ def add_cmd_flags(self) -> list[str]:
if full_refresh is True:
flags.append("--full-refresh")

if self.log_format:
flags.append("--log-format")
flags.append(self.log_format)

return flags


Expand Down
10 changes: 7 additions & 3 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,27 +457,30 @@ def _override_rtif_airflow_2_x(session: Session = NEW_SESSION) -> None:

_override_rtif_airflow_2_x()

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
def run_subprocess(
self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any
) -> FullOutputSubprocessResult:
logger.info("Trying to run the command:\n %s\nFrom %s", command, cwd)
subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(
command=command,
env=env,
cwd=cwd,
output_encoding=self.output_encoding,
**kwargs,
)
# Logging changed in Airflow 3.1 and we needed to replace the output by the full output:
output = "".join(subprocess_result.full_output)
logger.info(output)
return subprocess_result

def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> dbtRunnerResult:
"""Invokes the dbt command programmatically."""
if not dbt_runner.is_available():
raise CosmosDbtRunError(
"Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running."
)

return dbt_runner.run_command(command, env, cwd, callbacks=self._dbt_runner_callbacks)
return dbt_runner.run_command(command, env, cwd, callbacks=self._dbt_runner_callbacks, **kwargs)

def _cache_package_lockfile(self, tmp_project_dir: Path) -> None:
project_dir = Path(self.project_dir)
Expand Down Expand Up @@ -684,6 +687,7 @@ def run_command( # noqa: C901
command=full_cmd,
env=env,
cwd=tmp_project_dir,
context=context,
)
if is_openlineage_common_available:
self.calculate_openlineage_events_completes(env, tmp_dir_path)
Expand Down
6 changes: 4 additions & 2 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def __init__(
if not self.py_requirements:
self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter")

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
def run_subprocess(
self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any
) -> FullOutputSubprocessResult:
if self._py_bin is not None:
self.log.info(f"Using Python binary from virtualenv: {self._py_bin}")
command[0] = str(Path(self._py_bin).parent / "dbt")
return super().run_subprocess(command, env, cwd)
return super().run_subprocess(command, env, cwd, **kwargs)

def run_command(
self,
Expand Down
20 changes: 3 additions & 17 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import zlib
from datetime import timedelta
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any

from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom
from cosmos._utils.watcher_state import get_xcom_val, safe_xcom_push

if TYPE_CHECKING: # pragma: no cover
try:
from airflow.sdk.definitions.context import Context
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
except ImportError:
from airflow.models.taskinstance import TaskInstance # type: ignore[assignment]
from airflow.utils.context import Context # type: ignore[attr-defined]

try:
Expand Down Expand Up @@ -52,25 +50,12 @@
EventMsg = None

logger = logging.getLogger(__name__)
xcom_set_lock = Lock()

CONSUMER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 10
PRODUCER_OPERATOR_DEFAULT_PRIORITY_WEIGHT = 9999
WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test()


def safe_xcom_push(task_instance: TaskInstance, key: str, value: Any) -> None:
"""
Safely set an XCom value in a thread-safe manner in Airflow 3.0 and 3.1.
We noticed that the combination of using dbt (multi-threaded) and Airflow 3.0 and 3.1 to set XCom lead to race conditions.
This leads the producer task to get stuck while running the dbt build command.
Unfortunately, since this is non-deterministic, and happens once every five runs, we were not able to have a proper test.
However, we applied this fix and run over 20 times a pipeline that would fail every 5 runs and this allowed us to no longer face the issue.
"""
with xcom_set_lock:
task_instance.xcom_push(key=key, value=value)


class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator):
"""Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode.

Expand Down Expand Up @@ -109,6 +94,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
default_args["retries"] = 0
kwargs["default_args"] = default_args
kwargs["retries"] = 0
kwargs["log_format"] = "json"

super().__init__(task_id=task_id, *args, **kwargs)

Expand Down Expand Up @@ -433,7 +419,7 @@ def poke(self, context: Context) -> bool:
if self._use_event():
status = self._get_status_from_events(ti, context)
else:
status = self._get_status_from_run_results(ti, context)
status = get_xcom_val(ti, self.producer_task_id, f"{self.model_unique_id.replace('.', '__')}_status")

if status is None:

Expand Down
Loading