Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model(

module_name = param_name

# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break

if old_param is not None:
param = param.to(old_param.dtype)

if device_map is None:
param_device = "cpu"
Expand Down