diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 746bcfb9d2a4..c81ef06ebed8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2162,13 +2162,14 @@ def find_submodule_and_param_name(model, long_key): # materialize state_dict entries one by one on CPU for k in loaded_state_dict_keys: - submodule, param_name = find_submodule_and_param_name(model, k) - if submodule is not None: - param_dtype = getattr(submodule, param_name).dtype - new_val = state_dict[k].to(param_dtype) - if isinstance(getattr(submodule, param_name), torch.nn.Parameter): - new_val = torch.nn.Parameter(new_val) - setattr(submodule, param_name, new_val) + if k in state_dict: + submodule, param_name = find_submodule_and_param_name(model, k) + if submodule is not None: + param_dtype = getattr(submodule, param_name).dtype + new_val = state_dict[k].to(param_dtype) + if isinstance(getattr(submodule, param_name), torch.nn.Parameter): + new_val = torch.nn.Parameter(new_val) + setattr(submodule, param_name, new_val) del state_dict