Skip to content

Commit

Permalink
removed default_hf_home_in_tira_host
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMrSheldon committed Sep 3, 2024
1 parent 6a08472 commit d520ccb
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 31 deletions.
5 changes: 4 additions & 1 deletion application/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ exclude = [
[tool.pytest.ini_options]
DJANGO_SETTINGS_MODULE = "settings_test"
pythonpath = ["./src", "./test"]
python_files = "test_*.py"
python_files = "test_*.py"

[tool.pytest_env]
HF_HOME = "./tira-root/huggingface"
1 change: 1 addition & 0 deletions application/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ test =
parameterized
approvaltests==7.3.0
pytest-django
pytest-env==1.1.3
dev =
coverage
coverage-badge
Expand Down
21 changes: 5 additions & 16 deletions application/src/tira/huggingface_hub_integration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os
from typing import Iterable
from typing import Iterable, Optional

from django.conf import settings
from huggingface_hub import scan_cache_dir
from huggingface_hub import snapshot_download as hfsnapshot_download
from huggingface_hub import HFCacheInfo, scan_cache_dir, snapshot_download
from huggingface_hub.constants import HF_HOME

import tira.io_utils as tira_cli_io_utils

TIRA_HOST_HF_HOME = tira_cli_io_utils.default_hf_home_in_tira_host(settings.TIRA_ROOT)
HF_CACHE = None
HF_CACHE: Optional[HFCacheInfo] = None


def _hf_repos() -> dict[str, str]:
Expand All @@ -35,16 +32,8 @@ def huggingface_model_mounts(models: Iterable[str]):
else:
raise Exception(f"Model {model} is not available in the Huggingface cache")

return {"MOUNT_HF_MODEL": " ".join(models), "HF_HOME": TIRA_HOST_HF_HOME, "HF_CACHE_SCAN": ret}
return {"MOUNT_HF_MODEL": " ".join(models), "HF_HOME": HF_HOME, "HF_CACHE_SCAN": ret}


def snapshot_download_hf_model(model: str):
os.environ["HF_HOME"] = TIRA_HOST_HF_HOME
snapshot_download(repo_id=model.replace("--", "/"))


def snapshot_download(*args, **kwargs) -> str:
return hfsnapshot_download(*args, cache_dir=TIRA_HOST_HF_HOME / "hub", **kwargs)


snapshot_download.__doc__ = hfsnapshot_download.__doc__
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest

from tira.huggingface_hub_integration import _hf_repos, huggingface_model_mounts, snapshot_download
from huggingface_hub import snapshot_download

from tira.huggingface_hub_integration import _hf_repos, huggingface_model_mounts


class TestHfMountsAreParsed(unittest.TestCase):
Expand Down
13 changes: 0 additions & 13 deletions python-client/tira/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,6 @@ def _ln_huggingface_model_mounts(models: str) -> str:
return "; ".join(ret + [f'echo "mounted {len(models)} models"'])


def default_hf_home_in_tira_host(tira_base: Path) -> str:
"""Returns the location of the hf home on the tira hosts that are mounted read-only into the pods.
Args:
tira_base (Path): The base path to TIRA's file structure. The hugging face home directory will be located inside
it.
Returns:
str: the HF_HOME on a tira host.
"""
return tira_base / "data" / "publicly-shared-datasets" / "huggingface"


def all_lines_to_pandas(input_file: Union[str, Iterable[str]], load_default_text: bool) -> pd.DataFrame:
"""
.. todo:: add documentation
Expand Down

0 comments on commit d520ccb

Please sign in to comment.