Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions cosmos/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
except ImportError:
from airflow.hooks.base import BaseHook

from cosmos.log import get_logger

logger = get_logger(__name__)


class FullOutputSubprocessResult(NamedTuple):
exit_code: int
Expand Down Expand Up @@ -60,7 +64,7 @@ def run_command(
``output``: the last line from stderr or stdout
``full_output``: all lines from stderr or stdout.
"""
self.log.info("Tmp dir root location: \n %s", gettempdir())
logger.info("Tmp dir root location: \n %s", gettempdir())
log_lines = []
with contextlib.ExitStack() as stack:
if cwd is None:
Expand All @@ -73,7 +77,7 @@ def pre_exec() -> None:
signal.signal(getattr(signal, sig), signal.SIG_DFL)
os.setsid()

self.log.info("Running command: %s", command)
logger.info("Running command: %s", command)

self.sub_process = Popen(
command,
Expand All @@ -91,7 +95,7 @@ def pre_exec() -> None:
if self.sub_process is None:
raise RuntimeError("The subprocess should be created here and is None!")

self.log.info("Command output:")
logger.info("Command output:")

last_line: str = ""
assert self.sub_process.stdout is not None
Expand All @@ -102,23 +106,23 @@ def pre_exec() -> None:
if process_log_line:
process_log_line(line, kwargs)
else:
self.log.info("%s", line)
logger.info("%s", line)

# Wait until process completes
return_code = self.sub_process.wait()

self.log.info("Command exited with return code %s", return_code)
logger.info("Command exited with return code %s", return_code)

return FullOutputSubprocessResult(exit_code=return_code, output=last_line, full_output=log_lines)

def send_sigterm(self) -> None:
"""Sends SIGTERM signal to ``self.sub_process`` if one exists."""
self.log.info("Sending SIGTERM signal to process group")
logger.info("Sending SIGTERM signal to process group")
if self.sub_process and hasattr(self.sub_process, "pid"):
os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM)

def send_sigint(self) -> None:
"""Sends SIGINT signal to ``self.sub_process`` if one exists."""
self.log.info("Sending SIGINT signal to process group")
logger.info("Sending SIGINT signal to process group")
if self.sub_process and hasattr(self.sub_process, "pid"):
os.killpg(os.getpgid(self.sub_process.pid), signal.SIGINT)
26 changes: 15 additions & 11 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None:
except json.JSONDecodeError:
logger.debug("Failed to parse log: %s", line)
log_line = {}
node_info = log_line.get("data", {}).get("node_info", {})
node_status = node_info.get("node_status")
unique_id = node_info.get("unique_id")

logger.debug("Model: %s is in %s state", unique_id, node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
context = extra_kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)
else:
logger.debug("Log line: %s", log_line)
if "info" in log_line and "msg" in log_line["info"]:
logger.info(log_line["info"]["msg"])
Comment thread
tatiana marked this conversation as resolved.
node_info = log_line.get("data", {}).get("node_info", {})
node_status = node_info.get("node_status")
unique_id = node_info.get("unique_id")

logger.debug("Model: %s is in %s state", unique_id, node_status)

# TODO: Handle and store all possible node statuses, not just the current success and failed
if node_status in ["success", "failed"]:
context = extra_kwargs.get("context")
assert context is not None # Make MyPy happy
safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status)

# Additionally, log the message from dbt logs
log_info = log_line.get("info", {})
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,6 @@ def run_subprocess(
process_log_line=self._process_log_line_callable,
**kwargs,
)
# Logging changed in Airflow 3.1 and we needed to replace the output by the full output:
output = "".join(subprocess_result.full_output)
logger.info(output)
Comment thread
tatiana marked this conversation as resolved.
return subprocess_result

def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> dbtRunnerResult:
Expand Down
23 changes: 18 additions & 5 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64
import json
import zlib
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -83,7 +83,7 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator):
template_fields = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]
# Use staticmethod to prevent Python's descriptor protocol from binding the function to `self`
# when accessed via instance, which would incorrectly pass `self` as the first argument
_process_log_line_callable = staticmethod(store_dbt_resource_status_from_log)
_process_log_line_callable: Callable[[str, Any], None] | None = None

def __init__(self, *args: Any, **kwargs: Any) -> None:
task_id = kwargs.pop("task_id", PRODUCER_WATCHER_TASK_ID)
Expand Down Expand Up @@ -148,7 +148,23 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N
if startup_events:
safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events)

def _set_invocation_mode_if_not_set(self) -> None:
if not self.invocation_mode:
logger.info("No invocation mode provided, discovering it")
self._discover_invocation_mode()
Comment thread
tatiana marked this conversation as resolved.

def _set_process_log_line_callable_if_subprocess(self) -> None:
if self.invocation_mode == InvocationMode.SUBPROCESS:
logger.info(
"DbtProducerWatcherOperator: Setting log_format to json and process_log_line_callable to store_dbt_resource_status_from_log"
)
self.log_format = "json"
self._process_log_line_callable = store_dbt_resource_status_from_log

def execute(self, context: Context, **kwargs: Any) -> Any:
self._set_invocation_mode_if_not_set()
self._set_process_log_line_callable_if_subprocess()

task_instance = context.get("ti")
if task_instance is None:
raise AirflowException("DbtProducerWatcherOperator expects a task instance in the execution context")
Expand All @@ -169,9 +185,6 @@ def execute(self, context: Context, **kwargs: Any) -> Any:
)

try:
if not self.invocation_mode:
self._discover_invocation_mode()

use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None
logger.debug("DbtProducerWatcherOperator: use_events=%s", use_events)

Expand Down
4 changes: 4 additions & 0 deletions scripts/test/integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ ls $AIRFLOW_HOME

airflow db check

python -m venv venv-subprocess
venv-subprocess/bin/pip install -U dbt-postgres


rm -rf dbt/jaffle_shop/dbt_packages;

pytest -vv \
--cov=cosmos \
--cov-report=term-missing \
Expand Down
39 changes: 39 additions & 0 deletions tests/dbt/test_executable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest.mock import MagicMock, patch

from cosmos.dbt.executable import get_system_dbt, is_dbt_installed_in_same_environment


class TestGetSystemDbt:
@patch("shutil.which")
def test_get_system_dbt_returns_path_when_found(self, mock_which):
"""Test that get_system_dbt returns the path when dbt is found."""
mock_which.return_value = "/usr/local/bin/dbt"
result = get_system_dbt()
assert result == "/usr/local/bin/dbt"
mock_which.assert_called_once_with("dbt")

@patch("shutil.which")
def test_get_system_dbt_returns_dbt_when_not_found(self, mock_which):
"""Test that get_system_dbt returns 'dbt' when dbt is not found."""
mock_which.return_value = None
result = get_system_dbt()
assert result == "dbt"
mock_which.assert_called_once_with("dbt")


class TestIsDbtInstalledInSameEnvironment:
@patch("cosmos.dbt.executable.find_spec")
def test_is_dbt_installed_in_same_environment_returns_true_when_dbt_found(self, mock_find_spec):
"""Test that is_dbt_installed_in_same_environment returns True when find_spec finds dbt."""
mock_find_spec.return_value = MagicMock() # Simulates dbt module spec found
result = is_dbt_installed_in_same_environment()
assert result is True
mock_find_spec.assert_called_once_with("dbt")

@patch("cosmos.dbt.executable.find_spec")
def test_is_dbt_installed_in_same_environment_returns_false_when_import_error(self, mock_find_spec):
"""Test that is_dbt_installed_in_same_environment returns False when find_spec raises ImportError."""
mock_find_spec.side_effect = ImportError("No module named 'dbt'")
result = is_dbt_installed_in_same_environment()
assert result is False
mock_find_spec.assert_called_once_with("dbt")
100 changes: 57 additions & 43 deletions tests/operators/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop"
DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml"

DBT_EXECUTABLE_PATH = Path(__file__).parent.parent.parent / "venv-subprocess/bin/dbt"
DBT_PROJECT_WITH_EMPTY_MODEL_PATH = Path(__file__).parent.parent / "sample/dbt_project_with_empty_model"

project_config = ProjectConfig(
Expand Down Expand Up @@ -548,48 +550,10 @@ def test_store_dbt_resource_status_from_log_outputs_dbt_info(self, caplog, msg,
assert msg in caplog.text
assert any(record.levelname == logging.getLevelName(dynamic_level) for record in caplog.records)

def test_process_log_line_callable_is_not_bound_method(self):
"""Test that _process_log_line_callable is not bound as a method when accessed through an instance.

This test verifies the fix for the bug where accessing _process_log_line_callable through
an instance would create a bound method, causing 'self' to be passed as the first argument.
"""
import inspect

op = DbtProducerWatcherOperator(project_dir=".", profile_config=None)

# Access the callable through the instance
callable_from_instance = op._process_log_line_callable

# Verify it's not a bound method (which would have __self__ attribute)
assert not inspect.ismethod(
callable_from_instance
), "_process_log_line_callable should not be a bound method when accessed through instance"

# Verify it's the original function
assert callable_from_instance is store_dbt_resource_status_from_log

def test_process_log_line_callable_accepts_two_arguments(self):
"""Test that the callable can be called with exactly 2 arguments (line, kwargs).

This tests the integration pattern used in subprocess.py where process_log_line(line, kwargs) is called.
"""
op = DbtProducerWatcherOperator(project_dir=".", profile_config=None)
callable_from_instance = op._process_log_line_callable

ti = _MockTI()
ctx = {"ti": ti}

log_line = json.dumps({"data": {"node_info": {"node_status": "success", "unique_id": "model.pkg.test_model"}}})

# This should NOT raise TypeError about wrong number of arguments
callable_from_instance(log_line, {"context": ctx})

assert ti.store.get("model__pkg__test_model_status") == "success"
Comment thread
tatiana marked this conversation as resolved.

def test_process_log_line_callable_integration_with_subprocess_pattern(self):
"""Test the exact pattern used in subprocess.py: process_log_line(line, kwargs)."""
op = DbtProducerWatcherOperator(project_dir=".", profile_config=None)
op._process_log_line_callable = store_dbt_resource_status_from_log

ti = _MockTI()
ctx = {"ti": ti}
Expand All @@ -604,10 +568,8 @@ def test_process_log_line_callable_integration_with_subprocess_pattern(self):
]

# Simulate the subprocess.py pattern
process_log_line = op._process_log_line_callable
for line in log_lines:
if process_log_line:
process_log_line(line, kwargs)
op._process_log_line_callable(line, kwargs)

assert ti.store.get("model__pkg__model_a_status") == "success"
assert ti.store.get("model__pkg__model_b_status") == "failed"
Expand Down Expand Up @@ -1036,7 +998,7 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self):

@pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release")
@pytest.mark.integration
def test_dbt_dag_with_watcher():
def test_dbt_dag_with_watcher(capsys):
"""
Run a DbtDag using `ExecutionMode.WATCHER`.
Confirm the right amount of tasks is created and that tasks are in the expected topological order.
Expand All @@ -1049,6 +1011,7 @@ def test_dbt_dag_with_watcher():
dag_id="watcher_dag",
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.WATCHER,
invocation_mode=InvocationMode.DBT_RUNNER,
Comment thread
tatiana marked this conversation as resolved.
),
render_config=RenderConfig(emit_datasets=False),
operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)},
Expand Down Expand Up @@ -1099,6 +1062,57 @@ def test_dbt_dag_with_watcher():
"raw_customers_seed",
}

