Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def project_path(self) -> Path:
)
return path.absolute()

def _add_vars_arg(self, cmd_args: list[str]) -> None:
"""
Change args list in-place so they include dbt vars, if they are set.
"""
if self.project.dbt_vars:
cmd_args.extend(["--vars", json.dumps(self.project.dbt_vars, sort_keys=True)])

@cached_property
def dbt_ls_args(self) -> list[str]:
"""
Expand All @@ -213,8 +220,7 @@ def dbt_ls_args(self) -> list[str]:
if self.render_config.select:
ls_args.extend(["--select", *self.render_config.select])

if self.project.dbt_vars:
ls_args.extend(["--vars", json.dumps(self.project.dbt_vars, sort_keys=True)])
self._add_vars_arg(ls_args)

if self.render_config.selector:
ls_args.extend(["--selector", self.render_config.selector])
Expand Down Expand Up @@ -408,6 +414,16 @@ def should_use_partial_parse_cache(self) -> bool:
"""Identify if Cosmos should use/store dbt partial parse cache or not."""
return settings.enable_cache_partial_parse and settings.enable_cache and bool(self.cache_dir)

def run_dbt_deps(self, dbt_cmd: str, dbt_project_path: Path, env: dict[str, str]) -> None:
"""
Given the dbt command path and the dbt project path, build and run the dbt deps command.
"""
deps_command = [dbt_cmd, "deps"]
deps_command.extend(self.local_flags)
self._add_vars_arg(deps_command)
stdout = run_command(deps_command, dbt_project_path, env)
logger.debug("dbt deps output: %s", stdout)

def load_via_dbt_ls_without_cache(self) -> None:
"""
This is the most accurate way of loading `dbt` projects and filtering them out, since it uses the `dbt` command
Expand Down Expand Up @@ -461,16 +477,14 @@ def load_via_dbt_ls_without_cache(self) -> None:
"--target",
self.profile_config.target_name,
]

self.log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir_path / DBT_LOG_DIR_NAME)
self.target_dir = Path(env.get(DBT_TARGET_PATH_ENVVAR) or tmpdir_path / DBT_TARGET_DIR_NAME)
env[DBT_LOG_PATH_ENVVAR] = str(self.log_dir)
env[DBT_TARGET_PATH_ENVVAR] = str(self.target_dir)

if self.render_config.dbt_deps and has_non_empty_dependencies_file(self.project_path):
deps_command = [dbt_cmd, "deps"]
deps_command.extend(self.local_flags)
stdout = run_command(deps_command, tmpdir_path, env)
logger.debug("dbt deps output: %s", stdout)
self.run_dbt_deps(dbt_cmd, tmpdir_path, env)

nodes = self.run_dbt_ls(dbt_cmd, self.project_path, tmpdir_path, env)

Expand Down
11 changes: 10 additions & 1 deletion tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,15 @@ def test_dbt_ls_cache_key_args_sorts_envvars():
assert graph.dbt_ls_cache_key_args == ['{"5": "May", "11": "November", "12": "December"}']


@patch("cosmos.dbt.graph.run_command")
def test_run_dbt_deps(run_command_mock):
project_config = ProjectConfig(dbt_vars={"var-key": "var-value"})
graph = DbtGraph(project=project_config)
graph.local_flags = []
graph.run_dbt_deps("dbt", "/some/path", {})
run_command_mock.assert_called_with(["dbt", "deps", "--vars", '{"var-key": "var-value"}'], "/some/path", {})


@pytest.fixture()
def airflow_variable():
key = "cosmos_cache__undefined"
Expand Down Expand Up @@ -1413,7 +1422,7 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir
hash_dir, hash_args = version.split(",")
assert hash_args == "d41d8cd98f00b204e9800998ecf8427e"
if sys.platform == "darwin":
assert hash_dir == "465fc0735d8bef08a0d375b2315069bb"
assert hash_dir == "18b97e2bff2684161f71db817f1f50e2"
else:
assert hash_dir == "6c662da10b64a8390c469c884af88321"

Expand Down