diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5ab58d71eae2..4547f9ca4787 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2322,17 +2322,24 @@ def _init_weights(self, module): init.copy_(module.inv_freq, buffer_value) init.copy_(module.original_inv_freq, buffer_value) - def _initialize_weights(self, module): + def _initialize_weights(self, module, is_remote_code: bool = False): """ Initialize the weights if they are not already initialized. """ if getattr(module, "_is_hf_initialized", False): return + # This check is for remote code that does NOT use either `torch.init` or `transformers.initialization` in `_init_weights` + # which allow to check the flag directly on param. As they don't and write the params in-place, params would be reinitialized + # otherwise if ( - (weight := getattr(module, "weight", None)) is not None - and getattr(weight, "_is_hf_initialized", False) - and not list(module.named_buffers()) + is_remote_code + and all(getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False)) + and all( + getattr(buffer, "_is_hf_initialized", False) + for buffer in module.buffers(recurse=False) + if buffer is not None + ) ): module._is_hf_initialized = True return @@ -2353,20 +2360,20 @@ def initialize_weights(self): if not hasattr(torch.nn.Module, "smart_apply"): # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function # to apply as we go down the graph - def smart_apply(self, fn): + def smart_apply(self, fn, is_remote_code): for module in self.children(): # We found a sub-model: recursively dispatch its own init function now! if isinstance(module, PreTrainedModel): - module.smart_apply(module._initialize_weights) + module.smart_apply(module._initialize_weights, is_remote_code) else: - module.smart_apply(fn) - fn(self) + module.smart_apply(fn, is_remote_code) + fn(self, is_remote_code) return self torch.nn.Module.smart_apply = smart_apply # Let the magic happen with this simple call - self.smart_apply(self._initialize_weights) + self.smart_apply(self._initialize_weights, self.is_remote_code()) def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: r""" @@ -4017,7 +4024,7 @@ def from_pretrained( use_safetensors=use_safetensors, download_kwargs=download_kwargs_with_commit, user_agent=user_agent, - is_remote_code=cls._auto_class is not None, + is_remote_code=cls.is_remote_code(), transformers_explicit_filename=getattr(config, "transformers_weights", None), ) @@ -4227,11 +4234,8 @@ def _finalize_model_loading( missing keys from meta device to their expected device, reinitializing missing weights according to proper distributions, tying the weights and logging the loading report.""" try: - # Adjust `all_tied_weights_keys` before marking them as initialized - model._adjust_tied_keys_with_tied_pointers(loading_info.missing_and_mismatched()) - # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency) - model.mark_tied_weights_as_initialized() + model.mark_tied_weights_as_initialized(loading_info) # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from # meta device (because they were not moved when loading the weights as they were not in the loaded state dict) @@ -4445,35 +4449,6 @@ def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable: def is_backend_compatible(cls): return cls._supports_attention_backend - def _adjust_tied_keys_with_tied_pointers(self, missing_keys: list[str]) -> None: - """ - Adds keys to `self.all_tied_weights_keys` by checking if any group of params - share the same data ptr. It helps us support remote code where the weight tying is - done in old-T5 style, by manually assigning the same module to different param names. - If we don't add them back in `self.all_tied_weights_keys`, they will be re-initialized - and all params in tied group get random weights. - """ - param_pointers = defaultdict(list) - for param_name, param_value in self.state_dict().items(): - param_pointers[param_value.data_ptr()].append(param_name) - - # Filter out params that are already in `self.all_tied_weights_keys` or if all - # are missing params. Missing param groups share the same data ptr by being on `meta` - tied_param_names = [ - names - for names in param_pointers.values() - if len(names) > 1 - and not any(name in self.all_tied_weights_keys.keys() for name in names) - and not all(name in missing_keys for name in names) - ] - - # Create a dummy mapping, it doesn't matter which one is source/target - # because they are already tied - tied_weights_keys_by_pointers = { - param_name: group[0] for group in tied_param_names for param_name in group[1:] - } - self.all_tied_weights_keys.update(tied_weights_keys_by_pointers) - def _move_missing_keys_from_meta_to_device( self, missing_keys: list[str], @@ -4574,7 +4549,7 @@ def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) - key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None } - def mark_tied_weights_as_initialized(self): + def mark_tied_weights_as_initialized(self, loading_info): """Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them later as they will be tied (overwritten) anyway. This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so @@ -4583,6 +4558,23 @@ def mark_tied_weights_as_initialized(self): param = self.get_parameter(tied_param) param._is_hf_initialized = True + # Some remote code models define module tying (not parameter tying) in their __init__. When modules themselves are shared, + # weights inside both modules appear in the `state_dict` but only one will appear in the safetensors checkpoints + # as they are inherently tied because the 2 modules are the same object. In this case, once we load a parameter + # inside one of the 2 modules, the other will also automatically be loaded and will have the `_is_hf_initialized` + # flag (because we call `setattr` with the loaded param on the module, which is the same object), but its counterpart + # will still appear as a missing key as we never get it out of the set (because it appears in the state_dict as well). + # So we remove it now - otherwise it's considered missing and will be wrongly reinitialized + # Note: this is never an issue in main Transformers, as we never do module-tying, only parameter-tying, and we know + # which params are supposed to be tied to which other params + if self.is_remote_code(): + # Remove those that are already initialized, but appear as missing due to module tying + loading_info.missing_keys = { + key + for key in loading_info.missing_keys + if not getattr(self.get_parameter_or_buffer(key), "_is_hf_initialized", False) + } + def get_parameter_or_buffer(self, target: str): """ Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines @@ -4629,6 +4621,10 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) + @classmethod + def is_remote_code(cls) -> bool: + return cls._auto_class is not None + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a3abfeae295b..f72cce49da98 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -5604,9 +5604,9 @@ def seeded_weight_init(): # `_init_weights` so that it can add the seed for composite models as well) original_initialize_weights = PreTrainedModel._initialize_weights - def seeded_initialize_weights(self, module): + def seeded_initialize_weights(*args, **kwargs): set_seed(42) - original_initialize_weights(self, module) + original_initialize_weights(*args, **kwargs) PreTrainedModel._initialize_weights = seeded_initialize_weights @@ -5623,7 +5623,7 @@ def skip_weight_init(): original_initialize_weights = PreTrainedModel._initialize_weights # Just do nothing instead - def skip_initialize_weights(self, module): + def skip_initialize_weights(*args, **kwargs): pass PreTrainedModel._initialize_weights = skip_initialize_weights