Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] add amp for U2 conformer #3167

Merged
merged 7 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn.utils import clip_grad_norm_

from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
Expand All @@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)

def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch_data, scaler, msg):
train_conf = self.config
start = time.time()

# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
with paddle.amp.auto_cast(
level=self.amp_level, enable=True if scaler else False):
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)

# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
Expand All @@ -77,12 +80,24 @@ def train_batch(self, batch_index, batch_data, msg):
# processes.
context = nullcontext
with context():
loss.backward()
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)

# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
# do global grad clip
if train_conf.global_grad_clip != 0:
zxcd marked this conversation as resolved.
Show resolved Hide resolved
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
if scaler:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
Expand Down Expand Up @@ -173,7 +188,8 @@ def do_train(self):
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.train_batch(batch_index, batch, self.scaler,
msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
Expand Down Expand Up @@ -253,6 +269,19 @@ def setup_model(self):
model_conf.output_dim = self.test_loader.vocab_size

model = U2Model.from_config(model_conf)

# For Mixed Precision Training
self.use_amp = self.config.get("use_amp", True)
self.amp_level = self.config.get("amp_level", "O1")
if self.train and self.use_amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.config.get(
"scale_loss", 32768.0)) #amp default num 32768.0
#Set amp_level
if self.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level=self.amp_level)
else:
self.scaler = None
if self.parallel:
model = paddle.DataParallel(model)

Expand Down Expand Up @@ -290,7 +319,6 @@ def optimizer_args(
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
Expand Down
6 changes: 5 additions & 1 deletion paddlespeech/s2t/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(self, config, args):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
self.scaler = None

# print deps version
all_version()
Expand Down Expand Up @@ -187,7 +188,8 @@ def save(self, tag=None, infos: dict=None):
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
"lr": self.optimizer.get_lr(),
"scaler": self.scaler.state_dict()
})
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
Expand All @@ -211,6 +213,8 @@ def resume_or_scratch(self):
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.scaler = paddle.amp.GradScaler()
zxcd marked this conversation as resolved.
Show resolved Hide resolved
self.scaler.load_state_dict(infos["scaler"])
scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
Expand Down