From a218bf8de62bc27e98b3b8483768823a4b3e9d76 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 30 Jul 2024 15:58:33 +0800 Subject: [PATCH] update release_grads --- paddlenlp/trainer/trainer.py | 9 +++++---- paddlenlp/trainer/training_args.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b42e596e97e4..79c1065ad996 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1062,11 +1062,12 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if optimizer_was_run: self.lr_scheduler.step() - if enable_release_grads and args.pipeline_parallel_degree > 1: + if args.release_grads or enable_release_grads: self.optimizer.clear_grad(set_to_zero=False) - for _, buffers in model._chunk_2_comm_buffers.items(): - for buffer in buffers: - buffer._clear_grad_storage() + if args.pipeline_parallel_degree > 1: + for _, buffers in model._chunk_2_comm_buffers.items(): + for buffer in buffers: + buffer._clear_grad_storage() else: self.optimizer.clear_grad() diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index b31e55d7b4f0..99a91296c057 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -344,6 +344,8 @@ class TrainingArguments: Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc. distributed_dataloader (`bool`, *optional*): Whether to use distributed dataloader. Default is `False`. + release_grads (`bool`, *optional*): + Whether to release gradients during training. Default is `False`. """ output_dir: str = field( @@ -791,6 +793,9 @@ class TrainingArguments: default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, ) + release_grads: Optional[bool] = field( + default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."} + ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))