Skip to content

Commit

Permalink
Fix snapshot download when local_dir is provided. (#2592)
Browse files Browse the repository at this point in the history
* Fix snapshot download when local_dir is provided

* Fix tests docstring

* Add comment

* Fixes post-review
  • Loading branch information
hanouticelina authored Oct 8, 2024
1 parent 39c7a8b commit c8da356
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
from .utils import (
OfflineModeIsEnabled,
filter_repo_objects,
logging,
validate_hf_hub_args,
)
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
from .utils import tqdm as hf_tqdm


Expand Down Expand Up @@ -191,6 +186,7 @@ def snapshot_download(
# => let's look if we can find the appropriate folder in the cache:
# - if the specified revision is a commit hash, look inside "snapshots".
# - f the specified revision is a branch or tag, look inside "refs".
# => if local_dir is not None, we will return the path to the local folder if it exists.
if repo_info is None:
# Try to get which commit hash corresponds to the specified revision
commit_hash = None
Expand All @@ -210,7 +206,14 @@ def snapshot_download(
# Snapshot folder exists => let's return it
# (but we can't check if all the files are actually there)
return snapshot_folder

# If local_dir is not None, return it if it exists and is not empty
if local_dir is not None:
local_dir = Path(local_dir)
if local_dir.is_dir() and any(local_dir.iterdir()):
logger.warning(
f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
)
return str(local_dir.resolve())
# If we couldn't find the appropriate folder on disk, raise an error.
if local_files_only:
raise LocalEntryNotFoundError(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,37 @@ def test_download_model_local_only(self):
)
self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision

# Test with local_dir
with SoftTemporaryDirectory() as tmpdir:
# first download folder to local_dir
snapshot_download(self.repo_id, local_dir=tmpdir)
# now load from local_dir
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)
self.assertEquals(str(tmpdir), storage_folder)

def test_download_model_to_local_dir_with_offline_mode(self):
"""Test that an already downloaded folder is returned when there is a connection error"""
# first download folder to local_dir
with SoftTemporaryDirectory() as tmpdir:
snapshot_download(self.repo_id, local_dir=tmpdir)
# Check that the folder is returned when there is a connection error
for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir)
self.assertEquals(str(tmpdir), storage_folder)

def test_download_model_offline_mode_not_in_local_dir(self):
"""Test when connection error but local_dir is empty."""
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)

for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, local_dir=tmpdir)

def test_download_model_offline_mode_not_cached(self):
"""Test when connection error but cache is empty."""
with SoftTemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit c8da356

Please sign in to comment.