Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload

import torch
from packaging import version
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from typing_extensions import Literal, Self, TypedDict

from .activations import get_activation
from .configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -1020,6 +1021,13 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class LoadingInfo(TypedDict):
missing_keys: List[str]
unexpected_keys: List[str]
mismatched_keys: List[Tuple[str, Tuple[int, ...], Tuple[int, ...]]]
error_msgs: List[str]


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -1917,8 +1925,47 @@ def float(self, *args):
else:
return super().float(*args)

@overload
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args: Any,
output_loading_info: Literal[False] = ...,
**kwargs: Any,
) -> Self:
...

@overload
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args: Any,
output_loading_info: Literal[True],
**kwargs: Any,
) -> Tuple[Self, LoadingInfo]:
...

@overload
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args: Any,
output_loading_info: bool,
**kwargs: Any,
) -> Union[Self, Tuple[Self, LoadingInfo]]:
...

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args: Any,
output_loading_info: bool = False,
**kwargs: Any,
) -> Union[Self, Tuple[Self, LoadingInfo]]:
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.

Expand Down Expand Up @@ -1952,6 +1999,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
arguments `config` and `state_dict`).
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
Can be either:

Expand Down Expand Up @@ -1995,8 +2044,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or `bool`, *optional*):
Expand Down Expand Up @@ -2155,7 +2202,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
Expand Down