-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Noisy student training for wenet #1600
Changes from 43 commits
220d7b2
fd2cbeb
92974fd
4d68c70
2f84cd7
bd36149
1c4043b
1683016
0496b1d
7595ac9
b3c880c
0b47162
aa5bd02
5663ec0
b4dfc71
9a059a2
145fe9c
dc53142
463d09a
974ada1
3b6e69b
0918842
2289c79
555edae
a31a7d1
06e58a6
d6751d3
7c8a859
68d974a
91baa96
98eced6
898ccd0
bd341f6
eb58e5b
52dcdfe
6460e82
52c4783
b553f01
40b7b61
62e8cf1
de8bbe2
0141767
908752c
7f5ee03
4b19bb8
74caa24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Recipe to run Noisy Student Training with LM filter in WeNet | ||
|
||
Noisy Student Training (NST) has recently demonstrated extremely strong performance in Automatic Speech Recognition (ASR). | ||
|
||
Here, we provide a recipe to run NST with `LM filter` strategy using AISHELL-1 as supervised data and WenetSpeech as unsupervised data from [this paper](https://arxiv.org/abs/2211.04717), where hypotheses with and without Language Model are generated and CER differences between them are utilized as a filter threshold to improve the ASR performances of non-target domain datas. | ||
|
||
## Table of Contents | ||
|
||
- [Guideline](#guideline) | ||
- [Data preparation](#data-preparation) | ||
- [Initial supervised teacher](#initial-supervised-teacher) | ||
- [Noisy student interations](#noisy-student-interations) | ||
- [Performance Record](#performance-record) | ||
- [Supervised baseline and standard NST](##supervised-baseline-and-standard-nst) | ||
- [Supervised AISHELL-1 and unsupervised 1khr WenetSpeech](#supervised-aishell-1-and-unsupervised-1khr-wenetspeech) | ||
- [Supervised AISHELL-2 and unsupervised 4khr WenetSpeech](#supervised-aishell-2-and-unsupervised-4khr-wenetspeech) | ||
- [Citations](#citations) | ||
|
||
## Guideline | ||
|
||
|
||
First, you have to prepare supervised and unsupervised data for NST. Then in stage 1 of `run.sh`, you will train an initial supervised teacher and generate pseudo labels for unsupervised data. | ||
After that, you can run the noisy student training iteratively in stage 2. The whole pipeline is illustrated in the following picture. | ||
|
||
![plot](local/NST_plot.png) | ||
|
||
### Data preparation | ||
|
||
To run this recipe, you should follow the steps from [WeNet examples](https://github.com/wenet-e2e/wenet/tree/main/examples) to prepare [AISHELL1](https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/s0) and [WenetSpeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0) data. | ||
We extract 1khr data from WenetSpeech and data should be prepared and stored in the following format: | ||
|
||
``` | ||
data/ | ||
├── train/ | ||
├──── data_aishell.list | ||
├──── wenet_1khr.list | ||
├──── wav_dir/ | ||
├──── utter_time.json (optional) | ||
├── dev/ | ||
└── test/ | ||
|
||
``` | ||
- Files `*.list` contain paths for all the data shards for training. | ||
- A Json file containing the audio length should be prepared as `utter_time.json` if you want to apply the `speaking rate` filter. | ||
- A wav_dir contains all the audio data (id.wav) and labels (id.txt which is optional) for unsupervised data. | ||
> **HINTS** We include a tiny example under `data_example` to make it clearer for reproduction. | ||
|
||
### Initial supervised teacher | ||
|
||
To train an initial supervised teacher model, run the following command: | ||
|
||
```bash | ||
bash run.sh --stage 1 --stop-stage 1 | ||
``` | ||
|
||
Full arguments are listed below, you can check `run.sh` and `run_nst.sh` for more information about steps in each stage and their arguments. We used `num_split = 60` and generate shards with different cpu for the experiments in our paper which saved us lots of inference time and data shards generation time. | ||
|
||
```bash | ||
bash run.sh --stage 1 --stop-stage 1 --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --enable_nst 0 --num_split 1 --unsupervised_data_list wenet_1khr.list --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst0.txt --label 1 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar/ --tar_dir data/train/wenet_1khr_tar/ --out_data_list data/train/wenet_1khr.list | ||
``` | ||
- `dir` contains the training parameters. | ||
- `data_list` contains paths for the training data list. | ||
- `supervised_data_list` contains paths for supervised data shards. | ||
- `unsupervised_data_list`contains paths for unsupervised data shards which is used for inference. | ||
- `dir_split` is the directory stores split unsupervised data for parallel computing. | ||
- `out_data_list` is the pseudo label data list file path. | ||
- `enable_nst` indicates whether we train with pseudo label and split data, for initial teacher we set it to 0. | ||
- This recipe uses the default `num_split=1` while we strongly recommend use larger number to decrease the inference and shards generation time. | ||
> **HINTS** If num_split is set to N larger than 1, you need to modify the script in step 4-8 in run_nst.sh to submit N tasks into your own clusters (such as slurm,ngc etc..). | ||
> We strongly recommend to do so since inference and pseudo-data generation is time-consuming. | ||
|
||
### Noisy student interations | ||
|
||
After finishing the initial fully supervised baseline, we now have the mixed list contains both supervised and pseudo data which is `wenet_1khr_nst0.list`. | ||
We will use it as the `data_list` in the training step and the `data_list` for next NST iteration will be generated. | ||
|
||
Here is an example command: | ||
|
||
```bash | ||
bash run.sh --stage 2 --stop-stage 2 --iter_num 2 | ||
``` | ||
|
||
Here we add extra argument `iter_num` for number of NST iterations. Intermediate files are named with `iter_num` as a suffix. | ||
Please check the `run.sh` and `run_nst.sh` scripts for more information about each stage and their arguments. | ||
|
||
## Performance Record | ||
|
||
### Supervised baseline and standard NST | ||
* Non-streaming conformer model with attention rescoring decoder. | ||
* Without filter strategy, first iteration | ||
* Feature info: using FBANK feature, dither, cmvn, online speed perturb | ||
* Training info: lr 0.002, batch size 32, 8 gpu, acc_grad 4, 240 epochs, dither 0.1 | ||
* Decoding info: ctc_weight 0.3, average_num 30 | ||
|
||
|
||
| Supervised | Unsupervised | Test CER | | ||
|--------------------------|--------------|----------| | ||
| AISHELL-1 Only | ---- | 4.85 | | ||
| AISHELL-1+WenetSpeech | ---- | 3.54 | | ||
| AISHELL-1+AISHELL-2 | ---- | 1.01 | | ||
| AISHELL-1 (standard NST) | WenetSpeech | 5.52 | | ||
|
||
|
||
|
||
### Supervised AISHELL-1 and unsupervised 1khr WenetSpeech | ||
* Non-streaming conformer model with attention rescoring decoder. | ||
* Feature info: using FBANK feature | ||
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 | ||
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 | ||
|
||
| # nst iteration | AISHELL-1 test CER | Pseudo CER| Filtered CER | Filtered hours | | ||
|----------------|--------------------|-----------|--------------|----------------| | ||
| 0 | 4.85 | 47.10 | 25.18 | 323 | | ||
| 1 | 4.86 | 37.02 | 20.93 | 436 | | ||
| 2 | 4.75 | 31.81 | 19.74 | 540 | | ||
| 3 | 4.69 | 28.27 | 17.85 | 592 | | ||
| 4 | 4.48 | 26.64 | 14.76 | 588 | | ||
| 5 | 4.41 | 24.70 | 15.86 | 670 | | ||
| 6 | 4.34 | 23.64 | 15.40 | 669 | | ||
| 7 | 4.31 | 23.79 | 15.75 | 694 | | ||
|
||
### Supervised AISHELL-2 and unsupervised 4khr WenetSpeech | ||
* Non-streaming conformer model with attention rescoring decoder. | ||
* Feature info: using FBANK feature | ||
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1 | ||
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75 | ||
|
||
| # nst iteration | AISHELL-2 test CER | Pseudo CER | Filtered CER | Filtered hours | | ||
|----------------|--------------------|------------|--------------|----------------| | ||
| 0 | 5.48 | 30.10 | 11.73 | 1637 | | ||
| 1 | 5.09 | 28.31 | 9.39 | 2016 | | ||
| 2 | 4.88 | 25.38 | 9.99 | 2186 | | ||
| 3 | 4.74 | 22.47 | 10.66 | 2528 | | ||
| 4 | 4.73 | 22.23 | 10.43 | 2734 | | ||
|
||
|
||
|
||
## Citations | ||
|
||
``` bibtex | ||
|
||
@article{chen2022NST, | ||
title={Improving Noisy Student Training on Non-target Domain Data for Automatic Speech Recognition}, | ||
author={Chen, Yu and Wen, Ding and Lai, Junjie}, | ||
journal={arXiv preprint arXiv:2203.15455}, | ||
year={2022} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: conformer | ||
encoder_conf: | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.0 | ||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 15 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
|
||
# decoder related | ||
decoder: transformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 6 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.0 | ||
src_attention_dropout_rate: 0.0 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 1200 | ||
min_length: 0 | ||
token_max_length: 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 16 | ||
|
||
grad_clip: 5 | ||
accum_grad: 4 | ||
max_epoch: 4 | ||
log_interval: 100 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.002 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 25000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: conformer # ~ 30M parameters | ||
encoder_conf: | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.0 | ||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 15 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
|
||
# decoder related | ||
decoder: transformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 6 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.0 | ||
src_attention_dropout_rate: 0.0 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
supervised_dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
unsupervised_dataset_conf: | ||
filter_conf: | ||
max_length: 1200 #40960 | ||
min_length: 10 #0 | ||
token_max_length: 100 # 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: false #true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 32 | ||
|
||
grad_clip: 5 | ||
accum_grad: 4 | ||
max_epoch: 30 | ||
log_interval: 100 | ||
# for full supervised training, just set this pseudo_ratio to 0 | ||
pseudo_ratio: 0.75 | ||
|
||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.002 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 20000 #25000 | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
data/train/shards/shards_000000000.tar | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove data_example/train dir |
||
data/train/shards/shards_000000001.tar | ||
data/train/shards/shards_000000002.tar | ||
data/train/shards/shards_000000003.tar | ||
data/train/shards/shards_000000004.tar | ||
data/train/shards/shards_000000005.tar | ||
data/train/shards/shards_000000006.tar | ||
data/train/shards/shards_000000007.tar | ||
data/train/shards/shards_000000008.tar | ||
data/train/shards/shards_000000009.tar | ||
data/train/shards/shards_000000010.tar |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"ID001": 2.05, | ||
"ID002": 2.75, | ||
"ID003": 3.36 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
你好你好 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove wav_dir |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
data/train/wenet_1khr_tar//dir0_000000.tar | ||
data/train/wenet_1khr_tar//dir0_000001.tar | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
data/train/wenet_1khr_tar//dir0_000000.tar | ||
data/train/wenet_1khr_tar//dir0_000001.tar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is not required, right?