diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6780f9b19f14..21afe6fa896e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -609,6 +609,7 @@ def _load_state_dict_into_meta_model( param_name = param_name[len(start_prefix) :] module_name = param_name + set_module_kwargs = {} # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. @@ -619,6 +620,11 @@ def _load_state_dict_into_meta_model( and dtype == torch.float16 ): param = param.to(torch.float32) + + # For backward compatibility with older versions of `accelerate` + # TODO: @sgugger replace this check with version check at the next `accelerate` release + if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters): + set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) @@ -634,6 +640,8 @@ def _load_state_dict_into_meta_model( if old_param is not None: param = param.to(old_param.dtype) + set_module_kwargs["value"] = param + if device_map is None: param_device = "cpu" else: @@ -651,7 +659,8 @@ def _load_state_dict_into_meta_model( elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) elif not load_in_8bit: - set_module_tensor_to_device(model, param_name, param_device, value=param) + # For backward compatibility with older versions of `accelerate` + set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) else: set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)