Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def _load_state_dict_into_meta_model(
param_name = param_name[len(start_prefix) :]

module_name = param_name
set_module_kwargs = {}

# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
Expand All @@ -619,6 +620,11 @@ def _load_state_dict_into_meta_model(
and dtype == torch.float16
):
param = param.to(torch.float32)

# For backward compatibility with older versions of `accelerate`
# TODO: @sgugger replace this check with version check at the next `accelerate` release
if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)

Expand All @@ -634,6 +640,8 @@ def _load_state_dict_into_meta_model(
if old_param is not None:
param = param.to(old_param.dtype)

set_module_kwargs["value"] = param

if device_map is None:
param_device = "cpu"
else:
Expand All @@ -651,7 +659,8 @@ def _load_state_dict_into_meta_model(
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)

Expand Down