From 9fe6336c7df6d401ef744741990a95710f80c0ec Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Mon, 5 Feb 2024 15:37:06 -0800 Subject: [PATCH 01/16] add InvocationMode to ExecutionConfig --- cosmos/config.py | 16 +++++++++++++++- cosmos/constants.py | 9 +++++++++ tests/test_config.py | 19 ++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/cosmos/config.py b/cosmos/config.py index dc33c0eba5..29fd131f8d 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -10,7 +10,14 @@ import warnings from typing import Any, Iterator, Callable -from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode, TestIndirectSelection +from cosmos.constants import ( + DbtResourceType, + TestBehavior, + ExecutionMode, + LoadMode, + TestIndirectSelection, + InvocationMode, +) from cosmos.dbt.executable import get_system_dbt from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger @@ -290,12 +297,15 @@ class ExecutionConfig: Contains configuration about how to execute dbt. :param execution_mode: The execution mode for dbt. Defaults to local + :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV + execution modes. :param test_indirect_selection: The mode to configure the test behavior when performing indirect selection. :param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path. :param dbt_project_path Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path """ execution_mode: ExecutionMode = ExecutionMode.LOCAL + invocation_mode: InvocationMode | None = None test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER dbt_executable_path: str | Path = field(default_factory=get_system_dbt) @@ -303,4 +313,8 @@ class ExecutionConfig: project_path: Path | None = field(init=False) def __post_init__(self, dbt_project_path: str | Path | None) -> None: + if self.invocation_mode and self.execution_mode not in {ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV}: + raise CosmosValueError( + "ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV modes." + ) self.project_path = Path(dbt_project_path) if dbt_project_path else None diff --git a/cosmos/constants.py b/cosmos/constants.py index 4741d621d6..65bacff88f 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -53,6 +53,15 @@ class ExecutionMode(Enum): AZURE_CONTAINER_INSTANCE = "azure_container_instance" +class InvocationMode(Enum): + """ + How the dbt command should be invoked. + """ + + SUBPROCESS = "subprocess" + DBT_RUNNER = "dbt_runner" + + class TestIndirectSelection(Enum): """ Modes to configure the test behavior when performing indirect selection. diff --git a/tests/test_config.py b/tests/test_config.py index 795fcffb69..d7c456938a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,12 @@ from pathlib import Path from unittest.mock import patch from cosmos.profiles.postgres.user_pass import PostgresUserPasswordProfileMapping +from contextlib import nullcontext as does_not_raise import pytest -from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig, CosmosConfigException +from cosmos.constants import ExecutionMode, InvocationMode +from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig, CosmosConfigException from cosmos.exceptions import CosmosValueError @@ -195,3 +197,18 @@ def test_render_config_env_vars_deprecated(): """RenderConfig.env_vars is deprecated since Cosmos 1.3, should warn user.""" with pytest.deprecated_call(): RenderConfig(env_vars={"VAR": "value"}) + + +@pytest.mark.parametrize( + "execution_mode, expectation", + [ + (ExecutionMode.LOCAL, does_not_raise()), + (ExecutionMode.VIRTUALENV, does_not_raise()), + (ExecutionMode.KUBERNETES, pytest.raises(CosmosValueError)), + (ExecutionMode.DOCKER, pytest.raises(CosmosValueError)), + (ExecutionMode.AZURE_CONTAINER_INSTANCE, pytest.raises(CosmosValueError)), + ], +) +def test_execution_config_with_invocation_option(execution_mode, expectation): + with expectation: + ExecutionConfig(execution_mode=execution_mode, invocation_mode=InvocationMode.DBT_RUNNER) From fc60b9167c3a3a6ac3717155d313c81d93637aeb Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Mon, 5 Feb 2024 15:52:05 -0800 Subject: [PATCH 02/16] pass invocation_mode in task_args only if not None since it's only valid for local/venv operators --- cosmos/converter.py | 2 ++ tests/test_converter.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/cosmos/converter.py b/cosmos/converter.py index 97e8190dd1..5b62cbab1c 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -253,6 +253,8 @@ def __init__( } if execution_config.dbt_executable_path: task_args["dbt_executable_path"] = execution_config.dbt_executable_path + if execution_config.invocation_mode: + task_args["invocation_mode"] = execution_config.invocation_mode validate_arguments( render_config.select, diff --git a/tests/test_converter.py b/tests/test_converter.py index d84249aaee..7becc3a8be 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -7,7 +7,7 @@ from airflow.models import DAG from cosmos.converter import DbtToAirflowConverter, validate_arguments, validate_initial_user_config -from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode +from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode, InvocationMode from cosmos.config import ProjectConfig, ProfileConfig, ExecutionConfig, RenderConfig, CosmosConfigException from cosmos.dbt.graph import DbtNode from cosmos.exceptions import CosmosValueError @@ -405,3 +405,32 @@ def test_converter_project_config_dbt_vars_with_custom_load_mode( ) _, kwargs = mock_legacy_dbt_project.call_args assert kwargs["dbt_vars"] == {"key": "value"} + + +@pytest.mark.parametrize("invocation_mode", [None, InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +@patch("cosmos.config.ProjectConfig.validate_project") +@patch("cosmos.converter.build_airflow_graph") +def test_converter_invocation_mode_added_to_task_args(mock_build_airflow_graph, mock_validate_project, invocation_mode): + """Tests that the `task_args` passed to build_airflow_graph has invocation_mode if + it is not None. + """ + project_config = ProjectConfig(project_name="fake-project", dbt_project_path="/some/project/path") + execution_config = ExecutionConfig(invocation_mode=invocation_mode) + render_config = RenderConfig() + profile_config = MagicMock() + + with DAG("test-id", start_date=datetime(2024, 1, 1)) as dag: + DbtToAirflowConverter( + dag=dag, + nodes=nodes, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + operator_args={}, + ) + _, kwargs = mock_build_airflow_graph.call_args + if invocation_mode: + assert kwargs["task_args"]["invocation_mode"] == invocation_mode + else: + assert "invocation_mode" not in kwargs["task_args"] From 2f7d2ab48bfc9201fff6ce699e416c8fab29da06 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Mon, 5 Feb 2024 18:52:48 -0800 Subject: [PATCH 03/16] allow DbtLocalBaseOperator to use dbtRunner for invocation --- cosmos/dbt/parser/output.py | 66 ++++++++++++--- cosmos/operators/local.py | 88 ++++++++++++++------ tests/dbt/parser/test_output.py | 58 +++++++++++-- tests/operators/test_local.py | 127 +++++++++++++++++++++++++++-- tests/operators/test_virtualenv.py | 2 +- 5 files changed, 289 insertions(+), 52 deletions(-) diff --git a/cosmos/dbt/parser/output.py b/cosmos/dbt/parser/output.py index 791c4b6057..f74c78c6cf 100644 --- a/cosmos/dbt/parser/output.py +++ b/cosmos/dbt/parser/output.py @@ -1,33 +1,53 @@ +from __future__ import annotations + import logging import re -from typing import List, Tuple +from typing import List, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from dbt.cli.main import dbtRunnerResult from cosmos.hooks.subprocess import FullOutputSubprocessResult -def parse_output(result: FullOutputSubprocessResult, keyword: str) -> int: +DBT_NO_TESTS_MSG = "Nothing to do" +DBT_WARN_MSG = "WARN" + + +def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> int: """ - Parses the dbt test output message and returns the number of errors or warnings. + Parses the dbt test output message and returns the number of warnings. :param result: String containing the output to be parsed. - :param keyword: String representing the keyword to search for in the output (WARN, ERROR). :return: An integer value associated with the keyword, or 0 if parsing fails. Usage: ----- output_str = "Done. PASS=15 WARN=1 ERROR=0 SKIP=0 TOTAL=16" - keyword = "WARN" - num_warns = parse_output(output_str, keyword) + num_warns = parse_output(output_str) print(num_warns) # Output: 1 """ output = result.output - try: - num = int(output.split(f"{keyword}=")[1].split()[0]) - except ValueError: - logging.error( - f"Could not parse number of {keyword}s. Check your dbt/airflow version or if --quiet is not being used" - ) + num = 0 + if DBT_NO_TESTS_MSG not in result.output and DBT_WARN_MSG in result.output: + try: + num = int(output.split(f"{DBT_WARN_MSG}=")[1].split()[0]) + except ValueError: + logging.error( + f"Could not parse number of {DBT_WARN_MSG}s. Check your dbt/airflow version or if --quiet is not being used" + ) + return num + + +def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int: + """Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult + from invoking dbt build, compile, run, seed, snapshot, test, or run-operation. + """ + num = 0 + for run_result in result.result.results: # type: ignore + if run_result.status == "warn": + num += 1 return num @@ -67,3 +87,25 @@ def clean_line(line: str) -> str: test_results.append(test_result) return test_names, test_results + + +def extract_dbt_runner_issues(result: dbtRunnerResult) -> Tuple[List[str], List[str]]: + """ + Extracts warning messages from the dbt runner result and returns them as a formatted string. + + This function searches for warning messages in dbt run. It extracts and formats the relevant + information and appends it to a list of warnings. + + :param result: dbtRunnerResult object containing the output to be parsed. + :return: two lists of strings, the first one containing the test names and the second one + containing the test results. + """ + test_names = [] + test_results = [] + + for run_result in result.result.results: # type: ignore + if run_result.status == "warn": + test_names.append(str(run_result.node.name)) + test_results.append(str(run_result.message)) + + return test_names, test_results diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 58090ef50d..fac2a76f5d 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -18,6 +18,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session +from cosmos.constants import InvocationMode try: from openlineage.common.provider.dbt.local import DbtLocalArtifactProcessor @@ -31,6 +32,8 @@ if TYPE_CHECKING: from airflow.datasets import Dataset # noqa: F811 from openlineage.client.run import RunEvent + from dbt.cli.main import dbtRunner, dbtRunnerResult + from typing import Callable from sqlalchemy.orm import Session @@ -51,11 +54,14 @@ FullOutputSubprocessHook, FullOutputSubprocessResult, ) -from cosmos.dbt.parser.output import extract_log_issues, parse_output -from cosmos.dbt.project import create_symlinks +from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, + extract_log_issues, + parse_number_of_warnings_dbt_runner, + parse_number_of_warnings_subprocess, +) +from cosmos.dbt.project import create_symlinks, environ -DBT_NO_TESTS_MSG = "Nothing to do" -DBT_WARN_MSG = "WARN" logger = get_logger(__name__) @@ -111,6 +117,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, profile_config: ProfileConfig, + invocation_mode: InvocationMode = InvocationMode.SUBPROCESS, install_deps: bool = False, callback: Callable[[str], None] | None = None, should_store_compiled_sql: bool = True, @@ -122,6 +129,10 @@ def __init__( self.compiled_sql = "" self.should_store_compiled_sql = should_store_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] + self.invocation_mode = invocation_mode + self.invoke_dbt = getattr(self, f"run_{invocation_mode.value}") + self.handle_exception = getattr(self, f"handle_exception_{invocation_mode.value}") + self._dbt_runner: dbtRunner | None = None kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) @@ -130,7 +141,7 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() - def exception_handling(self, result: FullOutputSubprocessResult) -> None: + def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.") elif result.exit_code != 0: @@ -139,6 +150,11 @@ def exception_handling(self, result: FullOutputSubprocessResult) -> None: *result.full_output, ) + def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None: + """dbtRunnerResult has an attribute `success` that is False if the command failed.""" + if not result.success: + raise AirflowException("dbt command failed. See logs above for details.") + @provide_session def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: """ @@ -188,14 +204,33 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se def run_subprocess(self, *args: Any, **kwargs: Any) -> FullOutputSubprocessResult: subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) + logger.info(subprocess_result.output) return subprocess_result + def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> dbtRunnerResult: + """Invokes the dbt command programmatically.""" + try: + from dbt.cli.main import dbtRunner + except ImportError: + raise ImportError( + "Could not import dbt core. Ensure that dbt is installed and available in the environment where the operator is running." + ) + + if self._dbt_runner is None: + self._dbt_runner = dbtRunner() + + # The dbt runner will cd into the project directory and restore the cwd when completed + cli_args = command[1:] + ["--project-dir", cwd] + with environ(env): + result = self._dbt_runner.invoke(cli_args) + return result + def run_command( self, cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, - ) -> FullOutputSubprocessResult: + ) -> FullOutputSubprocessResult | dbtRunnerResult: """ Copies the dbt project to a temporary directory and runs the command. """ @@ -224,7 +259,7 @@ def run_command( if self.install_deps: deps_command = [self.dbt_executable_path, "deps"] deps_command.extend(flags) - self.run_subprocess( + self.invoke_dbt( command=deps_command, env=env, output_encoding=self.output_encoding, @@ -235,7 +270,7 @@ def run_command( logger.info("Trying to run the command:\n %s\nFrom %s", full_cmd, tmp_project_dir) logger.info("Using environment variables keys: %s", env.keys()) - result = self.run_subprocess( + result = self.invoke_dbt( command=full_cmd, env=env, output_encoding=self.output_encoding, @@ -255,11 +290,11 @@ def run_command( self.register_dataset(inlets, outlets) self.store_compiled_sql(tmp_project_dir, context) - self.exception_handling(result) + self.handle_exception(result) if self.callback: self.callback(tmp_project_dir) - return result + return result # type: ignore def calculate_openlineage_events_completes( self, env: dict[str, str | os.PathLike[Any] | bytes], project_dir: Path @@ -365,11 +400,12 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope job_facets=job_facets, ) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult: + def build_and_run_cmd( + self, context: Context, cmd_flags: list[str] | None = None + ) -> FullOutputSubprocessResult | dbtRunnerResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] result = self.run_command(cmd=dbt_cmd, env=env, context=context) - logger.info(result.output) return result def on_kill(self) -> None: @@ -429,8 +465,16 @@ def __init__( ) -> None: super().__init__(**kwargs) self.on_warning_callback = on_warning_callback - - def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: + self.extract_issues = { + InvocationMode.SUBPROCESS: lambda result: extract_log_issues(result.full_output), + InvocationMode.DBT_RUNNER: extract_dbt_runner_issues, + }[self.invocation_mode] + self.parse_number_of_warnings = { + InvocationMode.SUBPROCESS: parse_number_of_warnings_subprocess, + InvocationMode.DBT_RUNNER: parse_number_of_warnings_dbt_runner, + }[self.invocation_mode] + + def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ Handles warnings by extracting log issues, creating additional context, and calling the on_warning_callback with the updated context. @@ -438,7 +482,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) :param result: The result object from the build and run command. :param context: The original airflow context in which the build and run command was executed. """ - test_names, test_results = extract_log_issues(result.full_output) + test_names, test_results = self.extract_issues(result) warning_context = dict(context) warning_context["test_names"] = test_names @@ -448,17 +492,9 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) def execute(self, context: Context) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) - should_trigger_callback = all( - [ - self.on_warning_callback, - DBT_NO_TESTS_MSG not in result.output, - DBT_WARN_MSG in result.output, - ] - ) - if should_trigger_callback: - warnings = parse_output(result, "WARN") - if warnings > 0: - self._handle_warnings(result, context) + number_of_warnings = self.parse_number_of_warnings(result) # type: ignore + if self.on_warning_callback and number_of_warnings > 0: + self._handle_warnings(result, context) class DbtRunOperationLocalOperator(DbtRunOperationMixin, DbtLocalBaseOperator): diff --git a/tests/dbt/parser/test_output.py b/tests/dbt/parser/test_output.py index 0f4ba56cde..0d33198ff6 100644 --- a/tests/dbt/parser/test_output.py +++ b/tests/dbt/parser/test_output.py @@ -1,18 +1,40 @@ +import pytest +from unittest.mock import MagicMock from airflow.hooks.subprocess import SubprocessResult from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, extract_log_issues, - parse_output, + parse_number_of_warnings_subprocess, + parse_number_of_warnings_dbt_runner, ) -def test_parse_output() -> None: - for warnings in range(0, 3): - output_str = f"Done. PASS=15 WARN={warnings} ERROR=0 SKIP=0 TOTAL=16" - keyword = "WARN" - result = SubprocessResult(exit_code=0, output=output_str) - num_warns = parse_output(result, keyword) - assert num_warns == warnings +@pytest.mark.parametrize( + "output_str, expected_warnings", + [ + ("Done. PASS=15 WARN=1 ERROR=0 SKIP=0 TOTAL=16", 1), + ("Done. PASS=15 WARN=0 ERROR=0 SKIP=0 TOTAL=16", 0), + ("Done. PASS=15 WARN=2 ERROR=0 SKIP=0 TOTAL=16", 2), + ("Nothing to do. Exiting without running tests.", 0), + ], +) +def test_parse_number_of_warnings_subprocess(output_str: str, expected_warnings) -> None: + result = SubprocessResult(exit_code=0, output=output_str) + num_warns = parse_number_of_warnings_subprocess(result) + assert num_warns == expected_warnings + + +def test_parse_number_of_warnings_dbt_runner_with_warnings(): + runner_result = MagicMock() + runner_result.result.results = [ + MagicMock(status="pass"), + MagicMock(status="warn"), + MagicMock(status="pass"), + MagicMock(status="warn"), + ] + num_warns = parse_number_of_warnings_dbt_runner(runner_result) + assert num_warns == 2 def test_extract_log_issues() -> None: @@ -37,3 +59,23 @@ def test_extract_log_issues() -> None: test_names_no_warns, test_results_no_warns = extract_log_issues(log_list_no_warning) assert test_names_no_warns == [] assert test_results_no_warns == [] + + +def test_extract_dbt_runner_issues(): + """Tests that the function extracts the correct test names and results from a dbt runner result + for only warnings. + """ + runner_result = MagicMock() + runner_result.result.results = [ + MagicMock(status="pass"), + MagicMock(status="warn", message="A warning message", node=MagicMock()), + MagicMock(status="pass"), + MagicMock(status="warn", message="A different warning message", node=MagicMock()), + ] + runner_result.result.results[1].node.name = "a_test" + runner_result.result.results[3].node.name = "another_test" + + test_names, test_results = extract_dbt_runner_issues(runner_result) + + assert test_names == ["a_test", "another_test"] + assert test_results == ["A warning message", "A different warning message"] diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index babb425ef0..357efe0b0f 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -31,8 +31,13 @@ DbtRunOperationLocalOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode from tests.utils import test_dag as run_test_dag - +from cosmos.dbt.parser.output import ( + extract_dbt_runner_issues, + parse_number_of_warnings_subprocess, + parse_number_of_warnings_dbt_runner, +) DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" MINI_DBT_PROJ_DIR = Path(__file__).parent.parent / "sample/mini" @@ -122,6 +127,24 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: assert cmd[-1] == "cmd" +@pytest.mark.parametrize( + "invocation_mode, invoke_dbt_method, handle_exception_method", + [ + (InvocationMode.SUBPROCESS, "run_subprocess", "handle_exception_subprocess"), + (InvocationMode.DBT_RUNNER, "run_dbt_runner", "handle_exception_dbt_runner"), + ], +) +def test_dbt_base_operator_invocation_methods_set(invocation_mode, invoke_dbt_method, handle_exception_method): + """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and + DbtLocalBaseOperator.handle_exception based on the invocation mode passed. + """ + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode + ) + assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method + assert dbt_base_operator.handle_exception.__name__ == handle_exception_method + + @pytest.mark.parametrize( "indirect_selection_type", [None, "cautious", "buildable", "empty"], @@ -145,6 +168,58 @@ def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> No assert cmd[1] == "cmd" +def test_dbt_base_operator_run_dbt_runner_cannot_import(): + """Tests that the right error message is raised if dbtRunner cannot be imported.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + expected_error_message = "Could not import dbt core. Ensure that dbt is installed and available in the environment where the operator is running." + with patch.dict(sys.modules, {"dbt.cli.main": None}): + with pytest.raises(ImportError, match=expected_error_message): + dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") + + +def test_dbt_base_operator_run_dbt_runner(): + """Tests that dbtRunner.invoke() is called with the expected cli args.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + full_dbt_cmd = ["dbt", "run", "some_model"] + + mock_dbt = MagicMock() + with patch.dict(sys.modules, {"dbt.cli.main": mock_dbt}): + dbt_base_operator.run_dbt_runner(command=full_dbt_cmd, env={}, cwd="some-dir") + + mock_dbt_runner = mock_dbt.dbtRunner.return_value + expected_cli_args = ["run", "some_model", "--project-dir", "some-dir"] + + assert mock_dbt_runner.invoke.call_count == 1 + assert mock_dbt_runner.invoke.call_args[0][0] == expected_cli_args + + +def test_dbt_base_operator_run_dbt_runner_is_cached(): + """Tests that if run_dbt_runner is called multiple times a cached runner is used.""" + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + ) + mock_dbt = MagicMock() + with patch.dict(sys.modules, {"dbt.cli.main": mock_dbt}): + for _ in range(3): + dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") + mock_dbt_runner = mock_dbt.dbtRunner + assert mock_dbt_runner.call_count == 1 + assert dbt_base_operator._dbt_runner is not None + + @pytest.mark.parametrize( ["skip_exception", "exception_code_returned", "expected_exception"], [ @@ -166,9 +241,24 @@ def test_dbt_base_operator_exception_handling(skip_exception, exception_code_ret ) if expected_exception: with pytest.raises(expected_exception): - dbt_base_operator.exception_handling(SubprocessResult(exception_code_returned, None)) + dbt_base_operator.handle_exception(SubprocessResult(exception_code_returned, None)) else: - dbt_base_operator.exception_handling(SubprocessResult(exception_code_returned, None)) + dbt_base_operator.handle_exception(SubprocessResult(exception_code_returned, None)) + + +def test_dbt_base_operator_handle_exception_dbt_runner(): + """Tests that an AirflowException is raised if the dbtRunner result is not successful.""" + operator = ConcreteDbtLocalBaseOperator( + profile_config=MagicMock(), + task_id="my-task", + project_dir="my/dir", + ) + result = MagicMock() + result.success = False + expected_error_message = "dbt command failed. See logs above for details." + + with pytest.raises(AirflowException, match=expected_error_message): + operator.handle_exception_dbt_runner(result) @patch("cosmos.operators.base.context_to_airflow_vars") @@ -201,6 +291,31 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None assert env == expected_env +@patch("cosmos.operators.local.extract_log_issues") +def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issues): + # test subprocess invocation mode + operator = DbtTestLocalOperator( + profile_config=profile_config, + invocation_mode=InvocationMode.SUBPROCESS, + task_id="my-task", + project_dir="my/dir", + ) + assert operator.parse_number_of_warnings == parse_number_of_warnings_subprocess + result = MagicMock(full_output="some output") + operator.extract_issues(result) + mock_extract_log_issues.assert_called_once_with("some output") + + # test dbt runner invocation mode + operator = DbtTestLocalOperator( + profile_config=profile_config, + invocation_mode=InvocationMode.DBT_RUNNER, + task_id="my-task", + project_dir="my/dir", + ) + assert operator.extract_issues == extract_dbt_runner_issues + assert operator.parse_number_of_warnings == parse_number_of_warnings_dbt_runner + + @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4"), reason="Airflow DAG did not have datasets until the 2.4 release", @@ -235,7 +350,8 @@ def test_run_operator_dataset_inlets_and_outlets(): @pytest.mark.integration -def test_run_test_operator_with_callback(failing_test_dbt_project): +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +def test_run_test_operator_with_callback(invocation_mode, failing_test_dbt_project): on_warning_callback = MagicMock() with DAG("test-id-2", start_date=datetime(2022, 1, 1)) as dag: @@ -251,6 +367,7 @@ def test_run_test_operator_with_callback(failing_test_dbt_project): task_id="test", append_env=True, on_warning_callback=on_warning_callback, + invocation_mode=invocation_mode, ) run_operator >> test_operator run_test_dag(dag) @@ -473,7 +590,7 @@ def test_dbt_docs_gcs_local_operator(): @patch("cosmos.operators.local.DbtLocalBaseOperator.store_compiled_sql") -@patch("cosmos.operators.local.DbtLocalBaseOperator.exception_handling") +@patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.config.ProfileConfig.ensure_profile") @patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess") def test_operator_execute_deps_parameters( diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 86796308b1..1ac508e3a9 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -25,7 +25,7 @@ class ConcreteDbtVirtualenvBaseOperator(DbtVirtualenvBaseOperator): @patch("airflow.utils.python_virtualenv.execute_in_subprocess") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.calculate_openlineage_events_completes") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.store_compiled_sql") -@patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.exception_handling") +@patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.subprocess_hook") @patch("airflow.hooks.base.BaseHook.get_connection") def test_run_command( From 77cd3cb6a46af0ef376c39bb8a9d33471294b776 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 10:46:24 -0800 Subject: [PATCH 04/16] improve type hints for subprocess hooks --- cosmos/operators/local.py | 50 ++++++++++++++++++------------ cosmos/operators/virtualenv.py | 9 ++++-- tests/operators/test_virtualenv.py | 12 +++---- tests/test_converter.py | 12 ++++--- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index fac2a76f5d..99a0357c2a 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -130,9 +130,15 @@ def __init__( self.should_store_compiled_sql = should_store_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] self.invocation_mode = invocation_mode - self.invoke_dbt = getattr(self, f"run_{invocation_mode.value}") - self.handle_exception = getattr(self, f"handle_exception_{invocation_mode.value}") - self._dbt_runner: dbtRunner | None = None + self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] + self.handle_exception: Callable[..., None] + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.invoke_dbt = self.run_subprocess + self.handle_exception = self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.invoke_dbt = self.run_dbt_runner + self.handle_exception = self.handle_exception_dbt_runner + self._dbt_runner: dbtRunner | None = None kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) @@ -202,12 +208,18 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se else: logger.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") - def run_subprocess(self, *args: Any, **kwargs: Any) -> FullOutputSubprocessResult: - subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: + logger.info("Trying to run the command:\n %s\nFrom %s", command, cwd) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command( + command=command, + env=env, + cwd=cwd, + output_encoding=self.output_encoding, + ) logger.info(subprocess_result.output) return subprocess_result - def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> dbtRunnerResult: + def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult: """Invokes the dbt command programmatically.""" try: from dbt.cli.main import dbtRunner @@ -221,7 +233,8 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str, **kw # The dbt runner will cd into the project directory and restore the cwd when completed cli_args = command[1:] + ["--project-dir", cwd] - with environ(env): + logger.info("Trying to run dbtRunner with:\n %s", cli_args) + with environ({k: str(v) for k, v in env.items()}): result = self._dbt_runner.invoke(cli_args) return result @@ -240,7 +253,7 @@ def run_command( tmp_project_dir, self.project_dir, ) - + env = {k: str(v) for k, v in env.items()} create_symlinks(Path(self.project_dir), Path(tmp_project_dir), self.install_deps) with self.profile_config.ensure_profile() as profile_values: @@ -262,18 +275,15 @@ def run_command( self.invoke_dbt( command=deps_command, env=env, - output_encoding=self.output_encoding, cwd=tmp_project_dir, ) full_cmd = cmd + flags - logger.info("Trying to run the command:\n %s\nFrom %s", full_cmd, tmp_project_dir) logger.info("Using environment variables keys: %s", env.keys()) result = self.invoke_dbt( command=full_cmd, env=env, - output_encoding=self.output_encoding, cwd=tmp_project_dir, ) if is_openlineage_available: @@ -294,7 +304,7 @@ def run_command( if self.callback: self.callback(tmp_project_dir) - return result # type: ignore + return result def calculate_openlineage_events_completes( self, env: dict[str, str | os.PathLike[Any] | bytes], project_dir: Path @@ -465,14 +475,14 @@ def __init__( ) -> None: super().__init__(**kwargs) self.on_warning_callback = on_warning_callback - self.extract_issues = { - InvocationMode.SUBPROCESS: lambda result: extract_log_issues(result.full_output), - InvocationMode.DBT_RUNNER: extract_dbt_runner_issues, - }[self.invocation_mode] - self.parse_number_of_warnings = { - InvocationMode.SUBPROCESS: parse_number_of_warnings_subprocess, - InvocationMode.DBT_RUNNER: parse_number_of_warnings_dbt_runner, - }[self.invocation_mode] + self.extract_issues: Callable[..., tuple[list[str], list[str]]] + self.parse_number_of_warnings: Callable[..., int] + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = lambda result: extract_log_issues(result.full_output) + self.parse_number_of_warnings = parse_number_of_warnings_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = extract_dbt_runner_issues + self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index bad88e2346..6612ab8b83 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -85,11 +85,16 @@ def venv_dbt_path( self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary) return str(dbt_binary) - def run_subprocess(self, *args: Any, command: list[str], **kwargs: Any) -> FullOutputSubprocessResult: + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: if self.py_requirements: command[0] = self.venv_dbt_path - subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(command, *args, **kwargs) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command( + command=command, + env=env, + cwd=cwd, + output_encoding=self.output_encoding, + ) return subprocess_result def execute(self, context: Context) -> None: diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 1ac508e3a9..9d180b3e26 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -60,12 +60,12 @@ def test_run_command( run_command_args = mock_subprocess_hook.run_command.call_args_list assert len(run_command_args) == 3 python_cmd = run_command_args[0] - dbt_deps = run_command_args[1] - dbt_cmd = run_command_args[2] + dbt_deps = run_command_args[1].kwargs + dbt_cmd = run_command_args[2].kwargs assert python_cmd[0][0][0].endswith("/bin/python") assert python_cmd[0][-1][-1] == "from importlib.metadata import version; print(version('dbt-core'))" - assert dbt_deps[0][0][1] == "deps" - assert dbt_deps[0][0][0].endswith("/bin/dbt") - assert dbt_deps[0][0][0] == dbt_cmd[0][0][0] - assert dbt_cmd[0][0][1] == "do-something" + assert dbt_deps["command"][1] == "deps" + assert dbt_deps["command"][0].endswith("/bin/dbt") + assert dbt_deps["command"][0] == dbt_cmd["command"][0] + assert dbt_cmd["command"][1] == "do-something" assert mock_execute.call_count == 2 diff --git a/tests/test_converter.py b/tests/test_converter.py index 7becc3a8be..16edf02547 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -409,14 +409,16 @@ def test_converter_project_config_dbt_vars_with_custom_load_mode( @pytest.mark.parametrize("invocation_mode", [None, InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) @patch("cosmos.config.ProjectConfig.validate_project") +@patch("cosmos.converter.validate_initial_user_config") +@patch("cosmos.converter.DbtGraph") @patch("cosmos.converter.build_airflow_graph") -def test_converter_invocation_mode_added_to_task_args(mock_build_airflow_graph, mock_validate_project, invocation_mode): - """Tests that the `task_args` passed to build_airflow_graph has invocation_mode if - it is not None. - """ +def test_converter_invocation_mode_added_to_task_args( + mock_build_airflow_graph, mock_user_config, mock_dbt_graph, mock_validate_project, invocation_mode +): + """Tests that the `task_args` passed to build_airflow_graph has invocation_mode if it is not None.""" project_config = ProjectConfig(project_name="fake-project", dbt_project_path="/some/project/path") execution_config = ExecutionConfig(invocation_mode=invocation_mode) - render_config = RenderConfig() + render_config = MagicMock() profile_config = MagicMock() with DAG("test-id", start_date=datetime(2024, 1, 1)) as dag: From 7247487844a538e517d03e4b3acf95576a639699 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 13:04:53 -0800 Subject: [PATCH 05/16] add change_working_directory context manager and add integration tests --- cosmos/dbt/project.py | 13 +++++++++++++ cosmos/operators/local.py | 13 ++++++++----- dev/dags/basic_cosmos_task_group.py | 3 ++- tests/dbt/test_project.py | 17 ++++++++++++++++- tests/operators/test_local.py | 15 +++++++++++---- 5 files changed, 50 insertions(+), 11 deletions(-) diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index aff6ed03ec..cadc1fa82b 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -36,3 +36,16 @@ def environ(env_vars: dict[str, str]) -> Generator[None, None, None]: del os.environ[key] else: os.environ[key] = value + + +@contextmanager +def change_working_directory(path: str) -> Generator[None, None, None]: + """Temporarily changes the working directory to the given path, and then restores + back to the previous value on exit. + """ + previous_cwd = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(previous_cwd) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 99a0357c2a..dd0ab00500 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -60,7 +60,7 @@ parse_number_of_warnings_dbt_runner, parse_number_of_warnings_subprocess, ) -from cosmos.dbt.project import create_symlinks, environ +from cosmos.dbt.project import create_symlinks, environ, change_working_directory logger = get_logger(__name__) @@ -231,11 +231,14 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> d if self._dbt_runner is None: self._dbt_runner = dbtRunner() - # The dbt runner will cd into the project directory and restore the cwd when completed - cli_args = command[1:] + ["--project-dir", cwd] - logger.info("Trying to run dbtRunner with:\n %s", cli_args) - with environ({k: str(v) for k, v in env.items()}): + # Exclude the dbt executable path from the command + cli_args = command[1:] + + logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd) + + with change_working_directory(cwd), environ(env): result = self._dbt_runner.invoke(cli_args) + return result def run_command( diff --git a/dev/dags/basic_cosmos_task_group.py b/dev/dags/basic_cosmos_task_group.py index 4b6aae71e1..06b24f2918 100644 --- a/dev/dags/basic_cosmos_task_group.py +++ b/dev/dags/basic_cosmos_task_group.py @@ -12,6 +12,7 @@ from cosmos import DbtTaskGroup, ProjectConfig, ProfileConfig, RenderConfig, ExecutionConfig from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -25,7 +26,7 @@ ), ) -shared_execution_config = ExecutionConfig() +shared_execution_config = ExecutionConfig(invocation_mode=InvocationMode.DBT_RUNNER) @dag( diff --git a/tests/dbt/test_project.py b/tests/dbt/test_project.py index 000ad06bdc..d965e71f28 100644 --- a/tests/dbt/test_project.py +++ b/tests/dbt/test_project.py @@ -1,5 +1,5 @@ from pathlib import Path -from cosmos.dbt.project import create_symlinks, environ +from cosmos.dbt.project import create_symlinks, environ, change_working_directory import os from unittest.mock import patch @@ -33,3 +33,18 @@ def test_environ_context_manager(): # Check if the original environment variables are still set assert "value1" == os.environ.get("VAR1") assert "value2" == os.environ.get("VAR2") + + +@patch("os.chdir") +def test_change_working_directory(mock_chdir): + """Tests that the working directory is changed and then restored correctly.""" + # Define the path to change the working directory to + path = "/path/to/directory" + + # Use the change_working_directory context manager + with change_working_directory(path): + # Check if os.chdir is called with the correct path + mock_chdir.assert_called_once_with(path) + + # Check if os.chdir is called with the previous working directory + mock_chdir.assert_called_with(os.getcwd()) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 357efe0b0f..d716512426 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -182,7 +182,8 @@ def test_dbt_base_operator_run_dbt_runner_cannot_import(): dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") -def test_dbt_base_operator_run_dbt_runner(): +@patch("cosmos.dbt.project.os.chdir") +def test_dbt_base_operator_run_dbt_runner(mock_chdir): """Tests that dbtRunner.invoke() is called with the expected cli args.""" dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, @@ -197,13 +198,16 @@ def test_dbt_base_operator_run_dbt_runner(): dbt_base_operator.run_dbt_runner(command=full_dbt_cmd, env={}, cwd="some-dir") mock_dbt_runner = mock_dbt.dbtRunner.return_value - expected_cli_args = ["run", "some_model", "--project-dir", "some-dir"] + expected_cli_args = ["run", "some_model"] assert mock_dbt_runner.invoke.call_count == 1 assert mock_dbt_runner.invoke.call_args[0][0] == expected_cli_args + assert mock_chdir.call_count == 2 + assert mock_chdir.call_args_list[0][0][0] == "some-dir" -def test_dbt_base_operator_run_dbt_runner_is_cached(): +@patch("cosmos.dbt.project.os.chdir") +def test_dbt_base_operator_run_dbt_runner_is_cached(mock_chdir): """Tests that if run_dbt_runner is called multiple times a cached runner is used.""" dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, @@ -375,7 +379,8 @@ def test_run_test_operator_with_callback(invocation_mode, failing_test_dbt_proje @pytest.mark.integration -def test_run_test_operator_without_callback(): +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) +def test_run_test_operator_without_callback(invocation_mode): on_warning_callback = MagicMock() with DAG("test-id-3", start_date=datetime(2022, 1, 1)) as dag: @@ -384,6 +389,7 @@ def test_run_test_operator_without_callback(): project_dir=MINI_DBT_PROJ_DIR, task_id="run", append_env=True, + invocation_mode=invocation_mode, ) test_operator = DbtTestLocalOperator( profile_config=mini_profile_config, @@ -391,6 +397,7 @@ def test_run_test_operator_without_callback(): task_id="test", append_env=True, on_warning_callback=on_warning_callback, + invocation_mode=invocation_mode, ) run_operator >> test_operator run_test_dag(dag) From a1c5da0524877db2ddc40a6aababa4676612aaff Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 13:07:09 -0800 Subject: [PATCH 06/16] rm duplicate import --- cosmos/operators/local.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index dd0ab00500..94290ae59c 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -33,7 +33,6 @@ from airflow.datasets import Dataset # noqa: F811 from openlineage.client.run import RunEvent from dbt.cli.main import dbtRunner, dbtRunnerResult - from typing import Callable from sqlalchemy.orm import Session From 25e669ec816d746ace3a20d2a6e8080470b33184 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 13:15:20 -0800 Subject: [PATCH 07/16] add test coverage for env vars context --- tests/operators/test_local.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index d716512426..44284cc997 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -182,9 +182,11 @@ def test_dbt_base_operator_run_dbt_runner_cannot_import(): dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") +@patch("cosmos.dbt.project.os.environ") @patch("cosmos.dbt.project.os.chdir") -def test_dbt_base_operator_run_dbt_runner(mock_chdir): - """Tests that dbtRunner.invoke() is called with the expected cli args.""" +def test_dbt_base_operator_run_dbt_runner(mock_chdir, mock_environ): + """Tests that dbtRunner.invoke() is called with the expected cli args, that the + cwd is changed to the expected directory, and env variables are set.""" dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", @@ -192,18 +194,23 @@ def test_dbt_base_operator_run_dbt_runner(mock_chdir): invocation_mode=InvocationMode.DBT_RUNNER, ) full_dbt_cmd = ["dbt", "run", "some_model"] + env_vars = {"VAR1": "value1", "VAR2": "value2"} mock_dbt = MagicMock() with patch.dict(sys.modules, {"dbt.cli.main": mock_dbt}): - dbt_base_operator.run_dbt_runner(command=full_dbt_cmd, env={}, cwd="some-dir") + dbt_base_operator.run_dbt_runner(command=full_dbt_cmd, env=env_vars, cwd="some-dir") mock_dbt_runner = mock_dbt.dbtRunner.return_value expected_cli_args = ["run", "some_model"] - + # Assert dbtRunner.invoke was called with the expected cli args assert mock_dbt_runner.invoke.call_count == 1 assert mock_dbt_runner.invoke.call_args[0][0] == expected_cli_args + # Assert cwd was changed to the expected directory assert mock_chdir.call_count == 2 assert mock_chdir.call_args_list[0][0][0] == "some-dir" + # Assert env variables were updated + assert mock_environ.update.call_count == 1 + assert mock_environ.update.call_args[0][0] == env_vars @patch("cosmos.dbt.project.os.chdir") From 2adda9c8cc24b103e10dda141ab07cf46bc4faca Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 14:23:14 -0800 Subject: [PATCH 08/16] add invocation mode to docs --- docs/configuration/execution-config.rst | 1 + docs/getting_started/execution-modes.rst | 27 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/docs/configuration/execution-config.rst b/docs/configuration/execution-config.rst index c118590d85..dc24c22806 100644 --- a/docs/configuration/execution-config.rst +++ b/docs/configuration/execution-config.rst @@ -7,6 +7,7 @@ It does this by exposing a ``cosmos.config.ExecutionConfig`` class that you can The ``ExecutionConfig`` class takes the following arguments: - ``execution_mode``: The way dbt is run when executing within airflow. For more information, see the `execution modes <../getting_started/execution-modes.html>`_ page. +- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. - ``test_indirect_selection``: The mode to configure the test behavior when performing indirect selection. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index 924e4ba129..7f081811b2 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -180,3 +180,30 @@ Each task will create a new container on Azure, giving full isolation. This, how "image": "dbt-jaffle-shop:1.0.0", }, ) + + +.. _invocation_modes: +Invocation Modes +================ +.. versionadded:: 1.4 + +For ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV`` execution modes, Cosmos supports two invocation modes for running dbt: + +1. ``InvocationMode.SUBPROCESS``: This is currently the default mode and does not need to be specified. In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. + +2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations `__ to run dbt commands. \ + In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. + +The invocation mode can be set in the ``ExecutionConfig`` as shown below: + +.. code-block:: python + + from cosmos.constants import InvocationMode + + dag = DbtDag( + # ... + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.LOCAL, + invocation_mode=InvocationMode.DBT_RUNNER, + ), + ) From 15da8882f51bb240f9039791c56b2c4f22edb1e9 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 6 Feb 2024 14:39:52 -0800 Subject: [PATCH 09/16] fix: test coverage --- tests/dbt/parser/test_output.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/dbt/parser/test_output.py b/tests/dbt/parser/test_output.py index 0d33198ff6..9adb34ba43 100644 --- a/tests/dbt/parser/test_output.py +++ b/tests/dbt/parser/test_output.py @@ -1,4 +1,5 @@ import pytest +import logging from unittest.mock import MagicMock from airflow.hooks.subprocess import SubprocessResult @@ -19,12 +20,23 @@ ("Nothing to do. Exiting without running tests.", 0), ], ) -def test_parse_number_of_warnings_subprocess(output_str: str, expected_warnings) -> None: +def test_parse_number_of_warnings_subprocess(output_str: str, expected_warnings): result = SubprocessResult(exit_code=0, output=output_str) num_warns = parse_number_of_warnings_subprocess(result) assert num_warns == expected_warnings +def test_parse_number_of_warnings_subprocess_error_logged(caplog): + output_str = "WARN= should log an error." + with caplog.at_level(logging.ERROR): + result = SubprocessResult(exit_code=0, output=output_str) + parse_number_of_warnings_subprocess(result) + expected_error_log = ( + "Could not parse number of WARNs. Check your dbt/airflow version or if --quiet is not being used" + ) + assert expected_error_log in caplog.text + + def test_parse_number_of_warnings_dbt_runner_with_warnings(): runner_result = MagicMock() runner_result.result.results = [ From 1a90a3fa3e2d1a5dd3c7c5727253a014a1a7fdf4 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Mon, 12 Feb 2024 17:40:58 -0800 Subject: [PATCH 10/16] on_kill check for InvocationMode.SUBPROCESS --- cosmos/operators/local.py | 13 ++++++------ tests/operators/test_local.py | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 94290ae59c..27d5cd4db9 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -421,12 +421,13 @@ def build_and_run_cmd( return result def on_kill(self) -> None: - if self.cancel_query_on_kill: - self.subprocess_hook.log.info("Sending SIGINT signal to process group") - if self.subprocess_hook.sub_process and hasattr(self.subprocess_hook.sub_process, "pid"): - os.killpg(os.getpgid(self.subprocess_hook.sub_process.pid), signal.SIGINT) - else: - self.subprocess_hook.send_sigterm() + if self.invocation_mode == InvocationMode.SUBPROCESS: + if self.cancel_query_on_kill: + self.subprocess_hook.log.info("Sending SIGINT signal to process group") + if self.subprocess_hook.sub_process and hasattr(self.subprocess_hook.sub_process, "pid"): + os.killpg(os.getpgid(self.subprocess_hook.sub_process.pid), signal.SIGINT) + else: + self.subprocess_hook.send_sigterm() class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 44284cc997..41a00ce1cd 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -3,6 +3,7 @@ import sys import shutil import tempfile +import signal from pathlib import Path from unittest.mock import MagicMock, patch, call @@ -642,3 +643,39 @@ def test_dbt_docs_local_operator_with_static_flag(): dbt_cmd_flags=["--static"], ) assert operator.required_files == ["static_index.html"] + + +@patch("cosmos.operators.local.os.killpg") +@patch("cosmos.operators.local.os.getpgid", return_value=11111) +def test_on_kill_subprocess_cancel_query_on_kill_true(mock_getpgid, mock_killpg): + operator = ConcreteDbtLocalBaseOperator( + task_id="my-task", + profile_config=profile_config, + project_dir="my/dir", + invocation_mode=InvocationMode.SUBPROCESS, + cancel_query_on_kill=True, + ) + operator.subprocess_hook = MagicMock() + operator.subprocess_hook.sub_process = MagicMock() + operator.subprocess_hook.sub_process.pid = 12345 + + operator.on_kill() + + mock_getpgid.assert_called_once_with(12345) + mock_killpg.assert_called_once_with(11111, signal.SIGINT) + + +def test_on_kill_subprocess_cancel_query_on_kill_false(): + operator = ConcreteDbtLocalBaseOperator( + task_id="my-task", + profile_config=profile_config, + project_dir="my/dir", + invocation_mode=InvocationMode.SUBPROCESS, + cancel_query_on_kill=False, + ) + operator.subprocess_hook = MagicMock() + + with patch.object(operator.subprocess_hook, "send_sigterm") as mock_send_sigterm: + operator.on_kill() + + mock_send_sigterm.assert_called_once() From f0e03be040bdf7e94e6469bbdc33494aa4fa2893 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Thu, 15 Feb 2024 11:31:45 -0800 Subject: [PATCH 11/16] add note of dbt >= v1.50 requirement --- cosmos/operators/local.py | 2 +- docs/getting_started/execution-modes.rst | 3 ++- tests/operators/test_local.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 27d5cd4db9..7ca42c7a20 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -224,7 +224,7 @@ def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> d from dbt.cli.main import dbtRunner except ImportError: raise ImportError( - "Could not import dbt core. Ensure that dbt is installed and available in the environment where the operator is running." + "Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running." ) if self._dbt_runner is None: diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index 7f081811b2..f34a23c323 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -192,7 +192,8 @@ For ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV`` execution modes, Co 1. ``InvocationMode.SUBPROCESS``: This is currently the default mode and does not need to be specified. In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. 2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations `__ to run dbt commands. \ - In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. + In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. \ + This mode requires dbt version 1.5.0 or higher. The invocation mode can be set in the ``ExecutionConfig`` as shown below: diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 5e6227a5e2..db0310d67c 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -177,7 +177,7 @@ def test_dbt_base_operator_run_dbt_runner_cannot_import(): project_dir="my/dir", invocation_mode=InvocationMode.DBT_RUNNER, ) - expected_error_message = "Could not import dbt core. Ensure that dbt is installed and available in the environment where the operator is running." + expected_error_message = "Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running." with patch.dict(sys.modules, {"dbt.cli.main": None}): with pytest.raises(ImportError, match=expected_error_message): dbt_base_operator.run_dbt_runner(command=["cmd"], env={}, cwd="some-project") From 7a3b182e8104bab23a0a60fc7238090661e2a67d Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Fri, 16 Feb 2024 16:04:00 -0800 Subject: [PATCH 12/16] update perf dag to use postgres to allow latest dbt-core --- .github/workflows/test.yml | 20 ++++++++++++++++++-- dev/dags/dbt/perf/profiles.yml | 17 +++++++++-------- dev/dags/performance_dag.py | 15 +++++++++------ scripts/test/performance-setup.sh | 4 ++-- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b214ce119..1b930f6c23 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -299,7 +299,18 @@ jobs: python-version: ["3.11"] airflow-version: ["2.7"] num-models: [1, 10, 50, 100] - + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 steps: - uses: actions/checkout@v3 with: @@ -336,7 +347,12 @@ jobs: AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT: 90.0 PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH MODEL_COUNT: ${{ matrix.num-models }} - + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + POSTGRES_SCHEMA: public + POSTGRES_PORT: 5432 env: AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/ AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres diff --git a/dev/dags/dbt/perf/profiles.yml b/dev/dags/dbt/perf/profiles.yml index 5b3cf175d5..224f565f4a 100644 --- a/dev/dags/dbt/perf/profiles.yml +++ b/dev/dags/dbt/perf/profiles.yml @@ -1,11 +1,12 @@ -simple: +default: target: dev outputs: dev: - type: sqlite - threads: 1 - database: "database" - schema: "main" - schemas_and_paths: - main: "{{ env_var('DBT_SQLITE_PATH') }}/imdb.db" - schema_directory: "{{ env_var('DBT_SQLITE_PATH') }}" + type: postgres + host: "{{ env_var('POSTGRES_HOST') }}" + user: "{{ env_var('POSTGRES_USER') }}" + password: "{{ env_var('POSTGRES_PASSWORD') }}" + port: "{{ env_var('POSTGRES_PORT') | int }}" + dbname: "{{ env_var('POSTGRES_DB') }}" + schema: "{{ env_var('POSTGRES_SCHEMA') }}" + threads: 4 diff --git a/dev/dags/performance_dag.py b/dev/dags/performance_dag.py index caf977817d..fec5175c81 100644 --- a/dev/dags/performance_dag.py +++ b/dev/dags/performance_dag.py @@ -1,28 +1,31 @@ """ -A DAG that uses Cosmos to render a dbt project for performance testing. +An airflow DAG that uses Cosmos to render a dbt project for performance testing. """ -import airflow from datetime import datetime import os from pathlib import Path from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig +from cosmos.profiles import PostgresUserPasswordProfileMapping + DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) -DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data") + profile_config = ProfileConfig( - profile_name="simple", + profile_name="default", target_name="dev", - profiles_yml_filepath=(DBT_ROOT_PATH / "simple/profiles.yml"), + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), ) cosmos_perf_dag = DbtDag( project_config=ProjectConfig( DBT_ROOT_PATH / "perf", - env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH}, ), profile_config=profile_config, render_config=RenderConfig( diff --git a/scripts/test/performance-setup.sh b/scripts/test/performance-setup.sh index b8bce035c0..7efb917c1e 100644 --- a/scripts/test/performance-setup.sh +++ b/scripts/test/performance-setup.sh @@ -1,4 +1,4 @@ -pip uninstall -y dbt-core dbt-sqlite openlineage-airflow openlineage-integration-common; \ +pip uninstall -y dbt-core dbt-sqlite dbt-postgres openlineage-airflow openlineage-integration-common; \ rm -rf airflow.*; \ airflow db init; \ -pip install 'dbt-core==1.4' 'dbt-sqlite<=1.4' 'dbt-databricks<=1.4' 'dbt-postgres<=1.4' +pip install 'dbt-postgres' From fd92032794bd07cc6257538811b9bdf157079b6b Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Fri, 16 Feb 2024 16:07:54 -0800 Subject: [PATCH 13/16] add invocation mode discovery if none selected --- cosmos/config.py | 9 ++-- cosmos/operators/local.py | 57 +++++++++++++++------ docs/configuration/execution-config.rst | 2 +- docs/getting_started/execution-modes.rst | 8 +-- tests/operators/test_local.py | 64 +++++++++++++++++++++--- tests/operators/test_virtualenv.py | 2 + tests/test_config.py | 2 +- 7 files changed, 112 insertions(+), 32 deletions(-) diff --git a/cosmos/config.py b/cosmos/config.py index 29fd131f8d..439b668624 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -297,8 +297,7 @@ class ExecutionConfig: Contains configuration about how to execute dbt. :param execution_mode: The execution mode for dbt. Defaults to local - :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV - execution modes. + :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL. :param test_indirect_selection: The mode to configure the test behavior when performing indirect selection. :param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path. :param dbt_project_path Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path @@ -313,8 +312,6 @@ class ExecutionConfig: project_path: Path | None = field(init=False) def __post_init__(self, dbt_project_path: str | Path | None) -> None: - if self.invocation_mode and self.execution_mode not in {ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV}: - raise CosmosValueError( - "ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV modes." - ) + if self.invocation_mode and self.execution_mode != ExecutionMode.LOCAL: + raise CosmosValueError("ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL.") self.project_path = Path(dbt_project_path) if dbt_project_path else None diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 7ca42c7a20..bf5c51ab02 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -116,7 +116,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, profile_config: ProfileConfig, - invocation_mode: InvocationMode = InvocationMode.SUBPROCESS, + invocation_mode: InvocationMode | None = None, install_deps: bool = False, callback: Callable[[str], None] | None = None, should_store_compiled_sql: bool = True, @@ -131,13 +131,9 @@ def __init__( self.invocation_mode = invocation_mode self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] self.handle_exception: Callable[..., None] - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.invoke_dbt = self.run_subprocess - self.handle_exception = self.handle_exception_subprocess - elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.invoke_dbt = self.run_dbt_runner - self.handle_exception = self.handle_exception_dbt_runner - self._dbt_runner: dbtRunner | None = None + self._dbt_runner: dbtRunner | None = None + if self.invocation_mode: + self._set_invocation_methods() kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) @@ -146,6 +142,32 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() + def _set_invocation_methods(self) -> None: + """Checks if the invocation mode is provided, then sets the associated run and exception handling methods. + If the invocation mode is not set, will try to import dbtRunner and fall back to subprocess. + """ + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.invoke_dbt = self.run_subprocess + self.handle_exception = self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.invoke_dbt = self.run_dbt_runner + self.handle_exception = self.handle_exception_dbt_runner + + def _discover_invocation_mode(self) -> None: + """Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will + be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess. + This method is called at runtime to work in the environment where the operator is running. + """ + try: + from dbt.cli.main import dbtRunner + except ImportError: + self.invocation_mode = InvocationMode.SUBPROCESS + logger.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.") + else: + self.invocation_mode = InvocationMode.DBT_RUNNER + logger.info("dbtRunner is available. Using dbtRunner for invoking dbt.") + self._set_invocation_methods() + def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.") @@ -249,6 +271,9 @@ def run_command( """ Copies the dbt project to a temporary directory and runs the command. """ + if not self.invocation_mode: + self._discover_invocation_mode() + with tempfile.TemporaryDirectory() as tmp_project_dir: logger.info( "Cloning project to writable temp directory %s from %s", @@ -480,12 +505,6 @@ def __init__( self.on_warning_callback = on_warning_callback self.extract_issues: Callable[..., tuple[list[str], list[str]]] self.parse_number_of_warnings: Callable[..., int] - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.extract_issues = lambda result: extract_log_issues(result.full_output) - self.parse_number_of_warnings = parse_number_of_warnings_subprocess - elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.extract_issues = extract_dbt_runner_issues - self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ @@ -503,8 +522,18 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, self.on_warning_callback and self.on_warning_callback(warning_context) + def _set_test_result_parsing_methods(self) -> None: + """Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode.""" + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = lambda result: extract_log_issues(result.full_output) + self.parse_number_of_warnings = parse_number_of_warnings_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = extract_dbt_runner_issues + self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner + def execute(self, context: Context) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) + self._set_test_result_parsing_methods() number_of_warnings = self.parse_number_of_warnings(result) # type: ignore if self.on_warning_callback and number_of_warnings > 0: self._handle_warnings(result, context) diff --git a/docs/configuration/execution-config.rst b/docs/configuration/execution-config.rst index aad40c19d2..dd9758d558 100644 --- a/docs/configuration/execution-config.rst +++ b/docs/configuration/execution-config.rst @@ -7,7 +7,7 @@ It does this by exposing a ``cosmos.config.ExecutionConfig`` class that you can The ``ExecutionConfig`` class takes the following arguments: - ``execution_mode``: The way dbt is run when executing within airflow. For more information, see the `execution modes <../getting_started/execution-modes.html>`_ page. -- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. +- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. - ``test_indirect_selection``: The mode to configure the test behavior when performing indirect selection. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. - ``dbt_project_path``: Configures the dbt project location accessible at runtime for dag execution. This is the project path in a docker container for ``ExecutionMode.DOCKER`` or ``ExecutionMode.KUBERNETES``. Mutually exclusive with ``ProjectConfig.dbt_project_path``. diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index f34a23c323..92c542a1d3 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -187,12 +187,12 @@ Invocation Modes ================ .. versionadded:: 1.4 -For ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV`` execution modes, Cosmos supports two invocation modes for running dbt: +For ``ExecutionMode.LOCAL`` execution mode, Cosmos supports two invocation modes for running dbt: -1. ``InvocationMode.SUBPROCESS``: This is currently the default mode and does not need to be specified. In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. +1. ``InvocationMode.SUBPROCESS``: In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. 2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations `__ to run dbt commands. \ - In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. \ + In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and is faster than ``InvocationMode.SUBPROCESS``. \ This mode requires dbt version 1.5.0 or higher. The invocation mode can be set in the ``ExecutionConfig`` as shown below: @@ -208,3 +208,5 @@ The invocation mode can be set in the ``ExecutionConfig`` as shown below: invocation_mode=InvocationMode.DBT_RUNNER, ), ) + +If the invocation mode is not set, Cosmos will attempt to use ``InvocationMode.DBT_RUNNER`` if dbt is installed in the same environment as the worker, otherwise it will default to ``InvocationMode.SUBPROCESS``. diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index db0310d67c..b33510f4b5 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -135,17 +135,44 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: (InvocationMode.DBT_RUNNER, "run_dbt_runner", "handle_exception_dbt_runner"), ], ) -def test_dbt_base_operator_invocation_methods_set(invocation_mode, invoke_dbt_method, handle_exception_method): +def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_method, handle_exception_method): """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and - DbtLocalBaseOperator.handle_exception based on the invocation mode passed. + DbtLocalBaseOperator.handle_exception when a known invocation mode passed. """ dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode ) + dbt_base_operator._set_invocation_methods() assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method assert dbt_base_operator.handle_exception.__name__ == handle_exception_method +@pytest.mark.parametrize( + "can_import_dbt, invoke_dbt_method, handle_exception_method", + [ + (False, "run_subprocess", "handle_exception_subprocess"), + (True, "run_dbt_runner", "handle_exception_dbt_runner"), + ], +) +def test_dbt_base_operator_discover_invocation_mode(can_import_dbt, invoke_dbt_method, handle_exception_method): + """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and + DbtLocalBaseOperator.handle_exception if dbt can be imported or not. + """ + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + with patch.dict(sys.modules, {"dbt.cli.main": MagicMock()} if can_import_dbt else {"dbt.cli.main": None}): + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + dbt_base_operator._discover_invocation_mode() + assert dbt_base_operator.invocation_mode == ( + InvocationMode.DBT_RUNNER if can_import_dbt else InvocationMode.SUBPROCESS + ) + assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method + assert dbt_base_operator.handle_exception.__name__ == handle_exception_method + + @pytest.mark.parametrize( "indirect_selection_type", [None, "cautious", "buildable", "empty"], @@ -245,11 +272,14 @@ def test_dbt_base_operator_run_dbt_runner_is_cached(mock_chdir): "No exception raised", ], ) -def test_dbt_base_operator_exception_handling(skip_exception, exception_code_returned, expected_exception) -> None: +def test_dbt_base_operator_exception_handling_subprocess( + skip_exception, exception_code_returned, expected_exception +) -> None: dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.SUBPROCESS, ) if expected_exception: with pytest.raises(expected_exception): @@ -304,7 +334,7 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None @patch("cosmos.operators.local.extract_log_issues") -def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issues): +def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues): # test subprocess invocation mode operator = DbtTestLocalOperator( profile_config=profile_config, @@ -312,6 +342,7 @@ def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issu task_id="my-task", project_dir="my/dir", ) + operator._set_test_result_parsing_methods() assert operator.parse_number_of_warnings == parse_number_of_warnings_subprocess result = MagicMock(full_output="some output") operator.extract_issues(result) @@ -324,6 +355,7 @@ def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issu task_id="my-task", project_dir="my/dir", ) + operator._set_test_result_parsing_methods() assert operator.extract_issues == extract_dbt_runner_issues assert operator.parse_number_of_warnings == parse_number_of_warnings_dbt_runner @@ -519,7 +551,13 @@ def test_store_compiled_sql() -> None: ) @patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd") def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs): - task = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir", **kwargs) + task = operator_class( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + **kwargs, + ) task.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs) @@ -548,6 +586,7 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, **operator_class_kwargs.get(operator_class, {}), ) task.execute(context={}) @@ -616,8 +655,15 @@ def test_dbt_docs_gcs_local_operator(): @patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.config.ProfileConfig.ensure_profile") @patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess") +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner") +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) def test_operator_execute_deps_parameters( - mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql + mock_dbt_runner, + mock_subprocess, + mock_ensure_profile, + mock_exception_handling, + mock_store_compiled_sql, + invocation_mode, ): expected_call_kwargs = [ "/usr/local/bin/dbt", @@ -636,10 +682,14 @@ def test_operator_execute_deps_parameters( install_deps=True, emit_datasets=False, dbt_executable_path="/usr/local/bin/dbt", + invocation_mode=invocation_mode, ) mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) - assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs + if invocation_mode == InvocationMode.SUBPROCESS: + assert mock_subprocess.call_args_list[0].kwargs["command"] == expected_call_kwargs + elif invocation_mode == InvocationMode.DBT_RUNNER: + mock_dbt_runner.all_args_list[0].kwargs["command"] == expected_call_kwargs def test_dbt_docs_local_operator_with_static_flag(): diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 9d180b3e26..036f162de2 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -7,6 +7,7 @@ from cosmos.config import ProfileConfig from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode profile_config = ProfileConfig( profile_name="default", @@ -53,6 +54,7 @@ def test_run_command( py_system_site_packages=False, py_requirements=["dbt-postgres==1.6.0b1"], emit_datasets=False, + invocation_mode=InvocationMode.SUBPROCESS, ) assert venv_operator._venv_tmp_dir is None # Otherwise we are creating empty directories during DAG parsing time # and not deleting them diff --git a/tests/test_config.py b/tests/test_config.py index d7c456938a..b93ad26275 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -203,7 +203,7 @@ def test_render_config_env_vars_deprecated(): "execution_mode, expectation", [ (ExecutionMode.LOCAL, does_not_raise()), - (ExecutionMode.VIRTUALENV, does_not_raise()), + (ExecutionMode.VIRTUALENV, pytest.raises(CosmosValueError)), (ExecutionMode.KUBERNETES, pytest.raises(CosmosValueError)), (ExecutionMode.DOCKER, pytest.raises(CosmosValueError)), (ExecutionMode.AZURE_CONTAINER_INSTANCE, pytest.raises(CosmosValueError)), From 56c5e845dfaab8f75bcab37b5e378a4e3dd8be87 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Fri, 16 Feb 2024 16:10:54 -0800 Subject: [PATCH 14/16] add branch to test performance dag updates --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b930f6c23..af64893497 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,7 @@ name: test on: push: # Run on pushes to the default branch - branches: [main] + branches: [main, 717-add-dbtrunner-local-executor] # # TODO:remove before merge pull_request_target: # Also run on pull requests originated from forks branches: [main] From 42217c6575347946c5113940ac9dad32cb04e1cf Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Fri, 16 Feb 2024 16:28:54 -0800 Subject: [PATCH 15/16] update test env vars --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index af64893497..73dcedf0c4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -346,13 +346,14 @@ jobs: AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT: 90.0 PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH - MODEL_COUNT: ${{ matrix.num-models }} + COSMOS_CONN_POSTGRES_PASSWORD: ${{ secrets.COSMOS_CONN_POSTGRES_PASSWORD }} POSTGRES_HOST: localhost POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + MODEL_COUNT: ${{ matrix.num-models }} env: AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/ AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres From f761a8a7f9b0a05830c314342cebb869ecab5435 Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Fri, 16 Feb 2024 16:59:37 -0800 Subject: [PATCH 16/16] try add authorize --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 73dcedf0c4..07c2b9bc33 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -293,6 +293,7 @@ jobs: PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH Run-Performance-Tests: + needs: Authorize runs-on: ubuntu-latest strategy: matrix: