Skip to content

Commit

Permalink
feat: Add loading sharded GGUF files from HuggingFace with Llama.from…
Browse files Browse the repository at this point in the history
…_pretrained(additional_files=[...]) . Closes #1341

Co-authored-by: Andrei <[email protected]>
  • Loading branch information
Gnurro and abetlen committed Sep 19, 2024
1 parent 29afcfd commit 84c0920
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,6 +2227,7 @@ def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
additional_files: Optional[List] = None,
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
Expand All @@ -2239,6 +2240,7 @@ def from_pretrained(
Args:
repo_id: The model repo id.
filename: A filename or glob pattern to match the model file in the repo.
additional_files: A list of filenames or glob patterns to match additional model files in the repo.
local_dir: The local directory to save the model to.
local_dir_use_symlinks: Whether to use symlinks when downloading the model.
**kwargs: Additional keyword arguments to pass to the Llama constructor.
Expand Down Expand Up @@ -2269,6 +2271,7 @@ def from_pretrained(
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))

# find the only/first shard file:
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore

if len(matching_files) == 0:
Expand Down Expand Up @@ -2298,6 +2301,35 @@ def from_pretrained(
cache_dir=cache_dir,
)

if additional_files:
for additonal_file_name in additional_files:
# find the additional shard file:
matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)]

if len(matching_additional_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {additonal_file_name}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)

if len(matching_additional_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n"
f"Available Files:\n{json.dumps(files)}"
)

(matching_additional_file,) = matching_additional_files

# download the additional file
hf_hub_download(
repo_id=repo_id,
filename=matching_additional_file,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cache_dir,
)

if local_dir is None:
model_path = hf_hub_download(
repo_id=repo_id,
Expand All @@ -2311,6 +2343,7 @@ def from_pretrained(
else:
model_path = os.path.join(local_dir, filename)

# loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
return cls(
model_path=model_path,
**kwargs,
Expand Down

0 comments on commit 84c0920

Please sign in to comment.