diff --git a/paddleformers/mergekit/merge_model.py b/paddleformers/mergekit/merge_model.py index ade00dc80dc..1ec9b05a03b 100644 --- a/paddleformers/mergekit/merge_model.py +++ b/paddleformers/mergekit/merge_model.py @@ -590,7 +590,7 @@ def shard_lora_merge(self, base_index, shard_file, lora_config, file_type_list, tensor = paddle.Tensor.__call__(tensor, zero_copy=True) lora_A_tensor = paddle.Tensor.__call__(lora_A_tensor, zero_copy=True) lora_B_tensor = paddle.Tensor.__call__(lora_B_tensor, zero_copy=True) - if self.is_cpu and is_bf16 or self.merge_config.save_to_hf: + if self.is_cpu and is_bf16: tensor = tensor.astype("float32") lora_A_tensor = lora_A_tensor.astype("float32") lora_B_tensor = lora_B_tensor.astype("float32") diff --git a/paddleformers/transformers/model_utils.py b/paddleformers/transformers/model_utils.py index f7fce5b0d94..a7ef574cd88 100644 --- a/paddleformers/transformers/model_utils.py +++ b/paddleformers/transformers/model_utils.py @@ -598,7 +598,7 @@ def load_state_dict( def prepare_safe_save_state_dict(state_dict, save_to_hf=False): for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): - if save_to_hf: + if state_dict[k].dtype == paddle.bfloat16: state_dict[k] = state_dict.pop(k).astype("float32").cpu().numpy().astype(ml_dtypes.bfloat16) else: state_dict[k] = state_dict.pop(k).cpu().numpy()