diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..62742ea1ac65 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -143,6 +143,7 @@ from .utils import reshard as reshard_util from .utils.helper import ( # nested_truncate, broadcast_dp_optimizer, + broadcast_moe_optimizer, distributed_concat, distributed_file, distributed_isfile, @@ -945,7 +946,8 @@ def _inner_training_loop( ((step_control + 1) % args.gradient_accumulation_steps != 0) and availiable_no_sync and args._no_sync_in_gradient_accumulation - ) or (args.recompute and availiable_no_sync) + ) or (args.recompute and availiable_no_sync + ) or (args.use_moe and availiable_no_sync) # sharding # stage1. the same as ddp # stage2. manualy collect gradient on dp group @@ -965,6 +967,14 @@ def _inner_training_loop( tr_loss += tr_loss_step + def fused_allreduce_gradients_no_sync(paramlist, hcg): + paramlist = list(paramlist) + nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)] + moelist = [p for p in paramlist if getattr(p, "no_sync", False)] + if moelist and not self.args.use_moe: + logger.warning("found `no sync` param when `use_moe=False`") + fused_allreduce_gradients(nonmoe_list, hcg) + if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps @@ -983,12 +993,12 @@ def _inner_training_loop( # Case 1: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) + if (args.recompute or args.use_moe) and availiable_no_sync: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Case 2: hack dp with master_grad - if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) + elif dp_master_grad: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap pipeline_parallel_config = ( @@ -1007,8 +1017,7 @@ def _inner_training_loop( self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - + fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() @@ -1028,7 +1037,9 @@ def _inner_training_loop( ) optimizer_was_run = True if self.do_grad_scaling: - scale_before = paddle.assign(self.scaler._scale) + if args.pipeline_parallel_degree > 1: + assert not self.args.use_moe, "pipline moe not work under fp16" + scale_before = self.scaler._scale.numpy() self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler._scale @@ -2042,7 +2053,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, model.train() inputs = self._prepare_inputs(inputs) - + self.timers and self.timers(f"forward-acc-{self._cur_acc_step}").start() with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) @@ -2053,7 +2064,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, self.scaler.scale(loss).backward() else: loss.backward() - + self.timers and self.timers(f"backward-acc-{self._cur_acc_step}").stop() return loss.detach() def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: @@ -2143,6 +2154,19 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) + def _save_moe_weights( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False,): + # save moe optimizer and model state # TODO 默认为冗余存储 + + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) + saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") + paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name)) + with open(saved_signal_path, mode="w+") as f: + f.write("1") + def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" self.runtime_timer.start("checkpoint saving time") @@ -2245,6 +2269,8 @@ def _save_checkpoint(self, model, metrics=None): os.makedirs(output_dir, exist_ok=True) paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + if self.args.use_moe and self.args.data_parallel_rank > 0: + self._save_moe_weights(output_dir) # Maybe delete some older checkpoints. # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False @@ -2476,7 +2502,10 @@ def _load_optimizer_and_scheduler(self, checkpoint): # broadcast optimizer state in dp group if self.args.local_rank != -1: dist.barrier() - opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + if not self.args.use_moe: + opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + # else: + # opt_state_dict = broadcast_moe_optimizer(opt_state_dict) if opt_state_dict is not None: # Load in optimizer and scheduler states diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index d30876a0f2d2..d5d43093565e 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -803,6 +803,10 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) + use_moe: Optional[bool] = field( + default=False, + metadata={"help": "开启moe训练"}, + ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] + if self.use_moe: + order = order[1: -1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) + if self.use_moe: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None @property @@ -1652,12 +1662,16 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) + if self.use_moe: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None - def sharded_name_suffix(self, shard_id=None, pp_id=None): + def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): if self.use_hybrid_parallel: name = [] if self.tensor_parallel_degree > 1: @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None): shard_id = self.sharding_parallel_rank assert isinstance(shard_id, int) name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree)) + if self.use_moe: + if moe_id is None: + moe_id = self.data_parallel_rank + assert isinstance(moe_id, int) + name.append(self._format_name("moe", moe_id, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + if moe_id is None: + moe_id = self.data_parallel_rank + return self._format_name("moe", moe_id, self.data_parallel_degree) return None @property diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 25f593f71e35..d69fd184ec03 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -18,7 +18,7 @@ import os from typing import Any, Optional - +import copy import numpy as np import paddle import paddle.distributed as dist @@ -226,3 +226,108 @@ def broadcast_dp_optimizer(state_dict): state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group) return state_dict + +# def broadcast_moe_optimizer(state_dict): +# if paddle.distributed.get_world_size() <= 1: +# return state_dict + +# logger.info("Start broadcast optimizer in MoE(data) parallel group.") +# try: +# hcg = fleet.get_hybrid_communicate_group() +# dp_group = hcg.get_data_parallel_group() +# src_rank = hcg.get_data_parallel_group_src_rank() +# process_rank = paddle.distributed.get_rank() +# # Don't broadcast optimizer for dp rank is 1. +# if dp_group.nranks <= 1: +# return state_dict +# except: +# dp_group = None +# src_rank = 0 +# process_rank = paddle.distributed.get_rank() + +# if process_rank == src_rank: +# if state_dict is None: +# logger.warning( +# f"Your local rank {paddle.distributed.get_rank()} must have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}" +# ) +# fake_state_dict = [nested_reduce_tensor(state_dict)] +# else: +# fake_state_dict = [None] + +# paddle.distributed.broadcast_object_list( +# fake_state_dict, +# src=src_rank, +# group=dp_group, +# ) +# fake_state_dict = fake_state_dict[0] +# if process_rank != src_rank: +# sync_state_dict = nested_empty_tensor(fake_state_dict) +# else: +# sync_state_dict = state_dict +# logger.info(f"SYNC-state-dict--{sync_state_dict.keys()}") +# sync_state_dict = nested_broadcast_tensor(sync_state_dict, src=src_rank, group=dp_group) +# if process_rank != src_rank: +# master_weights = state_dict.pop('master_weights', {}) +# sync_state_dict['master_weights'].update(master_weights) +# sync_state_dict.update(state_dict) +# state_dict = sync_state_dict +# logger.info("broadcast_moe_optimizer done") +# return state_dict + + +def broadcast_moe_optimizer(state_dict): + + try: + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + src_rank = hcg.get_data_parallel_group_src_rank() + process_rank = paddle.distributed.get_rank() + data_parallel_rank = hcg.get_data_parallel_rank() + # Don't broadcast optimizer for dp rank is 1. + if dp_group.nranks <= 1: + return state_dict + except: + dp_group = None + src_rank = 0 + data_parallel_rank = 0 + process_rank = paddle.distributed.get_rank() + + def _broadcast_moe_optimizer_state(state_dict): + # boardcast_keys + base_state_dict = {"master_weights": {}} + buf = [ + {i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]}, + {i: j.shape for i, j in state_dict["master_weights"].items()}, + {"LR_Scheduler": state_dict.get("LR_Scheduler", {})}, + ] + + dist.broadcast_object_list(buf, src=src_rank, group=dp_group) + # logger.info(f"moe-optimizer-gather-keys{buf}") + for k, s in buf[0].items(): + v = state_dict.get(k, paddle.zeros(s, "float32")).cuda() + v.name = k + # k = k.replace("_fp32_master_0", "") + dist.broadcast(v, src=src_rank, group=dp_group) + logger.info(f"broadcast moe optimizer {k} from {src_rank}") + base_state_dict[k] = v.cpu() + for k, s in buf[1].items(): + v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda() + v.name = k + dist.broadcast(v, src=src_rank, group=dp_group) + logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}") + base_state_dict["master_weights"][k] = v.cpu() + base_state_dict.update(buf[2]) + return base_state_dict + + base_state_dict = _broadcast_moe_optimizer_state(state_dict) + if data_parallel_rank > 0: + master_weight = state_dict.pop("master_weights", {}) + base_state_dict.update(state_dict) + if master_weight: + if "master_weights" in base_state_dict: + base_state_dict["master_weights"].update(master_weight) + else: + base_state_dict["master_weights"] = master_weight + state_dict = base_state_dict + del base_state_dict + return state_dict diff --git a/paddlenlp/trainer/utils/reshard/common.py b/paddlenlp/trainer/utils/reshard/common.py index cc834862e299..61674eeff4bf 100644 --- a/paddlenlp/trainer/utils/reshard/common.py +++ b/paddlenlp/trainer/utils/reshard/common.py @@ -291,7 +291,7 @@ def _opt_name_to_tname(tensor_names, opt_names): (self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights) for k in list(model_weights_tmp.keys()): t_name = structure_name_mapping[k] - self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu() + self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu() del model_weights_tmp[k] # opt diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 56f4c426ce0a..d927de769e2f 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -444,12 +444,17 @@ def filter_func(name): master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group) model_state_dict = self.model.state_dict() + logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}") logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict))) for key, param in model_state_dict.items(): if param.name in master_weights: assert param.shape == master_weights[param.name].shape - paddle.assign(master_weights[param.name].cuda(), model_state_dict[key]) - + paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key]) + elif key in state_dict: + logger.info(f"key: {key} is in state_dict, but not in master_weights") + paddle.assign(state_dict[key], model_state_dict[key]) + else: + logger.info(f"key: {key} is not in state_dict and master_weights") logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict))) state_dict.update(model_state_dict) return state_dict diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index f785a5358af4..4a5c067fed6c 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -818,8 +818,14 @@ def weight_name_suffix(): name.append(f"tp{hcg.get_model_parallel_rank():0>2d}") if hcg.get_pipe_parallel_world_size() > 1: name.append(f"pp{hcg.get_stage_id():0>2d}") + if config and getattr(config, "moe_num_experts", 0): + dp_group = hcg.get_data_parallel_group() + name.append(f"moe{dp_group.rank:0>2d}") return "_".join(name) else: + if config and getattr(config, "moe_num_experts", 0): + rank = paddle.distributed.get_rank() + return f"moe{rank:0>2d}" return None