diff --git a/application/pyproject.toml b/application/pyproject.toml index 86590298..30388f6b 100644 --- a/application/pyproject.toml +++ b/application/pyproject.toml @@ -28,4 +28,7 @@ exclude = [ [tool.pytest.ini_options] DJANGO_SETTINGS_MODULE = "settings_test" pythonpath = ["./src", "./test"] -python_files = "test_*.py" \ No newline at end of file +python_files = "test_*.py" + +[tool.pytest_env] +HF_HOME = "./tira-root/huggingface" \ No newline at end of file diff --git a/application/setup.cfg b/application/setup.cfg index 7fe224de..8ed7f473 100644 --- a/application/setup.cfg +++ b/application/setup.cfg @@ -39,6 +39,7 @@ test = parameterized approvaltests==7.3.0 pytest-django + pytest-env==1.1.3 dev = coverage coverage-badge diff --git a/application/src/tira/huggingface_hub_integration.py b/application/src/tira/huggingface_hub_integration.py index ccf968a1..4724d6eb 100644 --- a/application/src/tira/huggingface_hub_integration.py +++ b/application/src/tira/huggingface_hub_integration.py @@ -1,14 +1,11 @@ -import os -from typing import Iterable +from typing import Iterable, Optional -from django.conf import settings -from huggingface_hub import scan_cache_dir -from huggingface_hub import snapshot_download as hfsnapshot_download +from huggingface_hub import HFCacheInfo, scan_cache_dir, snapshot_download +from huggingface_hub.constants import HF_HOME import tira.io_utils as tira_cli_io_utils -TIRA_HOST_HF_HOME = tira_cli_io_utils.default_hf_home_in_tira_host(settings.TIRA_ROOT) -HF_CACHE = None +HF_CACHE: Optional[HFCacheInfo] = None def _hf_repos() -> dict[str, str]: @@ -35,16 +32,8 @@ def huggingface_model_mounts(models: Iterable[str]): else: raise Exception(f"Model {model} is not available in the Huggingface cache") - return {"MOUNT_HF_MODEL": " ".join(models), "HF_HOME": TIRA_HOST_HF_HOME, "HF_CACHE_SCAN": ret} + return {"MOUNT_HF_MODEL": " ".join(models), "HF_HOME": HF_HOME, "HF_CACHE_SCAN": ret} def snapshot_download_hf_model(model: str): - os.environ["HF_HOME"] = TIRA_HOST_HF_HOME snapshot_download(repo_id=model.replace("--", "/")) - - -def snapshot_download(*args, **kwargs) -> str: - return hfsnapshot_download(*args, cache_dir=TIRA_HOST_HF_HOME / "hub", **kwargs) - - -snapshot_download.__doc__ = hfsnapshot_download.__doc__ diff --git a/application/test/hf_mount_model_tests/test_hf_mounts_are_parsed.py b/application/test/hf_mount_model_tests/test_hf_mounts_are_parsed.py index d21b9c78..cf0640e8 100644 --- a/application/test/hf_mount_model_tests/test_hf_mounts_are_parsed.py +++ b/application/test/hf_mount_model_tests/test_hf_mounts_are_parsed.py @@ -1,6 +1,8 @@ import unittest -from tira.huggingface_hub_integration import _hf_repos, huggingface_model_mounts, snapshot_download +from huggingface_hub import snapshot_download + +from tira.huggingface_hub_integration import _hf_repos, huggingface_model_mounts class TestHfMountsAreParsed(unittest.TestCase): diff --git a/python-client/tira/io_utils.py b/python-client/tira/io_utils.py index a12314f6..3be46d55 100644 --- a/python-client/tira/io_utils.py +++ b/python-client/tira/io_utils.py @@ -152,19 +152,6 @@ def _ln_huggingface_model_mounts(models: str) -> str: return "; ".join(ret + [f'echo "mounted {len(models)} models"']) -def default_hf_home_in_tira_host(tira_base: Path) -> str: - """Returns the location of the hf home on the tira hosts that are mounted read-only into the pods. - - Args: - tira_base (Path): The base path to TIRA's file structure. The hugging face home directory will be located inside - it. - - Returns: - str: the HF_HOME on a tira host. - """ - return tira_base / "data" / "publicly-shared-datasets" / "huggingface" - - def all_lines_to_pandas(input_file: Union[str, Iterable[str]], load_default_text: bool) -> pd.DataFrame: """ .. todo:: add documentation