diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 19c10c7321..5462ead2c8 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,5 +1,6 @@ import json import os +import warnings from pathlib import Path from typing import Dict, List, Optional, Union @@ -9,6 +10,7 @@ from .file_download import hf_hub_download, is_torch_available from .hf_api import HfApi from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args +from .utils._deprecation import _deprecate_positional_args if is_torch_available(): @@ -19,34 +21,36 @@ class ModelHubMixin: """ - A generic Hub mixin for machine learning models. Define your own mixin for - any framework by inheriting from this class and overwriting the - [`_from_pretrained`] and [`_save_pretrained`] methods to define custom logic - for saving and loading your classes. See [`PyTorchModelHubMixin`] for an - example. + A generic Hub mixin for machine learning models. Define your own mixin for any framework + by inheriting from this class and overwriting the [`_from_pretrained`] and [`_save_pretrained`] + methods to define custom logic for saving and loading your classes. + + See [`PyTorchModelHubMixin`] for an example. """ + @_deprecate_positional_args(version="0.16") def save_pretrained( self, save_directory: Union[str, Path], + *, config: Optional[dict] = None, + repo_id: Optional[str] = None, push_to_hub: bool = False, **kwargs, - ): + ) -> Optional[str]: """ Save weights in local directory. Parameters: save_directory (`str` or `Path`): - Specify directory in which you want to save weights. + Path to directory in which the model weights and configuration will be saved. config (`dict`, *optional*): - Specify config (must be dict) in case you want to save - it. + Model configuration specified as a key/value dictionary. push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face Hub after - saving it. You can specify the repository you want to push to with - `repo_id` (will default to the name of `save_directory` in your - namespace). + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will + default to the folder name if not provided. kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. @@ -67,13 +71,14 @@ def save_pretrained( if config is not None: # kwarg for `push_to_hub` kwargs["config"] = config - if kwargs.get("repo_id") is None: + if repo_id is None: # Repo name defaults to `save_directory` name - kwargs["repo_id"] = Path(save_directory).name + repo_id = Path(save_directory).name - return self.push_to_hub(**kwargs) + return self.push_to_hub(repo_id=repo_id, **kwargs) + return None - def _save_pretrained(self, save_directory: Union[str, Path]): + def _save_pretrained(self, save_directory: Union[str, Path]) -> None: """ Overwrite this method in subclass to define how to save your model. """ @@ -81,84 +86,60 @@ def _save_pretrained(self, save_directory: Union[str, Path]): @classmethod @validate_hf_hub_args + @_deprecate_positional_args(version="0.16") def from_pretrained( cls, - pretrained_model_name_or_path: str, + pretrained_model_name_or_path: Union[str, Path], + *, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict] = None, token: Optional[Union[str, bool]] = None, - cache_dir: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, local_files_only: bool = False, + revision: Optional[str] = None, **model_kwargs, ): - r""" - Download and instantiate a model from the Hugging Face Hub. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - A string, the `model id` of a pretrained model - hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, - like `bert-base-uncased`, or namespaced under a - user or organization name, like - `dbmdz/bert-base-german-cased`. - - You can add `revision` by appending `@` at the end - of model_id simply like this: - `dbmdz/bert-base-german-cased@main` Revision is - the specific model version to use. It can be a - branch name, a tag name, or a commit id, since we - use a git-based system for storing models and - other artifacts on huggingface.co, so `revision` - can be any identifier allowed by git. - - A path to a `directory` containing model weights - saved using - [`~transformers.PreTrainedModel.save_pretrained`], - e.g., `./my_model_directory/`. - - `None` if you are both providing the configuration - and state dictionary (resp. with keyword arguments - `config` and `state_dict`). - force_download (`bool`, *optional*, defaults to `False`): - Whether to force the (re-)download of the model weights - and configuration files, overriding the cached versions - if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether to delete incompletely received files. Will - attempt to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or - endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are - used on each request. - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote - files. If `True`, will use the token generated when - running `transformers-cli login` (stored in - `~/.huggingface`). - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained - model configuration should be cached if the standard - cache should not be used. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether to only look at local files (i.e., do not try to - download the model). - model_kwargs (`Dict`, *optional*): - model_kwargs will be passed to the model during - initialization - - - - Passing `token=True` is required when you want to use a - private model. - - """ + Download a model from the Huggingface Hub and instantiate it. + Parameters: + pretrained_model_name_or_path (`str`, `Path`): + - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. + - Or a path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration + files from the Hub, overriding the existing cache. + resume_download (`bool`, *optional*, defaults to `False`): + Whether to delete incompletely received files. Will attempt to resume the + download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, + it will use the token cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + model_kwargs (`Dict`, *optional*): + Additional kwargs to pass to the model during initialization. + """ model_id = pretrained_model_name_or_path - revision = None - if len(model_id.split("@")) == 2: + if isinstance(model_id, str) and len(model_id.split("@")) == 2: + warnings.warn( + "Passing a revision using 'namespace/model_id@revision' pattern is" + " deprecated and will be removed in version v0.16. Please pass" + " 'revision=...' as argument.", + FutureWarning, + ) model_id, revision = model_id.split("@") config_file: Optional[str] = None @@ -167,10 +148,10 @@ def from_pretrained( config_file = os.path.join(model_id, CONFIG_NAME) else: logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") - else: + elif isinstance(model_id, str): try: config_file = hf_hub_download( - repo_id=model_id, + repo_id=str(model_id), filename=CONFIG_NAME, revision=revision, cache_dir=cache_dir, @@ -181,7 +162,7 @@ def from_pretrained( local_files_only=local_files_only, ) except requests.exceptions.RequestException: - logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub") + logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub.") if config_file is not None: with open(config_file, "r", encoding="utf-8") as f: @@ -189,28 +170,30 @@ def from_pretrained( model_kwargs.update({"config": config}) return cls._from_pretrained( - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token, + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, **model_kwargs, ) @classmethod + @_deprecate_positional_args(version="0.16") def _from_pretrained( cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], **model_kwargs, ): """Overwrite this method in subclass to define how to load your model from @@ -268,13 +251,9 @@ def push_to_hub( Returns: The url of the commit of your model in the given repository. """ - api = HfApi(endpoint=api_endpoint) + api = HfApi(endpoint=api_endpoint, token=token) api.create_repo( - repo_id=repo_id, - repo_type="model", - token=token, - private=private, - exist_ok=True, + repo_id=repo_id, repo_type="model", private=private, exist_ok=True ) # Push the files to the repo in a single commit @@ -284,7 +263,6 @@ def push_to_hub( return api.upload_folder( repo_id=repo_id, repo_type="model", - token=token, folder_path=saved_path, commit_message=commit_message, revision=branch, @@ -330,35 +308,30 @@ class PyTorchModelHubMixin(ModelHubMixin): ``` """ - def _save_pretrained(self, save_directory): - """ - Overwrite this method if you wish to save specific layers instead of the - complete model. - """ + def _save_pretrained(self, save_directory: Union[str, Path]): + """Save weights from a Pytorch model to a local directory.""" path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) - model_to_save = self.module if hasattr(self, "module") else self + model_to_save = self.module if hasattr(self, "module") else self # type: ignore torch.save(model_to_save.state_dict(), path) @classmethod + @_deprecate_positional_args(version="0.16") def _from_pretrained( cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token, - map_location="cpu", - strict=False, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, **model_kwargs, ): - """ - Overwrite this method to initialize your model in a different way. - """ - map_location = torch.device(map_location) - + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) @@ -376,8 +349,8 @@ def _from_pretrained( ) model = cls(**model_kwargs) - state_dict = torch.load(model_file, map_location=map_location) - model.load_state_dict(state_dict, strict=strict) - model.eval() + state_dict = torch.load(model_file, map_location=torch.device(map_location)) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore return model diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 46305e6ab9..52dc509440 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -1,21 +1,18 @@ import json import os import unittest -from unittest.mock import Mock +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.hub_mixin import PyTorchModelHubMixin -from huggingface_hub.utils import is_torch_available, logging +from huggingface_hub.utils import SoftTemporaryDirectory, is_torch_available from .testing_constants import ENDPOINT_STAGING, TOKEN, USER -from .testing_utils import expect_deprecation, repo_name, rmtree_with_retry - +from .testing_utils import repo_name -logger = logging.get_logger(__name__) -WORKING_REPO_SUBDIR = "fixtures/working_repo_2" -WORKING_REPO_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR -) if is_torch_available(): import torch.nn as nn @@ -35,6 +32,7 @@ def require_torch(test_case): if is_torch_available(): + CONFIG = {"num": 10, "act": "gelu_fast"} class DummyModel(nn.Module, PyTorchModelHubMixin): def __init__(self, **kwargs): @@ -50,49 +48,34 @@ def forward(self, x): @require_torch -class HubMixingCommonTest(unittest.TestCase): - _api = HfApi(endpoint=ENDPOINT_STAGING) - - -@require_torch -class HubMixingTest(HubMixingCommonTest): - def tearDown(self) -> None: - if os.path.exists(WORKING_REPO_DIR): - rmtree_with_retry(WORKING_REPO_DIR) - logger.info( - f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}" - ) +@pytest.mark.usefixtures("fx_cache_dir") +class HubMixingTest(unittest.TestCase): + cache_dir: Path @classmethod - @expect_deprecation("set_access_token") def setUpClass(cls): """ Share this valid token in all tests below. """ - cls._token = TOKEN - cls._api.token = TOKEN - cls._api.set_access_token(TOKEN) - - def test_save_pretrained(self): - REPO_NAME = repo_name("save") - model = DummyModel() + cls._api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) - model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") - files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + def test_save_pretrained_basic(self): + DummyModel().save_pretrained(self.cache_dir) + files = os.listdir(self.cache_dir) self.assertTrue("pytorch_model.bin" in files) self.assertEqual(len(files), 1) - model.save_pretrained( - f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} - ) - files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + def test_save_pretrained_with_config(self): + DummyModel().save_pretrained(self.cache_dir, config=CONFIG) + files = os.listdir(self.cache_dir) self.assertTrue("config.json" in files) self.assertTrue("pytorch_model.bin" in files) self.assertEqual(len(files), 2) def test_save_pretrained_with_push_to_hub(self): - REPO_NAME = repo_name("save") - save_directory = f"{WORKING_REPO_DIR}/{REPO_NAME}" + repo_id = repo_name("save") + save_directory = self.cache_dir / repo_id + config = {"hello": "world"} mocked_model = DummyModel() mocked_model.push_to_hub = Mock() @@ -110,40 +93,57 @@ def test_save_pretrained_with_push_to_hub(self): # Push to hub with default repo_id (based on dir name) mocked_model.save_pretrained(save_directory, push_to_hub=True, config=config) - mocked_model.push_to_hub.assert_called_with(repo_id=REPO_NAME, config=config) - - def test_rel_path_from_pretrained(self): - model = DummyModel() - model.save_pretrained( - f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", - config={"num": 10, "act": "gelu_fast"}, - ) - - model = DummyModel.from_pretrained( - f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" + mocked_model.push_to_hub.assert_called_with(repo_id=repo_id, config=config) + + @patch.object(DummyModel, "_from_pretrained") + def test_from_pretrained_model_id_only(self, from_pretrained_mock: Mock) -> None: + model = DummyModel.from_pretrained("namespace/repo_name") + from_pretrained_mock.assert_called_once() + self.assertIs(model, from_pretrained_mock.return_value) + + @patch.object(DummyModel, "_from_pretrained") + def test_from_pretrained_model_id_and_revision( + self, from_pretrained_mock: Mock + ) -> None: + """Regression test for #1313. + + See https://github.com/huggingface/huggingface_hub/issues/1313.""" + model = DummyModel.from_pretrained("namespace/repo_name", revision="123456789") + from_pretrained_mock.assert_called_once_with( + model_id="namespace/repo_name", + revision="123456789", # Revision is passed correctly! + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + local_files_only=False, + token=None, ) - self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) - - def test_abs_path_from_pretrained(self): - REPO_NAME = repo_name("FROM_PRETRAINED") - model = DummyModel() - model.save_pretrained( - f"{WORKING_REPO_DIR}/{REPO_NAME}", - config={"num": 10, "act": "gelu_fast"}, - ) - - model = DummyModel.from_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") - self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) - - def test_push_to_hub_via_http_basic(self): - REPO_NAME = repo_name("PUSH_TO_HUB_via_http") - repo_id = f"{USER}/{REPO_NAME}" - + self.assertIs(model, from_pretrained_mock.return_value) + + def test_from_pretrained_to_relative_path(self): + with SoftTemporaryDirectory(dir=Path(".")) as tmp_relative_dir: + relative_save_directory = Path(tmp_relative_dir) / "model" + DummyModel().save_pretrained(relative_save_directory, config=CONFIG) + model = DummyModel.from_pretrained(relative_save_directory) + self.assertDictEqual(model.config, CONFIG) + + def test_from_pretrained_to_absolute_path(self): + save_directory = self.cache_dir / "subfolder" + DummyModel().save_pretrained(save_directory, config=CONFIG) + model = DummyModel.from_pretrained(save_directory) + self.assertDictEqual(model.config, CONFIG) + + def test_from_pretrained_to_absolute_string_path(self): + save_directory = str(self.cache_dir / "subfolder") + DummyModel().save_pretrained(save_directory, config=CONFIG) + model = DummyModel.from_pretrained(save_directory) + self.assertDictEqual(model.config, CONFIG) + + def test_push_to_hub(self): + repo_id = f"{USER}/{repo_name('push_to_hub')}" DummyModel().push_to_hub( - repo_id=repo_id, - api_endpoint=ENDPOINT_STAGING, - token=self._token, - config={"num": 7, "act": "gelu_fast"}, + repo_id=repo_id, api_endpoint=ENDPOINT_STAGING, token=TOKEN, config=CONFIG ) # Test model id exists @@ -152,11 +152,13 @@ def test_push_to_hub_via_http_basic(self): # Test config has been pushed to hub tmp_config_path = hf_hub_download( - repo_id=repo_id, filename="config.json", use_auth_token=self._token + repo_id=repo_id, + filename="config.json", + use_auth_token=TOKEN, + cache_dir=self.cache_dir, ) with open(tmp_config_path) as f: - self.assertEqual(json.load(f), {"num": 7, "act": "gelu_fast"}) + self.assertDictEqual(json.load(f), CONFIG) - # Delete tmp file and repo - os.remove(tmp_config_path) + # Delete repo self._api.delete_repo(repo_id=repo_id)