Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ def _add_vars_arg(self, cmd_args: list[str]) -> None:
"""
Change args list in-place so they include dbt vars, if they are set.
"""
if self.project.dbt_vars:
cmd_args.extend(["--vars", json.dumps(self.project.dbt_vars, sort_keys=True)])
if self.dbt_vars:
cmd_args.extend(["--vars", json.dumps(self.dbt_vars, sort_keys=True)])

@cached_property
def dbt_ls_args(self) -> list[str]:
Expand Down
62 changes: 49 additions & 13 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,8 +1508,7 @@ def test_load_via_dbt_ls_project_config_dbt_vars(
"""Tests that the dbt ls command in the subprocess has "--vars" with the project config dbt_vars."""
mock_popen().communicate.return_value = ("", "")
mock_popen().returncode = 0
dbt_vars = {"my_var1": "my_value1", "my_var2": "my_value2"}
project_config = ProjectConfig(dbt_vars=dbt_vars)
project_config = ProjectConfig(dbt_vars={"my_var1": "my_value1", "my_var2": "my_value2"})
render_config = RenderConfig(
dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME,
source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR,
Expand All @@ -1528,8 +1527,42 @@ def test_load_via_dbt_ls_project_config_dbt_vars(
)
dbt_graph.load_via_dbt_ls()
ls_command = mock_popen.call_args.args[0]
assert "--vars" not in ls_command


@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=False)
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_dbt_graph_dbt_vars(
mock_validate, mock_update_nodes, mock_popen, mock_use_case, tmp_dbt_project_dir
):
"""Tests that the dbt ls command in the subprocess has "--vars" with the DbtGraph dbt_vars."""
mock_popen().communicate.return_value = ("", "")
mock_popen().returncode = 0
dbt_vars = {"my_var3": "my_value3"}
render_config = RenderConfig(
dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME,
source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR,
)
profile_config = ProfileConfig(
profile_name="test",
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(
project=ProjectConfig(),
render_config=render_config,
execution_config=execution_config,
profile_config=profile_config,
dbt_vars=dbt_vars,
)
dbt_graph.load_via_dbt_ls()
ls_command = mock_popen.call_args.args[0]
assert "--vars" in ls_command
assert ls_command[ls_command.index("--vars") + 1] == '{"my_var1": "my_value1", "my_var2": "my_value2"}'
assert ls_command[ls_command.index("--vars") + 1] == json.dumps(dbt_vars, sort_keys=True)


@patch("cosmos.dbt.graph.DbtGraph.should_use_dbt_ls_cache", return_value=False)
Expand Down Expand Up @@ -1721,21 +1754,23 @@ def test_project_path_fails():


@pytest.mark.parametrize(
"render_config,project_config,expected_dbt_ls_args",
"render_config,project_config,dbt_vars,expected_dbt_ls_args",
[
(RenderConfig(), ProjectConfig(), []),
(RenderConfig(exclude=["package:snowplow"]), ProjectConfig(), ["--exclude", "package:snowplow"]),
(RenderConfig(), ProjectConfig(), None, []),
(RenderConfig(exclude=["package:snowplow"]), ProjectConfig(), None, ["--exclude", "package:snowplow"]),
(
RenderConfig(select=["tag:prod", "config.materialized:incremental"]),
ProjectConfig(),
None,
["--select", "tag:prod", "config.materialized:incremental"],
),
(RenderConfig(selector="nightly"), ProjectConfig(), ["--selector", "nightly"]),
(RenderConfig(), ProjectConfig(dbt_vars={"a": 1}), ["--vars", '{"a": 1}']),
(RenderConfig(), ProjectConfig(partial_parse=False), ["--no-partial-parse"]),
(RenderConfig(selector="nightly"), ProjectConfig(), None, ["--selector", "nightly"]),
(RenderConfig(), ProjectConfig(dbt_vars={"a": 1}), {"k": "v"}, ["--vars", '{"k": "v"}']),
(RenderConfig(), ProjectConfig(partial_parse=False), None, ["--no-partial-parse"]),
(
RenderConfig(exclude=["1", "2"], select=["a", "b"], selector="nightly"),
ProjectConfig(dbt_vars={"a": 1}, partial_parse=False),
{"k": "v"},
[
"--exclude",
"1",
Expand All @@ -1744,18 +1779,19 @@ def test_project_path_fails():
"a",
"b",
"--vars",
'{"a": 1}',
'{"k": "v"}',
"--selector",
"nightly",
"--no-partial-parse",
],
),
],
)
def test_dbt_ls_args(render_config, project_config, expected_dbt_ls_args):
def test_dbt_ls_args(render_config, project_config, dbt_vars, expected_dbt_ls_args):
graph = DbtGraph(
project=project_config,
render_config=render_config,
dbt_vars=dbt_vars,
)
assert graph.dbt_ls_args == expected_dbt_ls_args

Expand All @@ -1768,8 +1804,8 @@ def test_dbt_ls_cache_key_args_sorts_envvars():

@patch("cosmos.dbt.graph.run_command")
def test_run_dbt_deps(run_command_mock):
project_config = ProjectConfig(dbt_vars={"var-key": "var-value"})
graph = DbtGraph(project=project_config)
project_config = ProjectConfig()
graph = DbtGraph(project=project_config, dbt_vars={"var-key": "var-value"})
graph.local_flags = []
graph.run_dbt_deps("dbt", "/some/path", {})
run_command_mock.assert_called_with(
Expand Down