diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index d0e55b392c95..9b162d4a88c1 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -308,7 +308,7 @@ def load_resolved_archive_file( if model_state_dict[key_name[0]].dtype != paddle.float32: key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) else: - # for moe gate with float32 dtype. + # for parameters with float32 dtype, no need to have fp32 master weights. key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 9d4a084cf715..41ba54972efb 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -67,6 +67,7 @@ FP32_MASTER, UnifiedCheckpointOption, filter_params, + filter_sync_parameters, gather_sharded_object, generate_base_static_name, get_expected_state_dict, @@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - no_sync_kname = [] - model_state_dict = get_expected_state_dict(model) - for k, v in model_state_dict.items(): - if getattr(v, "no_sync", False): - no_sync_kname.append(k) - - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 if self.args.use_expert_parallel: - for k in list(optim_state_dict.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - optim_state_dict.pop(k) - if master_weights is not None: - for k in list(master_weights.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - master_weights.pop(k) + model_state_dict = get_expected_state_dict(model) + filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False) optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -518,12 +503,7 @@ def unified_checkpoint_into_shards( if args.use_expert_parallel: # ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0. - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - for key in list(state_dict.keys()): - if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): - state_dict.pop(key) + filter_sync_parameters(state_dict, is_model_weight=True) if config_to_save.tensor_parallel_degree > 1: if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): @@ -631,25 +611,11 @@ def unified_optimizer_into_shards( filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True) filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True) - hcg = fleet.get_hybrid_communicate_group() - tp_group = hcg.get_model_parallel_group() - dp_group = hcg.get_data_parallel_group() + tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() tp_size = tp_group.nranks - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 if args.use_expert_parallel: - no_sync_kname = [] - for k, v in state_dict.items(): - if getattr(state_dict[k], "no_sync", False): - no_sync_kname.append(k) - for key in list(optim_state_dict.keys()): - model_key = key.split("/")[0] - if dp_rank > 0 and model_key not in no_sync_kname: - optim_state_dict.pop(key) - if master_weights is not None: - for key in list(master_weights.keys()): - if dp_rank > 0 and key not in no_sync_kname: - master_weights.pop(key) + filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False) if tp_size > 1: # get tp_actions diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index ccafb244702a..413ca7c47210 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -758,3 +758,31 @@ def save_config(model_to_save): # save generation config if model_to_save.can_generate(): model_to_save.generation_config.save_pretrained(save_directory) + + +def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True): + """Filter sync parameters under expert parallel mode.""" + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + if is_model_weight: + for key in list(model_state_dict.keys()): + if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False): + model_state_dict.pop(key) + else: + no_sync_kname = [] + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + for key in list(optim_state_dict.keys()): + model_key = key.split("/")[0] + if dp_rank > 0 and model_key not in no_sync_kname: + optim_state_dict.pop(key) + + if master_weights is not None: + for key in list(master_weights.keys()): + if dp_rank > 0 and key not in no_sync_kname: + master_weights.pop(key)