diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4cd62c8c..9fb9d5db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: name: Detect secrets - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.257' + rev: 'v0.11.7' hooks: - id: ruff name: Linting with ruff diff --git a/airflow_dbt_python/hooks/fs/git.py b/airflow_dbt_python/hooks/fs/git.py index 70949ee9..bcab442b 100644 --- a/airflow_dbt_python/hooks/fs/git.py +++ b/airflow_dbt_python/hooks/fs/git.py @@ -157,7 +157,11 @@ def _download( client, path, branch = self.get_git_client_path(source) client.clone( - path, str(destination), mkdir=not destination.exists(), branch=branch + path, + str(destination), + mkdir=not destination.exists(), + # NOTE: Dulwich expects branch to be bytes if defined. + branch=branch.encode("utf-8") if isinstance(branch, str) else branch, ) def get_git_client_path(self, url: URL) -> Tuple[GitClients, str, Optional[str]]: diff --git a/tests/hooks/test_git_hook.py b/tests/hooks/test_git_hook.py index 3c7e3bd7..daf60b62 100644 --- a/tests/hooks/test_git_hook.py +++ b/tests/hooks/test_git_hook.py @@ -4,6 +4,7 @@ import os import platform import shutil +import typing import pytest from dulwich.repo import Repo @@ -294,6 +295,16 @@ def repo_name(): return "test/test_shop" +@pytest.fixture +def repo_branch(request) -> typing.Optional[bytes]: + """A configurable local git repo branch.""" + try: + return request.param + except AttributeError: + # Default to dulwich's + return None + + @pytest.fixture def repo_dir(tmp_path): """A testing local git repo directory.""" @@ -303,9 +314,9 @@ def repo_dir(tmp_path): @pytest.fixture -def repo(repo_dir, dbt_project_file, test_files, profiles_file): +def repo(repo_dir, dbt_project_file, test_files, profiles_file, repo_branch): """Initialize a git repo with some dbt test files.""" - repo = Repo.init(repo_dir) + repo = Repo.init(repo_dir, default_branch=repo_branch) shutil.copyfile(dbt_project_file, repo_dir / "dbt_project.yml") repo.stage("dbt_project.yml") @@ -364,6 +375,31 @@ def test_download_dbt_project_with_local_server( assert_dir_contents(local_repo_path, expected, exact=False) +@no_git_local_server +@pytest.mark.parametrize("repo_branch", ["test-branch".encode("utf-8")], indirect=True) +def test_download_dbt_project_with_custom_branch_from_local_server( + tmp_path, git_server, repo_name, assert_dir_contents, repo_branch +): + """Test downloading a dbt project from a local git server.""" + local_path = tmp_path / "local" + fs_hook = DbtGitFSHook() + server_address, server_port = git_server + source = URL( + f"git://{server_address}:{server_port}/{repo_name}@{repo_branch.decode('utf-8')}" + ) + local_repo_path = fs_hook.download_dbt_project(source, local_path) + + expected = [ + URL(local_repo_path / "dbt_project.yml"), + URL(local_repo_path / "models" / "a_model.sql"), + URL(local_repo_path / "models" / "another_model.sql"), + URL(local_repo_path / "seeds" / "a_seed.csv"), + ] + + assert local_repo_path.exists() + assert_dir_contents(local_repo_path, expected, exact=False) + + @pytest.fixture def pre_run(hook, repo_dir): """Fixture to run a dbt run task."""