diff --git a/cosmos/hooks/subprocess.py b/cosmos/hooks/subprocess.py index e028d57af6..2e12876625 100644 --- a/cosmos/hooks/subprocess.py +++ b/cosmos/hooks/subprocess.py @@ -18,6 +18,10 @@ except ImportError: from airflow.hooks.base import BaseHook +from cosmos.log import get_logger + +logger = get_logger(__name__) + class FullOutputSubprocessResult(NamedTuple): exit_code: int @@ -60,7 +64,7 @@ def run_command( ``output``: the last line from stderr or stdout ``full_output``: all lines from stderr or stdout. """ - self.log.info("Tmp dir root location: \n %s", gettempdir()) + logger.info("Tmp dir root location: \n %s", gettempdir()) log_lines = [] with contextlib.ExitStack() as stack: if cwd is None: @@ -73,7 +77,7 @@ def pre_exec() -> None: signal.signal(getattr(signal, sig), signal.SIG_DFL) os.setsid() - self.log.info("Running command: %s", command) + logger.info("Running command: %s", command) self.sub_process = Popen( command, @@ -91,7 +95,7 @@ def pre_exec() -> None: if self.sub_process is None: raise RuntimeError("The subprocess should be created here and is None!") - self.log.info("Command output:") + logger.info("Command output:") last_line: str = "" assert self.sub_process.stdout is not None @@ -102,23 +106,23 @@ def pre_exec() -> None: if process_log_line: process_log_line(line, kwargs) else: - self.log.info("%s", line) + logger.info("%s", line) # Wait until process completes return_code = self.sub_process.wait() - self.log.info("Command exited with return code %s", return_code) + logger.info("Command exited with return code %s", return_code) return FullOutputSubprocessResult(exit_code=return_code, output=last_line, full_output=log_lines) def send_sigterm(self) -> None: """Sends SIGTERM signal to ``self.sub_process`` if one exists.""" - self.log.info("Sending SIGTERM signal to process group") + logger.info("Sending SIGTERM signal to process group") if self.sub_process and hasattr(self.sub_process, "pid"): os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM) def send_sigint(self) -> None: """Sends SIGINT signal to ``self.sub_process`` if one exists.""" - self.log.info("Sending SIGINT signal to process group") + logger.info("Sending SIGINT signal to process group") if self.sub_process and hasattr(self.sub_process, "pid"): os.killpg(os.getpgid(self.sub_process.pid), signal.SIGINT) diff --git a/cosmos/operators/_watcher/base.py b/cosmos/operators/_watcher/base.py index 4428f5e7ad..965bbdadbf 100644 --- a/cosmos/operators/_watcher/base.py +++ b/cosmos/operators/_watcher/base.py @@ -40,17 +40,21 @@ def store_dbt_resource_status_from_log(line: str, extra_kwargs: Any) -> None: except json.JSONDecodeError: logger.debug("Failed to parse log: %s", line) log_line = {} - node_info = log_line.get("data", {}).get("node_info", {}) - node_status = node_info.get("node_status") - unique_id = node_info.get("unique_id") - - logger.debug("Model: %s is in %s state", unique_id, node_status) - - # TODO: Handle and store all possible node statuses, not just the current success and failed - if node_status in ["success", "failed"]: - context = extra_kwargs.get("context") - assert context is not None # Make MyPy happy - safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) + else: + logger.debug("Log line: %s", log_line) + if "info" in log_line and "msg" in log_line["info"]: + logger.info(log_line["info"]["msg"]) + node_info = log_line.get("data", {}).get("node_info", {}) + node_status = node_info.get("node_status") + unique_id = node_info.get("unique_id") + + logger.debug("Model: %s is in %s state", unique_id, node_status) + + # TODO: Handle and store all possible node statuses, not just the current success and failed + if node_status in ["success", "failed"]: + context = extra_kwargs.get("context") + assert context is not None # Make MyPy happy + safe_xcom_push(task_instance=context["ti"], key=f"{unique_id.replace('.', '__')}_status", value=node_status) # Additionally, log the message from dbt logs log_info = log_line.get("info", {}) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index bab5f4f21c..ca17d95ed9 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -463,9 +463,6 @@ def run_subprocess( process_log_line=self._process_log_line_callable, **kwargs, ) - # Logging changed in Airflow 3.1 and we needed to replace the output by the full output: - output = "".join(subprocess_result.full_output) - logger.info(output) return subprocess_result def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> dbtRunnerResult: diff --git a/cosmos/operators/watcher.py b/cosmos/operators/watcher.py index 771ed821de..a999ea54e6 100644 --- a/cosmos/operators/watcher.py +++ b/cosmos/operators/watcher.py @@ -3,7 +3,7 @@ import base64 import json import zlib -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any @@ -83,7 +83,7 @@ class DbtProducerWatcherOperator(DbtBuildMixin, DbtLocalBaseOperator): template_fields = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator] # Use staticmethod to prevent Python's descriptor protocol from binding the function to `self` # when accessed via instance, which would incorrectly pass `self` as the first argument - _process_log_line_callable = staticmethod(store_dbt_resource_status_from_log) + _process_log_line_callable: Callable[[str, Any], None] | None = None def __init__(self, *args: Any, **kwargs: Any) -> None: task_id = kwargs.pop("task_id", PRODUCER_WATCHER_TASK_ID) @@ -148,7 +148,23 @@ def _finalize(self, context: Context, startup_events: list[dict[str, Any]]) -> N if startup_events: safe_xcom_push(task_instance=context["ti"], key="dbt_startup_events", value=startup_events) + def _set_invocation_mode_if_not_set(self) -> None: + if not self.invocation_mode: + logger.info("No invocation mode provided, discovering it") + self._discover_invocation_mode() + + def _set_process_log_line_callable_if_subprocess(self) -> None: + if self.invocation_mode == InvocationMode.SUBPROCESS: + logger.info( + "DbtProducerWatcherOperator: Setting log_format to json and process_log_line_callable to store_dbt_resource_status_from_log" + ) + self.log_format = "json" + self._process_log_line_callable = store_dbt_resource_status_from_log + def execute(self, context: Context, **kwargs: Any) -> Any: + self._set_invocation_mode_if_not_set() + self._set_process_log_line_callable_if_subprocess() + task_instance = context.get("ti") if task_instance is None: raise AirflowException("DbtProducerWatcherOperator expects a task instance in the execution context") @@ -169,9 +185,6 @@ def execute(self, context: Context, **kwargs: Any) -> Any: ) try: - if not self.invocation_mode: - self._discover_invocation_mode() - use_events = self.invocation_mode == InvocationMode.DBT_RUNNER and EventMsg is not None logger.debug("DbtProducerWatcherOperator: use_events=%s", use_events) diff --git a/scripts/test/integration.sh b/scripts/test/integration.sh index f988198076..a844e37357 100644 --- a/scripts/test/integration.sh +++ b/scripts/test/integration.sh @@ -11,8 +11,12 @@ ls $AIRFLOW_HOME airflow db check +python -m venv venv-subprocess +venv-subprocess/bin/pip install -U dbt-postgres + rm -rf dbt/jaffle_shop/dbt_packages; + pytest -vv \ --cov=cosmos \ --cov-report=term-missing \ diff --git a/tests/dbt/test_executable.py b/tests/dbt/test_executable.py new file mode 100644 index 0000000000..78c8bb50e7 --- /dev/null +++ b/tests/dbt/test_executable.py @@ -0,0 +1,39 @@ +from unittest.mock import MagicMock, patch + +from cosmos.dbt.executable import get_system_dbt, is_dbt_installed_in_same_environment + + +class TestGetSystemDbt: + @patch("shutil.which") + def test_get_system_dbt_returns_path_when_found(self, mock_which): + """Test that get_system_dbt returns the path when dbt is found.""" + mock_which.return_value = "/usr/local/bin/dbt" + result = get_system_dbt() + assert result == "/usr/local/bin/dbt" + mock_which.assert_called_once_with("dbt") + + @patch("shutil.which") + def test_get_system_dbt_returns_dbt_when_not_found(self, mock_which): + """Test that get_system_dbt returns 'dbt' when dbt is not found.""" + mock_which.return_value = None + result = get_system_dbt() + assert result == "dbt" + mock_which.assert_called_once_with("dbt") + + +class TestIsDbtInstalledInSameEnvironment: + @patch("cosmos.dbt.executable.find_spec") + def test_is_dbt_installed_in_same_environment_returns_true_when_dbt_found(self, mock_find_spec): + """Test that is_dbt_installed_in_same_environment returns True when find_spec finds dbt.""" + mock_find_spec.return_value = MagicMock() # Simulates dbt module spec found + result = is_dbt_installed_in_same_environment() + assert result is True + mock_find_spec.assert_called_once_with("dbt") + + @patch("cosmos.dbt.executable.find_spec") + def test_is_dbt_installed_in_same_environment_returns_false_when_import_error(self, mock_find_spec): + """Test that is_dbt_installed_in_same_environment returns False when find_spec raises ImportError.""" + mock_find_spec.side_effect = ImportError("No module named 'dbt'") + result = is_dbt_installed_in_same_environment() + assert result is False + mock_find_spec.assert_called_once_with("dbt") diff --git a/tests/operators/test_watcher.py b/tests/operators/test_watcher.py index 1fcb9f1c6b..9f743a72eb 100644 --- a/tests/operators/test_watcher.py +++ b/tests/operators/test_watcher.py @@ -33,6 +33,8 @@ DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" DBT_PROFILES_YAML_FILEPATH = DBT_PROJECT_PATH / "profiles.yml" + +DBT_EXECUTABLE_PATH = Path(__file__).parent.parent.parent / "venv-subprocess/bin/dbt" DBT_PROJECT_WITH_EMPTY_MODEL_PATH = Path(__file__).parent.parent / "sample/dbt_project_with_empty_model" project_config = ProjectConfig( @@ -548,48 +550,10 @@ def test_store_dbt_resource_status_from_log_outputs_dbt_info(self, caplog, msg, assert msg in caplog.text assert any(record.levelname == logging.getLevelName(dynamic_level) for record in caplog.records) - def test_process_log_line_callable_is_not_bound_method(self): - """Test that _process_log_line_callable is not bound as a method when accessed through an instance. - - This test verifies the fix for the bug where accessing _process_log_line_callable through - an instance would create a bound method, causing 'self' to be passed as the first argument. - """ - import inspect - - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - - # Access the callable through the instance - callable_from_instance = op._process_log_line_callable - - # Verify it's not a bound method (which would have __self__ attribute) - assert not inspect.ismethod( - callable_from_instance - ), "_process_log_line_callable should not be a bound method when accessed through instance" - - # Verify it's the original function - assert callable_from_instance is store_dbt_resource_status_from_log - - def test_process_log_line_callable_accepts_two_arguments(self): - """Test that the callable can be called with exactly 2 arguments (line, kwargs). - - This tests the integration pattern used in subprocess.py where process_log_line(line, kwargs) is called. - """ - op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) - callable_from_instance = op._process_log_line_callable - - ti = _MockTI() - ctx = {"ti": ti} - - log_line = json.dumps({"data": {"node_info": {"node_status": "success", "unique_id": "model.pkg.test_model"}}}) - - # This should NOT raise TypeError about wrong number of arguments - callable_from_instance(log_line, {"context": ctx}) - - assert ti.store.get("model__pkg__test_model_status") == "success" - def test_process_log_line_callable_integration_with_subprocess_pattern(self): """Test the exact pattern used in subprocess.py: process_log_line(line, kwargs).""" op = DbtProducerWatcherOperator(project_dir=".", profile_config=None) + op._process_log_line_callable = store_dbt_resource_status_from_log ti = _MockTI() ctx = {"ti": ti} @@ -604,10 +568,8 @@ def test_process_log_line_callable_integration_with_subprocess_pattern(self): ] # Simulate the subprocess.py pattern - process_log_line = op._process_log_line_callable for line in log_lines: - if process_log_line: - process_log_line(line, kwargs) + op._process_log_line_callable(line, kwargs) assert ti.store.get("model__pkg__model_a_status") == "success" assert ti.store.get("model__pkg__model_b_status") == "failed" @@ -1036,7 +998,7 @@ def test_dbt_build_watcher_operator_raises_not_implemented_error(self): @pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release") @pytest.mark.integration -def test_dbt_dag_with_watcher(): +def test_dbt_dag_with_watcher(capsys): """ Run a DbtDag using `ExecutionMode.WATCHER`. Confirm the right amount of tasks is created and that tasks are in the expected topological order. @@ -1049,6 +1011,7 @@ def test_dbt_dag_with_watcher(): dag_id="watcher_dag", execution_config=ExecutionConfig( execution_mode=ExecutionMode.WATCHER, + invocation_mode=InvocationMode.DBT_RUNNER, ), render_config=RenderConfig(emit_datasets=False), operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, @@ -1099,6 +1062,57 @@ def test_dbt_dag_with_watcher(): "raw_customers_seed", } + # dbt runner logs are not captured by caplog, so we need to capture them using capsys + capsys_output = capsys.readouterr() + stdout = capsys_output.out + + assert ( + '''"node_status": "success", "resource_type": "seed", "unique_id": "seed.jaffle_shop.raw_orders"''' + not in stdout + ) + + assert "OK loaded seed file public.raw_orders" in stdout + + +@pytest.mark.skipif(AIRFLOW_VERSION < Version("2.7"), reason="Airflow did not have dag.test() until the 2.6 release") +@pytest.mark.integration +def test_dbt_dag_with_watcher_and_subprocess(caplog): + """ + Run a DbtDag using `ExecutionMode.WATCHER`. + Confirm the right amount of tasks is created and that tasks are in the expected topological order. + Confirm that the producer watcher task is created and that it is the parent of the root dbt nodes. + """ + watcher_dag = DbtDag( + project_config=project_config, + profile_config=profile_config, + start_date=datetime(2023, 1, 1), + dag_id="watcher_dag", + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.WATCHER, + invocation_mode=InvocationMode.SUBPROCESS, + dbt_executable_path=DBT_EXECUTABLE_PATH, + ), + render_config=RenderConfig(emit_datasets=False, select=["raw_orders"], test_behavior=TestBehavior.AFTER_ALL), + operator_args={"trigger_rule": "all_success", "execution_timeout": timedelta(seconds=120)}, + ) + dag_run = new_test_dag(watcher_dag) + assert dag_run.state == DagRunState.SUCCESS + + assert len(watcher_dag.dbt_graph.filtered_nodes) == 1 + + assert len(watcher_dag.task_dict) == 3 + tasks_names = [task.task_id for task in watcher_dag.topological_sort()] + expected_task_names = ["dbt_producer_watcher", "raw_orders_seed", "jaffle_shop_test"] + assert tasks_names == expected_task_names + # Confirm that the dbt command was successfully run using the given dbt executable path: + assert "venv-subprocess/bin/dbt'), 'build'" in caplog.text + # Confirm that the seed was successfully run and the log output was JSON: + assert ( + '''"node_status": "success", "resource_type": "seed", "unique_id": "seed.jaffle_shop.raw_orders"''' + not in caplog.text + ) + assert "OK loaded seed file public.raw_orders" in caplog.text + # Airflow 3.0.0 hangs indefinitely, while Airflow 3.0.6 fails due to this Airflow bug: # https://github.com/apache/airflow/issues/51816