Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] add amp for U2 conformer #3167

Merged
merged 7 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).

### Recent Update
- 👑 2023.04.25: Add [AMP for U2 conformer](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server).
- 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized.
- 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3).
Expand Down
1 change: 1 addition & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。

### 近期更新
- 👑 2023.04.25: 新增 [U2 conformer 的 AMP 训练](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167).
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)。
- 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) 和 [HiFiGAN](./examples/opencpop/voc5),效果持续优化中。
- 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)。
Expand Down
2 changes: 1 addition & 1 deletion docs/source/released_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |-|
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |-|
[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) | WenetSpeech Dataset | Char-based | 540 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |[FP32](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) </br>[INT8](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz) |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.051968 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |-|
[Ds2 Offline Librispeech ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_offline_librispeech_ckpt_1.0.1.model.tar.gz)| Librispeech Dataset | Char-based | 1.3 GB | 2 Conv + 5 bidirectional LSTM layers| - |0.0467| 960 h | [Ds2 Offline Librispeech ASR0](../../examples/librispeech/asr0) | inference/python |-|
Expand Down
10 changes: 5 additions & 5 deletions examples/aishell/asr1/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ paddlespeech version: 1.0.1

## Conformer Streaming
paddle version: 2.2.2
paddlespeech version: 0.2.0
paddlespeech version: 1.4.1
Need set `decoding.decoding_chunk_size=16` when decoding.

| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.0551 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.0629 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.0544 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention | 16, -1 | - | 0.056102 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 0.058160 |
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |


## Transformer
Expand Down
10 changes: 10 additions & 0 deletions paddlespeech/resource/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.5.0.model.tar.gz',
'md5':
'a0adb2b204902982718bc1d8917f7038',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
},
"transformer_librispeech-en-16k": {
'1.0': {
Expand Down
44 changes: 37 additions & 7 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn.utils import clip_grad_norm_

from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
Expand All @@ -47,14 +48,16 @@ class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)

def train_batch(self, batch_index, batch_data, msg):
def train_batch(self, batch_index, batch_data, scaler, msg):
train_conf = self.config
start = time.time()

# forward
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
with paddle.amp.auto_cast(
level=self.amp_level, enable=True if scaler else False):
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)

# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
Expand All @@ -77,12 +80,26 @@ def train_batch(self, batch_index, batch_data, msg):
# processes.
context = nullcontext
with context():
loss.backward()
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
layer_tools.print_grads(self.model, print_func=None)

# optimizer step
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
# do global grad clip
if train_conf.global_grad_clip != 0:
zxcd marked this conversation as resolved.
Show resolved Hide resolved
if scaler:
scaler.unscale_(self.optimizer)
# need paddlepaddle==develop or paddlepaddle>=2.5
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
if scaler:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
Expand Down Expand Up @@ -173,7 +190,8 @@ def do_train(self):
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.train_batch(batch_index, batch, self.scaler,
msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
Expand Down Expand Up @@ -253,6 +271,19 @@ def setup_model(self):
model_conf.output_dim = self.test_loader.vocab_size

model = U2Model.from_config(model_conf)

# For Mixed Precision Training
self.use_amp = self.config.get("use_amp", True)
self.amp_level = self.config.get("amp_level", "O1")
if self.train and self.use_amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.config.get(
"scale_loss", 32768.0)) #amp default num 32768.0
#Set amp_level
if self.amp_level == 'O2':
model = paddle.amp.decorate(models=model, level=self.amp_level)
else:
self.scaler = None
if self.parallel:
model = paddle.DataParallel(model)

Expand Down Expand Up @@ -290,7 +321,6 @@ def optimizer_args(
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
Expand Down
16 changes: 15 additions & 1 deletion paddlespeech/s2t/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
from collections import OrderedDict
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(self, config, args):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self._train = True
self.scaler = None

# print deps version
all_version()
Expand Down Expand Up @@ -187,8 +189,13 @@ def save(self, tag=None, infos: dict=None):
infos.update({
"step": self.iteration,
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
"lr": self.optimizer.get_lr(),
})
if self.scaler:
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
paddle.save(self.scaler.state_dict(), scaler_path)

self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
Expand All @@ -211,6 +218,13 @@ def resume_or_scratch(self):
# lr will resotre from optimizer ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]

scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
if os.path.exists(scaler_path):
scaler_state_dict = paddle.load(scaler_path)
self.scaler.load_state_dict(scaler_state_dict)

scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
Expand Down