Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 40 additions & 44 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,17 +2322,24 @@ def _init_weights(self, module):
init.copy_(module.inv_freq, buffer_value)
init.copy_(module.original_inv_freq, buffer_value)

def _initialize_weights(self, module):
def _initialize_weights(self, module, is_remote_code: bool = False):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return

# This check is for remote code that does NOT use either `torch.init` or `transformers.initialization` in `_init_weights`
# which allow to check the flag directly on param. As they don't and write the params in-place, params would be reinitialized
# otherwise
if (
(weight := getattr(module, "weight", None)) is not None
and getattr(weight, "_is_hf_initialized", False)
and not list(module.named_buffers())
is_remote_code
and all(getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False))
and all(
getattr(buffer, "_is_hf_initialized", False)
for buffer in module.buffers(recurse=False)
if buffer is not None
)
):
module._is_hf_initialized = True
return
Expand All @@ -2353,20 +2360,20 @@ def initialize_weights(self):
if not hasattr(torch.nn.Module, "smart_apply"):
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
# to apply as we go down the graph
def smart_apply(self, fn):
def smart_apply(self, fn, is_remote_code):
for module in self.children():
# We found a sub-model: recursively dispatch its own init function now!
if isinstance(module, PreTrainedModel):
module.smart_apply(module._initialize_weights)
module.smart_apply(module._initialize_weights, is_remote_code)
else:
module.smart_apply(fn)
fn(self)
module.smart_apply(fn, is_remote_code)
fn(self, is_remote_code)
return self

torch.nn.Module.smart_apply = smart_apply

# Let the magic happen with this simple call
self.smart_apply(self._initialize_weights)
self.smart_apply(self._initialize_weights, self.is_remote_code())

def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
r"""
Expand Down Expand Up @@ -4017,7 +4024,7 @@ def from_pretrained(
use_safetensors=use_safetensors,
download_kwargs=download_kwargs_with_commit,
user_agent=user_agent,
is_remote_code=cls._auto_class is not None,
is_remote_code=cls.is_remote_code(),
transformers_explicit_filename=getattr(config, "transformers_weights", None),
)

Expand Down Expand Up @@ -4227,11 +4234,8 @@ def _finalize_model_loading(
missing keys from meta device to their expected device, reinitializing missing weights according to proper
distributions, tying the weights and logging the loading report."""
try:
# Adjust `all_tied_weights_keys` before marking them as initialized
model._adjust_tied_keys_with_tied_pointers(loading_info.missing_and_mismatched())

# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
model.mark_tied_weights_as_initialized()
model.mark_tied_weights_as_initialized(loading_info)

# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
Expand Down Expand Up @@ -4445,35 +4449,6 @@ def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
def is_backend_compatible(cls):
return cls._supports_attention_backend

def _adjust_tied_keys_with_tied_pointers(self, missing_keys: list[str]) -> None:
"""
Adds keys to `self.all_tied_weights_keys` by checking if any group of params
share the same data ptr. It helps us support remote code where the weight tying is
done in old-T5 style, by manually assigning the same module to different param names.
If we don't add them back in `self.all_tied_weights_keys`, they will be re-initialized
and all params in tied group get random weights.
"""
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
param_pointers[param_value.data_ptr()].append(param_name)

# Filter out params that are already in `self.all_tied_weights_keys` or if all
# are missing params. Missing param groups share the same data ptr by being on `meta`
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]

# Create a dummy mapping, it doesn't matter which one is source/target
# because they are already tied
tied_weights_keys_by_pointers = {
param_name: group[0] for group in tied_param_names for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)

def _move_missing_keys_from_meta_to_device(
self,
missing_keys: list[str],
Expand Down Expand Up @@ -4574,7 +4549,7 @@ def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -
key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
}

def mark_tied_weights_as_initialized(self):
def mark_tied_weights_as_initialized(self, loading_info):
"""Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
later as they will be tied (overwritten) anyway.
This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so
Expand All @@ -4583,6 +4558,23 @@ def mark_tied_weights_as_initialized(self):
param = self.get_parameter(tied_param)
param._is_hf_initialized = True

# Some remote code models define module tying (not parameter tying) in their __init__. When modules themselves are shared,
# weights inside both modules appear in the `state_dict` but only one will appear in the safetensors checkpoints
# as they are inherently tied because the 2 modules are the same object. In this case, once we load a parameter
# inside one of the 2 modules, the other will also automatically be loaded and will have the `_is_hf_initialized`
# flag (because we call `setattr` with the loaded param on the module, which is the same object), but its counterpart
# will still appear as a missing key as we never get it out of the set (because it appears in the state_dict as well).
# So we remove it now - otherwise it's considered missing and will be wrongly reinitialized
# Note: this is never an issue in main Transformers, as we never do module-tying, only parameter-tying, and we know
# which params are supposed to be tied to which other params
if self.is_remote_code():
# Remove those that are already initialized, but appear as missing due to module tying
loading_info.missing_keys = {
key
for key in loading_info.missing_keys
if not getattr(self.get_parameter_or_buffer(key), "_is_hf_initialized", False)
}

def get_parameter_or_buffer(self, target: str):
"""
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
Expand Down Expand Up @@ -4629,6 +4621,10 @@ def train(self, mode: bool = True):
def eval(self):
return self.train(False)

@classmethod
def is_remote_code(cls) -> bool:
return cls._auto_class is not None


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5604,9 +5604,9 @@ def seeded_weight_init():
# `_init_weights` so that it can add the seed for composite models as well)
original_initialize_weights = PreTrainedModel._initialize_weights

def seeded_initialize_weights(self, module):
def seeded_initialize_weights(*args, **kwargs):
set_seed(42)
original_initialize_weights(self, module)
original_initialize_weights(*args, **kwargs)

PreTrainedModel._initialize_weights = seeded_initialize_weights

Expand All @@ -5623,7 +5623,7 @@ def skip_weight_init():
original_initialize_weights = PreTrainedModel._initialize_weights

# Just do nothing instead
def skip_initialize_weights(self, module):
def skip_initialize_weights(*args, **kwargs):
pass

PreTrainedModel._initialize_weights = skip_initialize_weights
Expand Down