From 5fd6c8a523e4209e556ed77f9c221d86f9e67968 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 4 Apr 2022 12:57:37 +0200 Subject: [PATCH] handle torch_dtype in low cpu mem usage --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0719700c0964..a1a0ad7d36fd 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2165,7 +2165,8 @@ def find_submodule_and_param_name(model, long_key): for k in loaded_state_dict_keys: submodule, param_name = find_submodule_and_param_name(model, k) if submodule is not None: - new_val = state_dict[k] + param_dtype = getattr(submodule, param_name).dtype + new_val = state_dict[k].to(param_dtype) if isinstance(getattr(submodule, param_name), torch.nn.Parameter): new_val = torch.nn.Parameter(new_val) setattr(submodule, param_name, new_val)