diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index 19e88b402bcf..5988c92111c8 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -305,7 +305,10 @@ def load_resolved_archive_file( ) ) if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != paddle.float32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 14f81599e538..84253e36fcd8 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -637,8 +637,8 @@ def unified_optimizer_into_shards( tp_size = tp_group.nranks dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - no_sync_kname = [] if args.use_expert_parallel: + no_sync_kname = [] for k, v in state_dict.items(): if getattr(state_dict[k], "no_sync", False): no_sync_kname.append(k)