Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cosmos/_triggers/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from asgiref.sync import sync_to_async
from packaging.version import Version

from cosmos._utils.watcher_state import build_producer_state_fetcher
from cosmos.constants import AIRFLOW_VERSION
from cosmos.operators._watcher.state import build_producer_state_fetcher


class WatcherTrigger(BaseTrigger):
Expand Down
34 changes: 4 additions & 30 deletions cosmos/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@
from __future__ import annotations

import contextlib
import json
import os
import signal
from subprocess import PIPE, STDOUT, Popen
from tempfile import TemporaryDirectory, gettempdir
from typing import Any, NamedTuple
from typing import Any, Callable, NamedTuple

try:
# Airflow 3.1 onwards
from airflow.sdk.bases.hook import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook

from cosmos._utils.watcher_state import safe_xcom_push


class FullOutputSubprocessResult(NamedTuple):
exit_code: int
Expand All @@ -34,37 +31,13 @@ def __init__(self) -> None:
self.sub_process: Popen[str] | None = None
super().__init__() # type: ignore[no-untyped-call]

def _store_dbt_resource_status_from_log(self, line: str, **kwargs: Any) -> None:
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.

This method parses each log line from dbt when --log-format json is used,
extracts node status information, and pushes it to XCom for consumption
by downstream watcher sensors.
"""
try:
log_line = json.loads(line)
except json.JSONDecodeError:
self.log.debug("Failed to parse log: %s", line)
log_line = {}

node_status = log_line.get("data", {}).get("node_info", {}).get("node_status")
unique_id = log_line.get("data", {}).get("node_info", {}).get("unique_id")

self.log.debug("Model: %s is in %s state", unique_id, node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
context = kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)

def run_command(
self,
command: list[str],
env: dict[str, str] | None = None,
output_encoding: str = "utf-8",
cwd: str | None = None,
process_log_line: Callable[[str, Any], None] | None = None,
**kwargs: Any,
) -> FullOutputSubprocessResult:
"""
Expand Down Expand Up @@ -126,7 +99,8 @@ def pre_exec() -> None:
last_line = line
log_lines.append(line)
self.log.info("%s", line)
self._store_dbt_resource_status_from_log(line, **kwargs)
if process_log_line:
process_log_line(line, kwargs)

# Wait until process completes
return_code = self.sub_process.wait()
Expand Down
5 changes: 5 additions & 0 deletions cosmos/operators/_watcher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

__all__ = ["get_xcom_val", "safe_xcom_push", "build_producer_state_fetcher"]

from cosmos.operators._watcher.state import build_producer_state_fetcher, get_xcom_val, safe_xcom_push
File renamed without changes.
2 changes: 2 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class AbstractDbtLocalBase(AbstractDbtBase):
"compiled_sql": "sql",
"freshness": "json",
}
_process_log_line_callable: Callable[[str, Any], None] | None = None
Comment thread
tatiana marked this conversation as resolved.

def __init__(
self,
Expand Down Expand Up @@ -466,6 +467,7 @@ def run_subprocess(
env=env,
cwd=cwd,
output_encoding=self.output_encoding,
process_log_line=self._process_log_line_callable,
**kwargs,
)
# Logging changed in Airflow 3.1 and we needed to replace the output by the full output:
Expand Down
33 changes: 30 additions & 3 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import zlib
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable

from cosmos._triggers.watcher import WatcherTrigger, _parse_compressed_xcom
from cosmos._utils.watcher_state import get_xcom_val, safe_xcom_push
from cosmos.operators._watcher.state import get_xcom_val, safe_xcom_push

if TYPE_CHECKING: # pragma: no cover
try:
Expand All @@ -29,9 +29,9 @@
from airflow.operators.empty import EmptyOperator # type: ignore[no-redef]


from cosmos._utils.watcher_state import build_producer_state_fetcher
from cosmos.config import ProfileConfig
from cosmos.constants import AIRFLOW_VERSION, PRODUCER_WATCHER_TASK_ID, InvocationMode
from cosmos.operators._watcher.state import build_producer_state_fetcher
from cosmos.operators.base import (
DbtBuildMixin,
DbtRunMixin,
Expand All @@ -56,6 +56,32 @@
WEIGHT_RULE = "absolute" # the default "downstream" does not work with dag.test()


def _store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None:
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.

This method parses each log line from dbt when --log-format json is used,
extracts node status information, and pushes it to XCom for consumption
by downstream watcher sensors.
"""
try:
log_line = json.loads(line)
except json.JSONDecodeError:
logger.debug("Failed to parse log: %s", line)
log_line = {}
node_info = log_line.get("data", {}).get("node_info", {})
node_status = node_info.get("node_status")
unique_id = node_info.get("unique_id")

logger.debug("Model: %s is in %s state", unique_id, node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
context = extra_kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)


class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator):
"""Run dbt build and update XCom with the progress of each model, as part of the *WATCHER* execution mode.

Expand All @@ -82,6 +108,7 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator):
"""

template_fields = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]
_process_log_line_callable: Callable[[str, dict[str, Any]], None] | None = _store_dbt_resource_status_from_log

def __init__(self, *args: Any, **kwargs: Any) -> None:
task_id = kwargs.pop("task_id", "dbt_producer_watcher_operator")
Expand Down
15 changes: 6 additions & 9 deletions tests/hooks/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from cosmos.hooks.subprocess import FullOutputSubprocessHook
from cosmos.operators.watcher import _store_dbt_resource_status_from_log

OS_ENV_KEY = "SUBPROCESS_ENV_TEST"
OS_ENV_VAL = "this-is-from-os-environ"
Expand Down Expand Up @@ -91,18 +92,16 @@ def test_send_sigterm(mock_killpg, mock_getpgid):
],
)
def test_store_dbt_resource_status_from_log_param(status, context, should_push, expect_assert):
trigger = FullOutputSubprocessHook()

# Prepare log line
log_line = {"data": {"node_info": {"node_status": status, "unique_id": "model.jaffle_shop.stg_orders"}}}
line = json.dumps(log_line)

with patch("cosmos.hooks.subprocess.safe_xcom_push") as mock_push:
with patch("cosmos.operators.watcher.safe_xcom_push") as mock_push:
if expect_assert:
with pytest.raises(AssertionError):
trigger._store_dbt_resource_status_from_log(line, context=context)
_store_dbt_resource_status_from_log(line, {"context": context})
else:
trigger._store_dbt_resource_status_from_log(line, context=context)
_store_dbt_resource_status_from_log(line, {"context": context})
if should_push:
mock_push.assert_called_once_with(
task_instance=context["ti"], key="model__jaffle_shop__stg_orders_status", value=status
Expand All @@ -112,10 +111,8 @@ def test_store_dbt_resource_status_from_log_param(status, context, should_push,


def test_store_dbt_resource_status_from_log_invalid_json():
trigger = FullOutputSubprocessHook()

invalid_line = "{not a valid json}"

with patch("cosmos.hooks.subprocess.safe_xcom_push") as mock_push:
trigger._store_dbt_resource_status_from_log(invalid_line, context={"ti": MagicMock()})
with patch("cosmos.operators.watcher.safe_xcom_push") as mock_push:
_store_dbt_resource_status_from_log(invalid_line, {"context": {"ti": MagicMock()}})
mock_push.assert_not_called()