Skip to content

Commit

Permalink
[Trainer] Support recompute for trainer. (#3261)
Browse files Browse the repository at this point in the history
* support recompute for trainer.
  • Loading branch information
ZHUI authored Sep 16, 2022
1 parent 8ca5cd8 commit fb4eec7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
7 changes: 7 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并

The value of initial scale_loss for fp16. (default: 32768)

--recompute
是否使用重计算训练。可以节省显存。
重新计算前向过程以获取梯度,减少中间变量显存
(`bool`, 可选, 默认为 `False`)

Recompute the forward pass to calculate gradients. Used for saving memory (default: False)

--minimum_eval_times
最少评估次数,如果当前设置的eval_steps,评估次数少于minimum_eval_times,
此选项会覆盖eval_steps参数。
Expand Down
25 changes: 22 additions & 3 deletions paddlenlp/trainer/trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import paddle.nn as nn
import paddle.amp.auto_cast as autocast
import paddle.distributed as dist
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
from paddle.io import (
Dataset,
DataLoader,
Expand Down Expand Up @@ -247,6 +248,15 @@ def __init__(
init_loss_scaling=self.args.scale_loss)
logger.info("Using half precision")

if args.recompute:

def fn(layer):
if type(layer) == paddle.nn.TransformerEncoder or type(
layer) == paddle.nn.TransformerDecoder:
layer.enable_recompute = True

model.apply(fn)

default_label_names = ([
"start_positions", "end_positions"
] if "QusetionAnswering" in type(self.model).__name__ else ["labels"])
Expand Down Expand Up @@ -549,9 +559,13 @@ def train(
self.control = self.callback_handler.on_step_begin(
args, self.state, self.control)

if (((step + 1) % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation):
is_no_sync = (((
(step + 1) % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation)
or (args.recompute and args.local_rank != -1))

if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
Expand All @@ -564,6 +578,11 @@ def train(
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps and
(step + 1) == steps_in_epoch):

if (args.recompute and args.local_rank != -1):
fused_allreduce_gradients(list(model.parameters()),
None)

if self.do_grad_scaling:
self.scaler.minimize(self.optimizer, tr_loss)
else:
Expand Down
12 changes: 12 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class TrainingArguments:
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
For `fp16` training, AMP optimization level selected in ['O0', 'O1', 'O2']. See details at
https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/amp/auto_cast_cn.html
recompute (`bool`, *optional*, defaults to `False`):
Recompute the forward pass to calculate gradients. Used for saving memory.
Only support for networks with transformer blocks.
scale_loss (`float`, *optional*, defaults to 32768):
The value of initial scale_loss for fp16. (default: 32768)
local_rank (`int`, *optional*, defaults to -1):
Expand Down Expand Up @@ -401,6 +404,15 @@ class TrainingArguments:
},
)

recompute: bool = field(
default=False,
metadata={
"help":
"Recompute the forward pass to calculate gradients. Used for saving memory. "
"Only support for networks with transformer blocks."
},
)

scale_loss: float = field(
default=2**15,
metadata={"help": "The value of initial scale_loss for fp16."})
Expand Down

0 comments on commit fb4eec7

Please sign in to comment.