From 1dca3e811b5915c4ab83db3cc608ef4fc56743af Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 14 Jan 2025 19:06:15 +0800 Subject: [PATCH] update --- paddlenlp/trainer/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index d0f0f7b2d0f4..a240a2f7f71e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -434,11 +434,6 @@ def _save_ckpt_func(state_dict, path, signal_path=None): "We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config." ) - if args.sequence_parallel: - register_sequence_parallel_allreduce_hooks( - self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce - ) - self.do_grad_scaling = False self.enable_autocast_context_manager = False if args.fp16 or args.bf16: @@ -2054,6 +2049,11 @@ def _wrap_model(self, model, training=True): else: model, self.optimizer = decorated + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce + ) + if self.args.world_size == 1: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)