From 08dd11f8e5b4288268c0e5c5c06aa73a5d25f17b Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 5 Jun 2023 18:29:07 -0500 Subject: [PATCH] Add overloads for PretrainedModel.from_pretrained Fixes #23980; move output_loading_info from variadic kwargs to an explicit kwarg and create overloaded signatures based on its value; also add LoadingInfo TypedDict --- src/transformers/modeling_utils.py | 56 +++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fb142432863b..b2b4c983d848 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -25,12 +25,13 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload import torch from packaging import version from torch import Tensor, nn from torch.nn import CrossEntropyLoss +from typing_extensions import Literal, Self, TypedDict from .activations import get_activation from .configuration_utils import PretrainedConfig @@ -1020,6 +1021,13 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +class LoadingInfo(TypedDict): + missing_keys: List[str] + unexpected_keys: List[str] + mismatched_keys: List[Tuple[str, Tuple[int, ...], Tuple[int, ...]]] + error_msgs: List[str] + + class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): r""" Base class for all models. @@ -1917,8 +1925,47 @@ def float(self, *args): else: return super().float(*args) + @overload + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args: Any, + output_loading_info: Literal[False] = ..., + **kwargs: Any, + ) -> Self: + ... + + @overload + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args: Any, + output_loading_info: Literal[True], + **kwargs: Any, + ) -> Tuple[Self, LoadingInfo]: + ... + + @overload + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args: Any, + output_loading_info: bool, + **kwargs: Any, + ) -> Union[Self, Tuple[Self, LoadingInfo]]: + ... + @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args: Any, + output_loading_info: bool = False, + **kwargs: Any, + ) -> Union[Self, Tuple[Self, LoadingInfo]]: r""" Instantiate a pretrained pytorch model from a pre-trained model configuration. @@ -1952,6 +1999,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P arguments `config` and `state_dict`). model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): Can be either: @@ -1995,8 +2044,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or `bool`, *optional*): @@ -2155,7 +2202,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None)