Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jan 14, 2025
1 parent 3e8b4b1 commit 1dca3e8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,6 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
)

if args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
)

self.do_grad_scaling = False
self.enable_autocast_context_manager = False
if args.fp16 or args.bf16:
Expand Down Expand Up @@ -2054,6 +2049,11 @@ def _wrap_model(self, model, training=True):
else:
model, self.optimizer = decorated

if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
)

if self.args.world_size == 1:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
Expand Down

0 comments on commit 1dca3e8

Please sign in to comment.