# dbt runner logs are not captured by caplog, so we need to capture them using capsys
capsys_output = capsys.readouterr()
stdout = capsys_output.out

assert (
'''"node_status": "success", "resource_type": "seed", "unique_id": "seed.jaffle_shop.raw_orders"'''
not in stdout
)

assert "OK loaded seed file public.raw_orders" in stdout


@pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release")
@pytest.mark.integration
def test_dbt_dag_with_watcher_and_subprocess(caplog):
"""
Run a DbtDag using `ExecutionMode.WATCHER`.
Confirm the right amount of tasks is created and that tasks are in the expected topological order.
Confirm that the producer watcher task is created and that it is the parent of the root dbt nodes.
"""
watcher_dag = DbtDag(
project_config=project_config,
profile_config=profile_config,
start_date=datetime(2023, 1, 1),
dag_id="watcher_dag",
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.WATCHER,
invocation_mode=InvocationMode.SUBPROCESS,
dbt_executable_path=DBT_EXECUTABLE_PATH,
),
render_config=RenderConfig(emit_datasets=False, select=["raw_orders"], test_behavior=TestBehavior.AFTER_ALL),
operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)},
)
dag_run = new_test_dag(watcher_dag)
assert dag_run.state == DagRunState.SUCCESS

assert len(watcher_dag.dbt_graph.filtered_nodes) == 1

assert len(watcher_dag.task_dict) == 3
tasks_names = [task.task_id for task in watcher_dag.topological_sort()]
expected_task_names = ["dbt_producer_watcher", "raw_orders_seed", "jaffle_shop_test"]
assert tasks_names == expected_task_names
# Confirm that the dbt command was successfully run using the given dbt executable path:
assert "venv-subprocess/bin/dbt'), 'build'" in caplog.text
# Confirm that the seed was successfully run and the log output was JSON:
assert (
'''"node_status": "success", "resource_type": "seed", "unique_id": "seed.jaffle_shop.raw_orders"'''
not in caplog.text
)
assert "OK loaded seed file public.raw_orders" in caplog.text


# Airflow 3.0.0 hangs indefinitely, while Airflow 3.0.6 fails due to this Airflow bug:
# https://github.com/apache/airflow/issues/51816
Expand Down