diff --git a/paddlenlp/peft/lora/loraga_utils.py b/paddlenlp/peft/lora/loraga_utils.py index 72b4baac1de2..c7079b6bf456 100644 --- a/paddlenlp/peft/lora/loraga_utils.py +++ b/paddlenlp/peft/lora/loraga_utils.py @@ -16,6 +16,13 @@ import paddle.distributed as dist from paddle.distributed import fleet +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + register_sequence_parallel_allreduce_hooks, + ) +except: + pass + from paddlenlp.peft import LoRAModel from paddlenlp.peft.lora.lora_layers import ( ColumnParallelLoRALinear, @@ -83,6 +90,11 @@ def estimate_gradient(self, model: PretrainedModel): def _wrap_model(self, model): """Wrap Model without optimizer, support dp, tp and sharding""" + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce + ) + in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1 in_sharding_parallel_mode = self.sharding is not None in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1 diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index d0f0f7b2d0f4..a240a2f7f71e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -434,11 +434,6 @@ def _save_ckpt_func(state_dict, path, signal_path=None): "We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config." ) - if args.sequence_parallel: - register_sequence_parallel_allreduce_hooks( - self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce - ) - self.do_grad_scaling = False self.enable_autocast_context_manager = False if args.fp16 or args.bf16: @@ -2054,6 +2049,11 @@ def _wrap_model(self, model, training=True): else: model, self.optimizer = decorated + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce + ) + if self.args.world_size == 1: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)