From faeed0d139d056c327a18965688c31336ff61fcb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jan 2025 09:25:34 -0800 Subject: [PATCH 1/6] An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably. --- src/transformers/modeling_utils.py | 43 ++++++++++--------- .../timm_wrapper/modeling_timm_wrapper.py | 6 +-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8eb2d7439ef3..9d5266f2bcbe 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4336,26 +4336,27 @@ def from_pretrained( return model @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" - if "beta" in key: - return key.replace("beta", "bias") - if "gamma" in key: - return key.replace("gamma", "weight") + if key.endswith("LayerNorm.beta"): + return key.replace("LayerNorm.beta", "LayerNorm.bias"), True + elif key.endswith("LayerNorm.gamma"): + return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True # to avoid logging parametrized weight norm renaming if hasattr(nn.utils.parametrizations, "weight_norm"): if "weight_g" in key: - return key.replace("weight_g", "parametrizations.weight.original0") + return key.replace("weight_g", "parametrizations.weight.original0"), True if "weight_v" in key: - return key.replace("weight_v", "parametrizations.weight.original1") + return key.replace("weight_v", "parametrizations.weight.original1"), True else: if "parametrizations.weight.original0" in key: - return key.replace("parametrizations.weight.original0", "weight_g") + return key.replace("parametrizations.weight.original0", "weight_g"), True if "parametrizations.weight.original1" in key: - return key.replace("parametrizations.weight.original1", "weight_v") - return key + return key.replace("parametrizations.weight.original1", "weight_v"), True + + return key, False @classmethod def _fix_state_dict_keys_on_load(cls, state_dict): @@ -4366,15 +4367,15 @@ def _fix_state_dict_keys_on_load(cls, state_dict): renamed_keys = {} state_dict_keys = list(state_dict.keys()) for key in state_dict_keys: - new_key = cls._fix_state_dict_key_on_load(key) - if new_key != key: + new_key, has_changed = cls._fix_state_dict_key_on_load(key) + if has_changed: state_dict[new_key] = state_dict.pop(key) - # add it once for logging - if "gamma" in key and "gamma" not in renamed_keys: - renamed_keys["gamma"] = (key, new_key) - if "beta" in key and "beta" not in renamed_keys: - renamed_keys["beta"] = (key, new_key) + # track gamma/beta rename for logging + if key.endswith("LayerNorm.gamma"): + renamed_keys["LayerNorm.gamma"] = (key, new_key) + elif key.endswith("LayerNorm.beta"): + renamed_keys["LayerNorm.beta"] = (key, new_key) if renamed_keys: warning_msg = f"A pretrained model of type `{cls.__name__}` " @@ -4387,19 +4388,19 @@ def _fix_state_dict_keys_on_load(cls, state_dict): return state_dict @staticmethod - def _fix_state_dict_key_on_save(key): + def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: """ Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. - Do nothing by default, but can be overriden in particular models. + Do nothing by default, but can be overridden in particular models. """ - return key + return key, False def _fix_state_dict_keys_on_save(self, state_dict): """ Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. """ - return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} + return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()} @classmethod def _load_pretrained_model( diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 47e8944583b4..e160a965c4a9 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -90,15 +90,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """ Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. We don't want this behavior for timm wrapped models. Instead, this method adds a "timm_model." prefix to enable loading official timm Hub checkpoints. """ if "timm_model." not in key: - return f"timm_model.{key}" - return key + return f"timm_model.{key}", True + return key, False def _fix_state_dict_key_on_save(self, key): """ From eed2570d354be99526ccb524c64b58021780037f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jan 2025 10:04:35 -0800 Subject: [PATCH 2/6] Fix fix on load issue --- src/transformers/modeling_utils.py | 2 +- src/transformers/models/timm_wrapper/modeling_timm_wrapper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9d5266f2bcbe..edc0289468c4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4458,7 +4458,7 @@ def _load_pretrained_model( expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) original_loaded_keys = loaded_keys - loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys] + loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys] if len(prefix) > 0: has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index e160a965c4a9..a74202ce5aa5 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -105,7 +105,7 @@ def _fix_state_dict_key_on_save(self, key): Overrides original method to remove "timm_model." prefix from state_dict keys. Makes the saved checkpoint compatible with the `timm` library. """ - return key.replace("timm_model.", "") + return key.replace("timm_model.", ""), True def load_state_dict(self, state_dict, *args, **kwargs): """ From 9a9641c5c9664b1da08355d80b8458a12728a7a0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jan 2025 13:11:14 -0800 Subject: [PATCH 3/6] Fix gamma/beta warning test --- tests/utils/test_modeling_utils.py | 60 +++++++++++++----------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 383f0cbe60e1..ae43ec88f4b0 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1562,57 +1562,49 @@ def test_model_from_pretrained_from_mlx(self): self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"])) def test_warning_for_beta_gamma_parameters(self): - class TestModelGamma(PreTrainedModel): + class TestGammaBetaNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.gamma = torch.nn.Parameter(torch.ones(1)) + self.beta = torch.nn.Parameter(torch.zeros(1)) + + def forward(self): + return self.gamma.sum() + self.beta.sum() + + class TestModelGammaBeta(PreTrainedModel): def __init__(self, config): super().__init__(config) - self.gamma_param = nn.Parameter(torch.ones(10)) + self.LayerNorm = TestGammaBetaNorm() self.post_init() def forward(self): - return self.gamma_param.sum() + return self.LayerNorm() logger = logging.get_logger("transformers.modeling_utils") config = PretrainedConfig() - warning_msg_gamma = "`gamma_param` -> `weight_param`" - model = TestModelGamma(config) + warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`" + warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`" + model = TestModelGammaBeta(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) with LoggingLevel(logging.INFO): with CaptureLogger(logger) as cl1: - _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) + _, loading_info = TestModelGammaBeta.from_pretrained( + tmp_dir, + config=config, + output_loading_info=True + ) missing_keys = loading_info["missing_keys"] unexpected_keys = loading_info["unexpected_keys"] - self.assertIn("`TestModelGamma`", cl1.out) + self.assertIn("`TestModelGammaBeta`", cl1.out) self.assertIn(warning_msg_gamma, cl1.out) - self.assertIn("gamma_param", missing_keys) - self.assertIn("weight_param", unexpected_keys) - - class TestModelBeta(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.beta_param = nn.Parameter(torch.ones(10)) - self.post_init() - - def forward(self): - return self.beta_param.sum() - - warning_msg_beta = "`beta_param` -> `bias_param`" - model = TestModelBeta(config) - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir) - with LoggingLevel(logging.INFO): - with CaptureLogger(logger) as cl2: - _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) - - missing_keys = loading_info["missing_keys"] - unexpected_keys = loading_info["unexpected_keys"] - self.assertIn("`TestModelBeta`", cl2.out) - self.assertIn(warning_msg_beta, cl2.out) - self.assertIn("beta_param", missing_keys) - self.assertIn("bias_param", unexpected_keys) + self.assertIn(warning_msg_beta, cl1.out) + self.assertIn("LayerNorm.gamma", missing_keys) + self.assertIn("LayerNorm.weight", unexpected_keys) + self.assertIn("LayerNorm.beta", missing_keys) + self.assertIn("LayerNorm.bias", unexpected_keys) def test_isin_mps_friendly(self): """tests that our custom `isin_mps_friendly` matches `torch.isin`""" From 0d2211ea13e60277114a6ad249498285fde59613 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Jan 2025 16:53:38 -0800 Subject: [PATCH 4/6] A style complaint --- tests/utils/test_modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ae43ec88f4b0..e90f8aa7d039 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1591,9 +1591,7 @@ def forward(self): with LoggingLevel(logging.INFO): with CaptureLogger(logger) as cl1: _, loading_info = TestModelGammaBeta.from_pretrained( - tmp_dir, - config=config, - output_loading_info=True + tmp_dir, config=config, output_loading_info=True ) missing_keys = loading_info["missing_keys"] From 725d86d04af1211f61c846e1402e63ca5da2f685 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Jan 2025 10:23:52 -0800 Subject: [PATCH 5/6] Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming. --- src/transformers/modeling_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index edc0289468c4..55dbc314383e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4339,21 +4339,25 @@ def from_pretrained( def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" + # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) + # This rename is logged. if key.endswith("LayerNorm.beta"): return key.replace("LayerNorm.beta", "LayerNorm.bias"), True elif key.endswith("LayerNorm.gamma"): return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True - # to avoid logging parametrized weight norm renaming + # Rename weight norm parametrizations to match changes across torch versions. + # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others. + # This rename is not logged. if hasattr(nn.utils.parametrizations, "weight_norm"): - if "weight_g" in key: + if key.endswith("weight_g"): return key.replace("weight_g", "parametrizations.weight.original0"), True - if "weight_v" in key: + elif key.endswith("weight_v"): return key.replace("weight_v", "parametrizations.weight.original1"), True else: - if "parametrizations.weight.original0" in key: + if key.endswith("parametrizations.weight.original0"): return key.replace("parametrizations.weight.original0", "weight_g"), True - if "parametrizations.weight.original1" in key: + elif key.endswith("parametrizations.weight.original1"): return key.replace("parametrizations.weight.original1", "weight_v"), True return key, False From 5bc205c11ff36167b718a34a724374599987c24d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Jan 2025 10:26:09 -0800 Subject: [PATCH 6/6] Habitual elif redunant with the return --- src/transformers/modeling_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 55dbc314383e..573a9322efd1 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4343,7 +4343,7 @@ def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: # This rename is logged. if key.endswith("LayerNorm.beta"): return key.replace("LayerNorm.beta", "LayerNorm.bias"), True - elif key.endswith("LayerNorm.gamma"): + if key.endswith("LayerNorm.gamma"): return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True # Rename weight norm parametrizations to match changes across torch versions. @@ -4352,12 +4352,12 @@ def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: if hasattr(nn.utils.parametrizations, "weight_norm"): if key.endswith("weight_g"): return key.replace("weight_g", "parametrizations.weight.original0"), True - elif key.endswith("weight_v"): + if key.endswith("weight_v"): return key.replace("weight_v", "parametrizations.weight.original1"), True else: if key.endswith("parametrizations.weight.original0"): return key.replace("parametrizations.weight.original0", "weight_g"), True - elif key.endswith("parametrizations.weight.original1"): + if key.endswith("parametrizations.weight.original1"): return key.replace("parametrizations.weight.original1", "weight_v"), True return key, False