diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index f35b23f95050..5f3cda836d3b 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -27,6 +27,11 @@ from paddle.distributed import fleet from tqdm.auto import tqdm +try: + from paddle.base import core +except: + core = None + from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool from paddlenlp.trainer.trainer_utils import ExplicitEnum @@ -389,7 +394,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) ) return - if self.args.dataset_rank == 0: + if self.args.dataset_rank == 0 or self.args.use_expert_parallel: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) def save_non_merge_optimizer(self, model, optimizer, output_dir): @@ -422,6 +427,26 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + no_sync_kname = [] + model_state_dict = get_expected_state_dict(model) + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + if self.args.use_expert_parallel: + for k in list(optim_state_dict.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + optim_state_dict.pop(k) + if master_weights is not None: + for k in list(master_weights.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + master_weights.pop(k) + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -462,7 +487,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) with device_guard(): @@ -568,7 +596,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) # If not having merge optimizer, then load non-merge optimizer. if not has_merge_optimizer_safetensors: - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, @@ -588,7 +616,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) return returned_optim_state_dict - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = load_unified_optimizer_locally( self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True ) @@ -651,8 +679,11 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -662,6 +693,7 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # save index json index_optimizer_file, index_master_weight_file = {}, {} @@ -744,7 +776,7 @@ def unlink_shared_memory(self): def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): """ - Only dataset_rank == 0 can enter this function. + Only dataset_rank == 0 or using expert parallel can enter this function. """ index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True) @@ -755,7 +787,14 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa loaded_keys = sharded_metadata["all_checkpoint_keys"] model_state_dict = get_expected_state_dict(model) - expected_keys = set(list(model_state_dict.keys())) + # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. + if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): + expected_keys = set(list(model_state_dict.keys())) + else: + expected_keys = set() + for key in model_state_dict.keys(): + if getattr(model_state_dict[key], "no_sync", False): + expected_keys.add(key) missing_keys = expected_keys - set(loaded_keys) use_fast_set = True @@ -889,11 +928,17 @@ def unified_checkpoint_into_shards( weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME shard_file = get_sharded_file_name(args, weights_name) + # renumerize shard_file name for expert_parallel. + if args.use_expert_parallel: + shard_file = rename_shard_file(args, shard_file, weights_name) + for key, weight in state_dict.items(): index_weight_file[key] = shard_file total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - index_file_list, total_size_list = gather_sharded_object(index_weight_file, total_size) + index_file_list, total_size_list = gather_sharded_object( + index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel + ) sharded_index = get_sharded_index( index_file_list, total_size_list, @@ -931,7 +976,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin model_keys = list(model_state_dict.keys()) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - expected_keys = get_expected_keys(sharded_metadata, model, optimizer) + expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): @@ -955,7 +1000,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), ) - expected_keys_mw = get_expected_keys(sharded_metadata_mw, model, optimizer) + expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) if not isinstance(resolved_archive_file_mw, list): resolved_archive_file_mw = [resolved_archive_file_mw] if len(resolved_archive_file_mw) > 1: @@ -1005,7 +1050,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) @@ -1049,8 +1097,13 @@ def unified_optimizer_into_shards( # get optimizer param mappings static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if args.dataset_rank > 0: # deal with different dataset rank. + continue + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -1060,6 +1113,7 @@ def unified_optimizer_into_shards( if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # filter optimizer param if master_weights is not None: @@ -1087,6 +1141,7 @@ def unified_optimizer_into_shards( optim_state_dict, tp_actions, filter_optim_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1096,6 +1151,7 @@ def unified_optimizer_into_shards( master_weights, tp_actions, filter_master_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1119,12 +1175,18 @@ def unified_optimizer_into_shards( total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) index_optimizer_filelist, total_optim_size_list = gather_sharded_object( - index_optimizer_file, total_optim_size, is_optimizer=True + index_optimizer_file, + total_optim_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) if master_weights is not None: index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( - index_master_weight_file, total_master_weight_size, is_optimizer=True + index_master_weight_file, + total_master_weight_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_master_weight_index = get_sharded_index(index_master_weight_filelist, total_master_weight_size_list) @@ -1175,15 +1237,20 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. local_resume = True - if args.dataset_rank == 0: + if args.dataset_rank == 0 or args.use_expert_parallel: hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 need_files = set() state_dict = get_expected_state_dict(model) for key in state_dict.keys(): filename = index["weight_map"][key] + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue need_files.add(filename) diff_filelist = list(need_files.difference(set(existed_files))) num_diff = paddle.to_tensor([len(diff_filelist)]) @@ -1191,6 +1258,8 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) if pp_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True else: @@ -1243,8 +1312,10 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} if sharding_group.nranks > 1: param2rank = optimizer._param2rank @@ -1269,9 +1340,10 @@ def check_complete(all_filenames): def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. local_resume = True - if args.data_parallel_rank == 0: + if args.data_parallel_rank == 0 or args.use_expert_parallel: need_files = set() state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None) @@ -1279,6 +1351,13 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, if param_rank != sharding_rank: continue + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + + if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + continue + if not is_master_weights: for type_name in typename_set: type_key = key + "/" + type_name @@ -1296,6 +1375,8 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) if sharding_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True @@ -1548,8 +1629,10 @@ def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_check for key in index["weight_map"].keys(): _, typename = key.split("/") typename_set.add(typename) - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} - static2struct_name_mappings = {v.name: k for k, v in get_expected_state_dict(model).items()} + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. send_table, recv_table = create_optimizer_dispatch_table( args, @@ -1671,7 +1754,10 @@ def check_optimizer_param(parameter): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) optim_state_dict[key_name] = optim_state_dict.pop(key) @@ -1745,9 +1831,10 @@ def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: s key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name if has_master_weights: @@ -1872,26 +1959,29 @@ def distributed_send_recv( def get_sharded_file_name(args, file_name, is_optimizer=False): if not is_optimizer: + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + size = sd_degree if args.use_expert_parallel else args.dataset_world_size shard_file = file_name.replace( ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) else: hcg = fleet.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() + size = dp_group.nranks if not args.use_expert_parallel else 1 shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" ) return shard_file @@ -1935,7 +2025,7 @@ def reduce_master_weights_status(has_master_weights=False): return data.item() > 0 -def gather_sharded_object(index_file, total_size, is_optimizer=False): +def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): index_file_list, total_size_list = [], [] @@ -1969,6 +2059,17 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): if len(index_file_list) == 0 and len(total_size_list) == 0: index_file_list = [index_file] total_size_list = [total_size] + + if use_expert_parallel: + data_group = hcg.get_data_parallel_group() + if data_group.nranks > 1: + data_index_file_list = [] + data_total_size_list = [] + dist.all_gather_object(data_index_file_list, index_file_list, data_group) + dist.all_gather_object(data_total_size_list, total_size_list, data_group) + index_file_list = flatten_list(data_index_file_list) + total_size_list = flatten_list(data_total_size_list) + if is_optimizer: sharding_group = hcg.get_sharding_parallel_group() if sharding_group.nranks > 1: @@ -1982,16 +2083,58 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): return index_file_list, total_size_list +def rename_shard_file(args, shard_file, file_name): + """rename shard file when using expert_parallel.""" + assert args.use_expert_parallel, "only expert_parallel need to use this function" + + shard_file_list = [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + data_group = hcg.get_data_parallel_group() + + if tp_group.nranks > 1: + dist.all_gather_object(shard_file_list, shard_file, tp_group) + if pp_group.nranks > 1: + pp_shard_file_list = [] + dist.all_gather_object( + pp_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, pp_group + ) + shard_file_list = flatten_list(pp_shard_file_list) + if data_group.nranks > 1: + data_shard_file_list = [] + dist.all_gather_object( + data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group + ) + shard_file_list = flatten_list(data_shard_file_list) + + new_index = shard_file_list.index(shard_file) + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + shard_file = file_name.replace( + ".pdparams", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", + ) + return shard_file + + def generate_base_static_name(vname): # return base static name and specific type name, like [embedding_0.w_0, moment1_0] if FP32_MASTER in vname: vname = vname.split("_" + FP32_MASTER + "_") return vname[0], vname[1] else: - vname = vname.split(".") - a = vname[0] + "." + vname[1][:3] - b = vname[1][4:] - return a, b + # Directly deal with type names, for example: moe_gate_1_moment1_0. + type_names = optimizer_scalar_name + optimizer_non_scaler_name + for name in type_names: + if name in vname: + a = vname.split(name)[0][:-1] + b = name + return a, b def filter_params(model_to_save, state_dict, is_optimizer=False): @@ -2087,7 +2230,9 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 # filter actions for pipeline mode if hcg.get_pipe_parallel_group().nranks > 1: @@ -2105,6 +2250,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): continue key = filter_keys[i] tensor = state_dict[key] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and not getattr(tensor, "no_sync", False): + continue if key in tp_actions: # Get tensor size tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks @@ -2128,16 +2276,24 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if len(tp_actions) > 0: for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") return state_dict_to_save -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): # Core function for UC hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + no_sync_kname = [] + if model_state_dict is not None: + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) state_dict_to_save = {} max_key_len = max([len(_) for _ in all_filter_keys]) @@ -2149,6 +2305,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) # get base model key model_key = filter_keys[i].split("/")[0] tensor = state_dict[filter_keys[i]] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and model_key not in no_sync_kname: + continue if model_key in tp_actions: # for example: beta1, beta2 if tensor.numel().item() == 1: @@ -2217,7 +2376,7 @@ def get_optimizer_shard_files(optimizer_path, index_filename): return shard_filenames, sharded_metadata -def get_expected_keys(sharded_metadata, model, optimizer): +def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank @@ -2225,11 +2384,23 @@ def get_expected_keys(sharded_metadata, model, optimizer): if in_sharding_parallel_model: params2rank = optimizer._param2rank - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): key_name = key.split("/")[0] + if ( + is_master_weights + and key_name in model_state_dict + and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 + ): + continue + + if args.use_expert_parallel and args.data_parallel_rank > 0: + if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): + continue + static_name = struct2static_name_mappings.get(key_name, None) if in_sharding_parallel_model: