Skip to content

Commit

Permalink
fixed communicating HF Cache folder to hf
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMrSheldon committed Sep 3, 2024
1 parent a72d806 commit 6a08472
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
13 changes: 11 additions & 2 deletions application/src/tira/huggingface_hub_integration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import Iterable

from huggingface_hub import scan_cache_dir, snapshot_download
from django.conf import settings
from huggingface_hub import scan_cache_dir
from huggingface_hub import snapshot_download as hfsnapshot_download

import tira.io_utils as tira_cli_io_utils

TIRA_HOST_HF_HOME = tira_cli_io_utils._default_hf_home_in_tira_host()
TIRA_HOST_HF_HOME = tira_cli_io_utils.default_hf_home_in_tira_host(settings.TIRA_ROOT)
HF_CACHE = None


Expand Down Expand Up @@ -39,3 +41,10 @@ def huggingface_model_mounts(models: Iterable[str]):
def snapshot_download_hf_model(model: str):
os.environ["HF_HOME"] = TIRA_HOST_HF_HOME
snapshot_download(repo_id=model.replace("--", "/"))


def snapshot_download(*args, **kwargs) -> str:
return hfsnapshot_download(*args, cache_dir=TIRA_HOST_HF_HOME / "hub", **kwargs)


snapshot_download.__doc__ = hfsnapshot_download.__doc__
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import os
import unittest

from tira.huggingface_hub_integration import TIRA_HOST_HF_HOME, _hf_repos, huggingface_model_mounts

os.environ["HF_HOME"] = TIRA_HOST_HF_HOME
from tira.huggingface_hub_integration import _hf_repos, huggingface_model_mounts, snapshot_download


class TestHfMountsAreParsed(unittest.TestCase):
def fail_if_hf_is_not_installed(self):
os.environ["HF_HOME"] = TIRA_HOST_HF_HOME
from huggingface_hub import snapshot_download

snapshot_download(repo_id="prajjwal1/bert-tiny")
self.assertGreater(len(_hf_repos()), 0)
del os.environ["HF_HOME"]

def test_hf_is_installed(self):
self.fail_if_hf_is_not_installed()
Expand All @@ -40,8 +33,6 @@ def test_non_existing_hf_models_can_not_be_mounted(self):

def test_existing_hf_model_can_be_mounted(self):
self.fail_if_hf_is_not_installed()
os.environ["HF_HOME"] = TIRA_HOST_HF_HOME

actual = huggingface_model_mounts(["prajjwal1/bert-tiny"])
del os.environ["HF_HOME"]
self.assertEqual("prajjwal1/bert-tiny", actual["MOUNT_HF_MODEL"])
8 changes: 6 additions & 2 deletions python-client/tira/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,17 @@ def _ln_huggingface_model_mounts(models: str) -> str:
return "; ".join(ret + [f'echo "mounted {len(models)} models"'])


def _default_hf_home_in_tira_host() -> str:
def default_hf_home_in_tira_host(tira_base: Path) -> str:
"""Returns the location of the hf home on the tira hosts that are mounted read-only into the pods.
Args:
tira_base (Path): The base path to TIRA's file structure. The hugging face home directory will be located inside
it.
Returns:
str: the HF_HOME on a tira host.
"""
return "/mnt/ceph/tira/data/publicly-shared-datasets/huggingface/"
return tira_base / "data" / "publicly-shared-datasets" / "huggingface"


def all_lines_to_pandas(input_file: Union[str, Iterable[str]], load_default_text: bool) -> pd.DataFrame:
Expand Down

0 comments on commit 6a08472

Please sign in to comment.