diff --git a/cosmos/config.py b/cosmos/config.py index c5e7a69a30..3b332931fb 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -40,6 +40,7 @@ class RenderConfig: :param load_method: The parsing method for loading the dbt model. Defaults to AUTOMATIC :param select: A list of dbt select arguments (e.g. 'config.materialized:incremental') :param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly') + :param selector: Name of a dbt YAML selector to use for parsing. Only supported when using ``load_method=LoadMode.DBT_LS``. :param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing :param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. :param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. @@ -52,6 +53,7 @@ class RenderConfig: load_method: LoadMode = LoadMode.AUTOMATIC select: list[str] = field(default_factory=list) exclude: list[str] = field(default_factory=list) + selector: str | None = None dbt_deps: bool = True node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None dbt_executable_path: str | Path = get_system_dbt() diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 92ef5e66fb..e943d95270 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -195,6 +195,9 @@ def run_dbt_ls( if self.project.dbt_vars: ls_command.extend(["--vars", yaml.dump(self.project.dbt_vars)]) + if self.render_config.selector: + ls_command.extend(["--selector", self.render_config.selector]) + ls_command.extend(self.local_flags) stdout = run_command(ls_command, tmp_dir, env_vars) @@ -291,6 +294,11 @@ def load_via_custom_parser(self) -> None: """ logger.info("Trying to parse the dbt project `%s` using a custom Cosmos method...", self.project.project_name) + if self.render_config.selector: + raise CosmosLoadDbtException( + "RenderConfig.selector is not yet supported when loading dbt projects using the LoadMode.CUSTOM parser." + ) + if not self.render_config.project_path or not self.execution_config.project_path: raise CosmosLoadDbtException( "Unable to load dbt project without RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path" @@ -349,6 +357,11 @@ def load_from_dbt_manifest(self) -> None: """ logger.info("Trying to parse the dbt project `%s` using a dbt manifest...", self.project.project_name) + if self.render_config.selector: + raise CosmosLoadDbtException( + "RenderConfig.selector is not yet supported when loading dbt projects using the LoadMode.DBT_MANIFEST parser." + ) + if not self.project.is_manifest_available(): raise CosmosLoadDbtException(f"Unable to load manifest using {self.project.manifest_path}") diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index 1028ecf622..6d669d0a5d 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -11,6 +11,7 @@ The ``RenderConfig`` class takes the following arguments: - ``test_behavior``: how to run tests. Defaults to running a model's tests immediately after the model is run. For more information, see the `Testing Behavior `_ section. - ``load_method``: how to load your dbt project. See `Parsing Methods `_ for more information. - ``select`` and ``exclude``: which models to include or exclude from your DAGs. See `Selecting & Excluding `_ for more information. +- ``selector``: (new in v1.3) name of a dbt YAML selector to use for DAG parsing. Only supported when using ``load_method=LoadMode.DBT_LS``. See `Selecting & Excluding `_ for more information. - ``dbt_deps``: A Boolean to run dbt deps when using dbt ls for dag parsing. Default True - ``node_converters``: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. Find more information below. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. diff --git a/docs/configuration/selecting-excluding.rst b/docs/configuration/selecting-excluding.rst index dfa4a96c59..01ee536b0a 100644 --- a/docs/configuration/selecting-excluding.rst +++ b/docs/configuration/selecting-excluding.rst @@ -3,7 +3,13 @@ Selecting & Excluding ======================= -Cosmos allows you to filter to a subset of your dbt project in each ``DbtDag`` / ``DbtTaskGroup`` using the ``select`` and ``exclude`` parameters in the ``RenderConfig`` class. +Cosmos allows you to filter to a subset of your dbt project in each ``DbtDag`` / ``DbtTaskGroup`` using the ``select `` and ``exclude`` parameters in the ``RenderConfig`` class. + + Since Cosmos 1.3, the ``selector`` parameter is also available in ``RenderConfig`` when using the ``LoadMode.DBT_LS`` to parse the dbt project into Airflow. + + +Using ``select`` and ``exclude`` +-------------------------------- The ``select`` and ``exclude`` parameters are lists, with values like the following: @@ -84,3 +90,24 @@ Examples: exclude=["node_name+"], # node_name and its children ) ) + +Using ``selector`` +-------------------------------- +.. note:: + Only currently supported using the ``dbt_ls`` parsing method since Cosmos 1.3 where the selector is passed directly to the dbt CLI command. \ + If ``select`` and/or ``exclude`` are used with ``selector``, dbt will ignore the ``select`` and ``exclude`` parameters. + +The ``selector`` parameter is a string that references a `dbt YAML selector `_ already defined in a dbt project. + +Examples: + +.. code-block:: python + + from cosmos import DbtDag, RenderConfig, LoadMode + + jaffle_shop = DbtDag( + render_config=RenderConfig( + selector="my_selector", # this selector must be defined in your dbt project + load_method=LoadMode.DBT_LS, + ) + ) diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 3b80424b61..2816fd07a5 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -1,7 +1,7 @@ import shutil import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock import yaml import pytest @@ -44,6 +44,18 @@ def tmp_dbt_project_dir(): shutil.rmtree(tmp_dir, ignore_errors=True) # delete directory +@pytest.fixture +def postgres_profile_config() -> ProfileConfig: + return ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), + ) + + @pytest.mark.parametrize( "unique_id,expected_name, expected_select", [ @@ -220,7 +232,9 @@ def test_load( @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_popen, tmp_dbt_project_dir): +def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder( + mock_popen, tmp_dbt_project_dir, postgres_profile_config +): mock_popen().communicate.return_value = ("", "") mock_popen().returncode = 0 assert not (tmp_dbt_project_dir / "target").exists() @@ -233,14 +247,7 @@ def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_pop project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() assert not (tmp_dbt_project_dir / "target").exists() @@ -252,7 +259,7 @@ def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_pop @pytest.mark.integration -def test_load_via_dbt_ls_with_exclude(): +def test_load_via_dbt_ls_with_exclude(postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, select=["*customers*"], exclude=["*orders*"] @@ -262,14 +269,7 @@ def test_load_via_dbt_ls_with_exclude(): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() @@ -301,7 +301,7 @@ def test_load_via_dbt_ls_with_exclude(): @pytest.mark.integration @pytest.mark.parametrize("project_name", ("jaffle_shop", "jaffle_shop_python")) -def test_load_via_dbt_ls_without_exclude(project_name): +def test_load_via_dbt_ls_without_exclude(project_name, postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -309,14 +309,7 @@ def test_load_via_dbt_ls_without_exclude(project_name): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() @@ -413,7 +406,7 @@ def test_load_via_dbt_ls_with_sources(load_method): @pytest.mark.integration -def test_load_via_dbt_ls_without_dbt_deps(): +def test_load_via_dbt_ls_without_dbt_deps(postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, dbt_deps=False) execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -421,14 +414,7 @@ def test_load_via_dbt_ls_without_dbt_deps(): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) with pytest.raises(CosmosLoadDbtException) as err_info: @@ -439,7 +425,7 @@ def test_load_via_dbt_ls_without_dbt_deps(): @pytest.mark.integration -def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_project_dir): +def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_project_dir, postgres_profile_config): local_flags = [ "--project-dir", tmp_dbt_project_dir / DBT_PROJECT_NAME, @@ -469,14 +455,7 @@ def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_ project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() # does not raise exception @@ -484,7 +463,9 @@ def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_ @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, tmp_dbt_project_dir): +def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr( + mock_popen, tmp_dbt_project_dir, postgres_profile_config +): mock_popen().communicate.return_value = ("", "Some stderr warnings") mock_popen().returncode = 0 @@ -495,14 +476,7 @@ def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, t project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() # does not raise exception @@ -510,7 +484,7 @@ def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, t @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): +def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen, postgres_profile_config): mock_popen().communicate.return_value = ("", "Some stderr message") mock_popen().returncode = 1 @@ -521,14 +495,7 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome stderr message" with pytest.raises(CosmosLoadDbtException, match=expected): @@ -537,7 +504,7 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): @pytest.mark.integration @patch("cosmos.dbt.graph.Popen.communicate", return_value=("Some Runtime Error", "")) -def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate): +def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate, postgres_profile_config): # It may seem strange, but at least until dbt 1.6.0, there are circumstances when it outputs errors to stdout project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -546,14 +513,7 @@ def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome Runtime Error" with pytest.raises(CosmosLoadDbtException, match=expected): @@ -673,7 +633,7 @@ def test_tag_selected_node_test_exist(): @pytest.mark.integration @pytest.mark.parametrize("load_method", ["load_via_dbt_ls", "load_from_dbt_manifest"]) -def test_load_dbt_ls_and_manifest_with_model_version(load_method): +def test_load_dbt_ls_and_manifest_with_model_version(load_method, postgres_profile_config): dbt_graph = DbtGraph( project=ProjectConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version", @@ -681,14 +641,7 @@ def test_load_dbt_ls_and_manifest_with_model_version(load_method): ), render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version"), execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version"), - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) getattr(dbt_graph, load_method)() expected_dbt_nodes = { @@ -826,6 +779,55 @@ def test_load_via_dbt_ls_project_config_dbt_vars(mock_validate, mock_update_node assert ls_command[ls_command.index("--vars") + 1] == yaml.dump(dbt_vars) +@patch("cosmos.dbt.graph.Popen") +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency") +@patch("cosmos.config.RenderConfig.validate_dbt_command") +def test_load_via_dbt_ls_render_config_selector_arg_is_used( + mock_validate, mock_update_nodes, mock_popen, tmp_dbt_project_dir +): + """Tests that the dbt ls command in the subprocess has "--selector" with the RenderConfig.selector.""" + mock_popen().communicate.return_value = ("", "") + mock_popen().returncode = 0 + selector = "my_selector" + project_config = ProjectConfig() + render_config = RenderConfig( + dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, + load_method=LoadMode.DBT_LS, + selector=selector, + ) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + execution_config = MagicMock() + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=profile_config, + ) + dbt_graph.load_via_dbt_ls() + ls_command = mock_popen.call_args.args[0] + assert "--selector" in ls_command + assert ls_command[ls_command.index("--selector") + 1] == selector + + +@pytest.mark.parametrize("load_method", [LoadMode.DBT_MANIFEST, LoadMode.CUSTOM]) +def test_load_method_with_unsupported_render_config_selector_arg(load_method): + """Tests that error is raised when RenderConfig.selector is used with LoadMode.DBT_MANIFEST or LoadMode.CUSTOM.""" + + expected_error_msg = ( + f"RenderConfig.selector is not yet supported when loading dbt projects using the {load_method} parser." + ) + dbt_graph = DbtGraph( + render_config=RenderConfig(load_method=load_method, selector="my_selector"), + project=MagicMock(), + ) + with pytest.raises(CosmosLoadDbtException, match=expected_error_msg): + dbt_graph.load(method=load_method) + + @pytest.mark.sqlite @pytest.mark.integration def test_load_via_dbt_ls_with_project_config_vars(): @@ -853,3 +855,45 @@ def test_load_via_dbt_ls_with_project_config_vars(): ) dbt_graph.load_via_dbt_ls() assert dbt_graph.nodes["model.simple.top_animations"].config["alias"] == "top_5_animated_movies" + + +@pytest.mark.integration +def test_load_via_dbt_ls_with_selector_arg(tmp_dbt_project_dir, postgres_profile_config): + """ + Tests that the dbt ls load method is successful if a selector arg is used with RenderConfig + and that the filtered nodes are expected. + """ + # Add a selectors yaml file to the project that will select the stg_customers model and all + # parents (raw_customers) + selectors_yaml = """ + selectors: + - name: stage_customers + definition: + method: fqn + value: stg_customers + parents: true + """ + with open(tmp_dbt_project_dir / DBT_PROJECT_NAME / "selectors.yml", "w") as f: + f.write(selectors_yaml) + + project_config = ProjectConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + render_config = RenderConfig( + dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, + selector="stage_customers", + ) + + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=postgres_profile_config, + ) + dbt_graph.load_via_dbt_ls() + + filtered_nodes = dbt_graph.filtered_nodes.keys() + assert len(filtered_nodes) == 4 + assert "model.jaffle_shop.stg_customers" in filtered_nodes + assert "seed.jaffle_shop.raw_customers" in filtered_nodes + # Two tests should be filtered + assert sum(node.startswith("test.jaffle_shop") for node in filtered_nodes) == 2