diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0719700c0964..a1a0ad7d36fd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2165,7 +2165,8 @@ def find_submodule_and_param_name(model, long_key): for k in loaded_state_dict_keys: submodule, param_name = find_submodule_and_param_name(model, k) if submodule is not None: - new_val = state_dict[k] + param_dtype = getattr(submodule, param_name).dtype + new_val = state_dict[k].to(param_dtype) if isinstance(getattr(submodule, param_name), torch.nn.Parameter): new_val = torch.nn.Parameter(new_val) setattr(submodule, param_name, new_val)