From fb4eec74d053356f0bda4d822473d5ecd36d711b Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Fri, 16 Sep 2022 12:07:22 +0800 Subject: [PATCH] [Trainer] Support recompute for trainer. (#3261) * support recompute for trainer. --- docs/trainer.md | 7 +++++++ paddlenlp/trainer/trainer_base.py | 25 ++++++++++++++++++++++--- paddlenlp/trainer/training_args.py | 12 ++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/docs/trainer.md b/docs/trainer.md index 81f35b0f2823..f141d15a665d 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -395,6 +395,13 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 The value of initial scale_loss for fp16. (default: 32768) + --recompute + 是否使用重计算训练。可以节省显存。 + 重新计算前向过程以获取梯度,减少中间变量显存 + (`bool`, 可选, 默认为 `False`) + + Recompute the forward pass to calculate gradients. Used for saving memory (default: False) + --minimum_eval_times 最少评估次数,如果当前设置的eval_steps,评估次数少于minimum_eval_times, 此选项会覆盖eval_steps参数。 diff --git a/paddlenlp/trainer/trainer_base.py b/paddlenlp/trainer/trainer_base.py index cd7be4a6a7a6..e9b801b36b12 100644 --- a/paddlenlp/trainer/trainer_base.py +++ b/paddlenlp/trainer/trainer_base.py @@ -38,6 +38,7 @@ import paddle.nn as nn import paddle.amp.auto_cast as autocast import paddle.distributed as dist +from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients from paddle.io import ( Dataset, DataLoader, @@ -247,6 +248,15 @@ def __init__( init_loss_scaling=self.args.scale_loss) logger.info("Using half precision") + if args.recompute: + + def fn(layer): + if type(layer) == paddle.nn.TransformerEncoder or type( + layer) == paddle.nn.TransformerDecoder: + layer.enable_recompute = True + + model.apply(fn) + default_label_names = ([ "start_positions", "end_positions" ] if "QusetionAnswering" in type(self.model).__name__ else ["labels"]) @@ -549,9 +559,13 @@ def train( self.control = self.callback_handler.on_step_begin( args, self.state, self.control) - if (((step + 1) % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 - and args._no_sync_in_gradient_accumulation): + is_no_sync = ((( + (step + 1) % args.gradient_accumulation_steps != 0) + and args.local_rank != -1 + and args._no_sync_in_gradient_accumulation) + or (args.recompute and args.local_rank != -1)) + + if is_no_sync: # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): tr_loss_step = self.training_step(model, inputs) @@ -564,6 +578,11 @@ def train( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch): + + if (args.recompute and args.local_rank != -1): + fused_allreduce_gradients(list(model.parameters()), + None) + if self.do_grad_scaling: self.scaler.minimize(self.optimizer, tr_loss) else: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 5f5d33f3a278..9a85680f8d60 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -175,6 +175,9 @@ class TrainingArguments: fp16_opt_level (`str`, *optional*, defaults to 'O1'): For `fp16` training, AMP optimization level selected in ['O0', 'O1', 'O2']. See details at https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/amp/auto_cast_cn.html + recompute (`bool`, *optional*, defaults to `False`): + Recompute the forward pass to calculate gradients. Used for saving memory. + Only support for networks with transformer blocks. scale_loss (`float`, *optional*, defaults to 32768): The value of initial scale_loss for fp16. (default: 32768) local_rank (`int`, *optional*, defaults to -1): @@ -401,6 +404,15 @@ class TrainingArguments: }, ) + recompute: bool = field( + default=False, + metadata={ + "help": + "Recompute the forward pass to calculate gradients. Used for saving memory. " + "Only support for networks with transformer blocks." + }, + ) + scale_loss: float = field( default=2**15, metadata={"help": "The value of initial scale_loss for fp16."})