diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 409dd88d0c78..dbf93b853ab1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -593,13 +593,22 @@ def _load_state_dict_into_meta_model( module_name = param_name - # We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params + # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. if dtype is not None and torch.is_floating_point(param): param = param.to(dtype) - # For compatibility with PyTorch which loads float16/bfloat16 weights in fp32 - if is_safetensors and dtype is None and torch.is_floating_point(param): - param = param.to(torch.float32) + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + if dtype is None: + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + if old_param is None: + break + + if old_param is not None: + param = param.to(old_param.dtype) if device_map is None: param_device = "cpu"