diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py
index c2ccbe4d240b..4e37f44b4d19 100644
--- a/paddlenlp/trainer/unified_checkpoint/load_local.py
+++ b/paddlenlp/trainer/unified_checkpoint/load_local.py
@@ -149,14 +149,6 @@ def _remove_unused_keys(
 
 
 def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
-    # Special process with split param.
-    if is_sharding_split_param_mode(args):
-        returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
-        return returned_optim_state_dict
-
-    # init and get optimizer LR_Scheduler
-    returned_optim_state_dict = nested_copy(optimizer.state_dict())
-
     if not safe_serialization:
         index_filename, index_filename_master_weights = (
             PADDLE_OPTIMIZER_INDEX_NAME,
@@ -165,6 +157,23 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
     else:
         index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
 
+    with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
+        index = json.loads(f.read())
+
+    ckpt_quant_stage = "O0"
+    if "ckpt_quant_stage" in index:
+        ckpt_quant_stage = index["ckpt_quant_stage"]
+
+    # Special process with split param.
+    if is_sharding_split_param_mode(args):
+        returned_optim_state_dict = load_unified_optimizer_split_param(
+            args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
+        )
+        return returned_optim_state_dict
+
+    # init and get optimizer LR_Scheduler
+    returned_optim_state_dict = nested_copy(optimizer.state_dict())
+
     resolved_archive_file, sharded_metadata = get_optimizer_shard_files(
         optimizer_path=resume_from_checkpoint,
         index_filename=os.path.join(resume_from_checkpoint, index_filename),
@@ -184,13 +193,6 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
     if len(resolved_archive_file) > 1:
         resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
 
-    with open(os.path.join(resume_from_checkpoint, index_filename), "r") as f:
-        index = json.loads(f.read())
-
-    ckpt_quant_stage = "O0"
-    if "ckpt_quant_stage" in index:
-        ckpt_quant_stage = index["ckpt_quant_stage"]
-
     # update has_master_weights and index_filename_master_weights
     # 1. if the master weight exists, only has_master_weights is set True and loaded when needed
     # 2. if master weight does not exist, convert model weight to master weight when needed
diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
index fda80fca0a61..19e88b402bcf 100644
--- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
+++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
@@ -36,18 +36,25 @@
     get_expected_state_dict,
     get_optimizer_shard_files,
     mapping_optimizer_tp_actions,
+    update_master_weight_status,
 )
 
 __all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]
 
 
 def merge_splited_param(
-    state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
+    state_dict,
+    partial_tensor_list,
+    param_shape_info,
+    send_table,
+    recv_table,
+    is_master_weights=False,
+    ckpt_quant_stage="O0",
 ):
     """Merge the splited param in sharding group."""
     global_rank = dist.get_rank()
     for key in list(state_dict.keys()):
-        if state_dict[key].numel().item() == 1:  # for example: beta1, beta2
+        if int(state_dict[key].numel()) == 1:  # for example: beta1, beta2
             continue
 
         static_name = key if is_master_weights else generate_base_static_name(key)[0]
@@ -89,10 +96,21 @@ def merge_splited_param(
                     )
                     dist.stream.send(tensor, dst=recv_rank)
                     state_dict.pop(key)
+
+    if ckpt_quant_stage != "O0":
+        for key in list(state_dict.keys()):
+            if int(state_dict[key].numel()) == 1:  # for example: beta1, beta2
+                static_name = key if is_master_weights else generate_base_static_name(key)[0]
+                if static_name in partial_tensor_list:
+                    recv_rank = recv_table[static_name]
+                    send_info = send_table[static_name]
+                    if global_rank != recv_rank:
+                        state_dict.pop(key)
+
     return state_dict
 
 
-def gather_splited_param_for_optimizer(optimizer):
+def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
     hcg = fleet.get_hybrid_communicate_group()
     sharding_group = hcg.get_sharding_parallel_group()
     global_rank = dist.get_rank()
@@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
     for key in list(optim_state_dict.keys()):
         static_name, _ = generate_base_static_name(key)
         if static_name in param_slice_info.keys():
-            if optim_state_dict[key].numel().item() == 1:  # for example: beta1, beta2
+            if int(optim_state_dict[key].numel()) == 1:  # for example: beta1, beta2
                 continue
             begin, end = param_slice_info[static_name]
             shape, numel, _, _ = param_shape_info[static_name]
@@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
         recv_table[key] = sharding_ranklist[0][0]  # which sharding_rank to recv the splited tensor
         send_table[key] = [(rank, begin, end) for rank, begin, end in sharding_ranklist]
 
-    merge_splited_param(optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False)
+    merge_splited_param(
+        optim_state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, False, ckpt_quant_stage
+    )
     if master_weights is not None:
         merge_splited_param(master_weights, partial_tensor_list, param_shape_info, send_table, recv_table, True)
     return optim_state_dict, master_weights
 
 
-def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
+def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
     returned_optim_state_dict = nested_copy(optimizer.state_dict())
 
     index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
     if len(resolved_archive_file) > 1:
         resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
 
+    has_master_weights, index_filename_master_weights = update_master_weight_status(
+        args, optimizer, has_master_weights, safe_serialization=True
+    )
+
     if has_master_weights:
         returned_optim_state_dict["master_weights"] = {}
         resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files(
@@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
         if len(resolved_archive_file_mw) > 1:
             resolved_archive_file_mw = tqdm(resolved_archive_file_mw, desc="Loading master weights shards")
 
-    def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False):
+    def load_resolved_archive_file(
+        resolved_archive_file, sharded_metadata, expected_keys, is_master_weights=False, ckpt_quant_stage="O0"
+    ):
         returned_state_dict = {}
 
         if model.config.tensor_parallel_degree > 1:
@@ -232,9 +258,21 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
             if expected_keys.isdisjoint(sharded_metadata["file_map"][os.path.split(shard_file)[-1]]):
                 continue
             if model.config.tensor_parallel_degree > 1:
-                state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="cpu")
+                state_dict = load_state_dict(
+                    shard_file,
+                    tp_actions,
+                    expected_keys,
+                    device="cpu",
+                    ckpt_quant_stage=ckpt_quant_stage,
+                )
             else:
