Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 105 additions & 132 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand All @@ -9,6 +10,7 @@
from .file_download import hf_hub_download, is_torch_available
from .hf_api import HfApi
from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
from .utils._deprecation import _deprecate_positional_args


if is_torch_available():
Expand All @@ -19,34 +21,36 @@

class ModelHubMixin:
"""
A generic Hub mixin for machine learning models. Define your own mixin for
any framework by inheriting from this class and overwriting the
[`_from_pretrained`] and [`_save_pretrained`] methods to define custom logic
for saving and loading your classes. See [`PyTorchModelHubMixin`] for an
example.
A generic Hub mixin for machine learning models. Define your own mixin for any framework
by inheriting from this class and overwriting the [`_from_pretrained`] and [`_save_pretrained`]
methods to define custom logic for saving and loading your classes.

See [`PyTorchModelHubMixin`] for an example.
"""

@_deprecate_positional_args(version="0.16")
def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[dict] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
):
) -> Optional[str]:
"""
Save weights in local directory.

Parameters:
save_directory (`str` or `Path`):
Specify directory in which you want to save weights.
Path to directory in which the model weights and configuration will be saved.
config (`dict`, *optional*):
Specify config (must be dict) in case you want to save
it.
Model configuration specified as a key/value dictionary.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face Hub after
saving it. You can specify the repository you want to push to with
`repo_id` (will default to the name of `save_directory` in your
namespace).
Whether or not to push your model to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will
default to the folder name if not provided.
kwargs:
Additional key word arguments passed along to the
[`~utils.PushToHubMixin.push_to_hub`] method.
Expand All @@ -67,98 +71,75 @@ def save_pretrained(
if config is not None: # kwarg for `push_to_hub`
kwargs["config"] = config

if kwargs.get("repo_id") is None:
if repo_id is None:
# Repo name defaults to `save_directory` name
kwargs["repo_id"] = Path(save_directory).name
repo_id = Path(save_directory).name

return self.push_to_hub(**kwargs)
return self.push_to_hub(repo_id=repo_id, **kwargs)
return None

def _save_pretrained(self, save_directory: Union[str, Path]):
def _save_pretrained(self, save_directory: Union[str, Path]) -> None:
"""
Overwrite this method in subclass to define how to save your model.
"""
raise NotImplementedError

@classmethod
@validate_hf_hub_args
@_deprecate_positional_args(version="0.16")
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**model_kwargs,
):
r"""
Download and instantiate a model from the Hugging Face Hub.

Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model
hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level,
like `bert-base-uncased`, or namespaced under a
user or organization name, like
`dbmdz/bert-base-german-cased`.
- You can add `revision` by appending `@` at the end
of model_id simply like this:
`dbmdz/bert-base-german-cased@main` Revision is
the specific model version to use. It can be a
branch name, a tag name, or a commit id, since we
use a git-based system for storing models and
other artifacts on huggingface.co, so `revision`
can be any identifier allowed by git.
- A path to a `directory` containing model weights
saved using
[`~transformers.PreTrainedModel.save_pretrained`],
e.g., `./my_model_directory/`.
- `None` if you are both providing the configuration
and state dictionary (resp. with keyword arguments
`config` and `state_dict`).
force_download (`bool`, *optional*, defaults to `False`):
Whether to force the (re-)download of the model weights
and configuration files, overriding the cached versions
if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will
attempt to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or
endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are
used on each request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote
files. If `True`, will use the token generated when
running `transformers-cli login` (stored in
`~/.huggingface`).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained
model configuration should be cached if the standard
cache should not be used.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only look at local files (i.e., do not try to
download the model).
model_kwargs (`Dict`, *optional*):
model_kwargs will be passed to the model during
initialization

<Tip>

Passing `token=True` is required when you want to use a
private model.

</Tip>
"""
Download a model from the Huggingface Hub and instantiate it.

Parameters:
pretrained_model_name_or_path (`str`, `Path`):
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
- Or a path to a `directory` containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
Defaults to the latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration
files from the Hub, overriding the existing cache.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will attempt to resume the
download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default,
it will use the token cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
model_kwargs (`Dict`, *optional*):
Additional kwargs to pass to the model during initialization.
"""
model_id = pretrained_model_name_or_path

revision = None
if len(model_id.split("@")) == 2:
if isinstance(model_id, str) and len(model_id.split("@")) == 2:
warnings.warn(
"Passing a revision using 'namespace/model_id@revision' pattern is"
" deprecated and will be removed in version v0.16. Please pass"
" 'revision=...' as argument.",
FutureWarning,
)
model_id, revision = model_id.split("@")

config_file: Optional[str] = None
Expand All @@ -167,10 +148,10 @@ def from_pretrained(
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
elif isinstance(model_id, str):
try:
config_file = hf_hub_download(
repo_id=model_id,
repo_id=str(model_id),
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
Expand All @@ -181,36 +162,38 @@ def from_pretrained(
local_files_only=local_files_only,
)
except requests.exceptions.RequestException:
logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub")
logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub.")

if config_file is not None:
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
model_kwargs.update({"config": config})

return cls._from_pretrained(
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
model_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
**model_kwargs,
)

@classmethod
@_deprecate_positional_args(version="0.16")
def _from_pretrained(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
**model_kwargs,
):
"""Overwrite this method in subclass to define how to load your model from
Expand Down Expand Up @@ -268,13 +251,9 @@ def push_to_hub(
Returns:
The url of the commit of your model in the given repository.
"""
api = HfApi(endpoint=api_endpoint)
api = HfApi(endpoint=api_endpoint, token=token)
api.create_repo(
repo_id=repo_id,
repo_type="model",
token=token,
private=private,
exist_ok=True,
repo_id=repo_id, repo_type="model", private=private, exist_ok=True
)

# Push the files to the repo in a single commit
Expand All @@ -284,7 +263,6 @@ def push_to_hub(
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
token=token,
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
Expand Down Expand Up @@ -330,35 +308,30 @@ class PyTorchModelHubMixin(ModelHubMixin):
```
"""

def _save_pretrained(self, save_directory):
"""
Overwrite this method if you wish to save specific layers instead of the
complete model.
"""
def _save_pretrained(self, save_directory: Union[str, Path]):
"""Save weights from a Pytorch model to a local directory."""
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
model_to_save = self.module if hasattr(self, "module") else self
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), path)

@classmethod
@_deprecate_positional_args(version="0.16")
def _from_pretrained(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
map_location="cpu",
strict=False,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""
Overwrite this method to initialize your model in a different way.
"""
map_location = torch.device(map_location)

"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
Expand All @@ -376,8 +349,8 @@ def _from_pretrained(
)
model = cls(**model_kwargs)

state_dict = torch.load(model_file, map_location=map_location)
model.load_state_dict(state_dict, strict=strict)
model.eval()
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore

return model
Loading