diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index ed28b7c13d1f..57a592029bb1 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -424,6 +424,26 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + no_sync_kname = [] + model_state_dict = get_expected_state_dict(model) + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + if self.args.use_expert_parallel: + for k in list(optim_state_dict.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + optim_state_dict.pop(k) + if master_weights is not None: + for k in list(master_weights.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + master_weights.pop(k) + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -464,7 +484,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) with device_guard(): @@ -566,7 +589,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) return optim_state_dict if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, @@ -1942,7 +1965,7 @@ def get_sharded_file_name(args, file_name, is_optimizer=False): else: hcg = fleet.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() - size = args.world_size if args.use_expert_parallel else dp_group.nranks + size = dp_group.nranks if not args.use_expert_parallel else 1 shard_file = file_name.replace( ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" ) @@ -2246,7 +2269,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if len(tp_actions) > 0: for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") return state_dict_to_save