-                state_dict = load_state_dict(shard_file, None, expected_keys, device="cpu")
+                state_dict = load_state_dict(
+                    shard_file,
+                    None,
+                    expected_keys,
+                    device="cpu",
+                    ckpt_quant_stage=ckpt_quant_stage,
+                )
             returned_state_dict.update(state_dict)
             del state_dict
             gc.collect()
@@ -242,14 +280,16 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
         return returned_state_dict
 
     # get tp params
-    state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys_optim)
+    state_dict_optim = load_resolved_archive_file(
+        resolved_archive_file, sharded_metadata, expected_keys_optim, ckpt_quant_stage=ckpt_quant_stage
+    )
 
     # need to split param for different sharding rank, maybe need to deal with oom issue.
     for key in list(state_dict_optim.keys()):
         key_name = key.split("/")
         static_name = struct2static_name_mappings.get(key_name[0], None)
 
-        if state_dict_optim[key].numel().item() > 1:
+        if int(state_dict_optim[key].numel()) > 1:
             begin, end = param_slice_info[static_name]
             shape, numel, index, padded_size = param_shape_info[static_name]
             state_dict_optim[key] = state_dict_optim[key].reshape([-1])
@@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
 
         for key in list(state_dict_master_weight.keys()):
             static_name = struct2static_name_mappings.get(key, None)
-            if state_dict_master_weight[key].numel().item() > 1:
+            if int(state_dict_master_weight[key].numel()) > 1:
                 begin, end = param_slice_info[static_name]
                 shape, numel, index, padded_size = param_shape_info[static_name]
                 state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
@@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
                 paddle.framework._current_expected_place(), False
             )
             returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
+
+            # master weight cast (only in remove_master_weight)
+            if returned_optim_state_dict["master_weights"][static_name].dtype != paddle.float32:
+                returned_optim_state_dict["master_weights"][static_name] = paddle.cast(
+                    returned_optim_state_dict["master_weights"][static_name], dtype=paddle.float32
+                )
+
             returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
 
     return returned_optim_state_dict
diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
index 0cb38bec94bb..2e453c54e80d 100644
--- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
+++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
@@ -344,7 +344,9 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
             return
 
         if is_sharding_split_param_mode(self.args):
-            optim_state_dict, master_weights = gather_splited_param_for_optimizer(optimizer)
+            optim_state_dict, master_weights = gather_splited_param_for_optimizer(
+                optimizer, self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0"
+            )
         else:
             optim_state_dict = nested_copy(optimizer.state_dict())
             master_weights = None