Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeezeformer #1447

Merged
merged 29 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
23c4f99
[init] enable SqueezeformerEncoder
Sep 15, 2022
7c24031
[update] enable Squeezeformer training
Sep 15, 2022
76ac435
[update] README.md
Sep 15, 2022
b291d1e
fix formatting issues
yygle Sep 15, 2022
6a36c23
fix formatting issues
yygle Sep 15, 2022
139fc36
fix formatting issues
yygle Sep 15, 2022
084c843
fix formatting issues
yygle Sep 15, 2022
1f4a8b3
[update] change residual connection & add copyrights
yygle Sep 15, 2022
0558fbb
fix formatting issues
yygle Sep 15, 2022
0264049
[update] enlarge adaptive scale dimensions
yygle Sep 15, 2022
be2f56e
fix formatting issues
yygle Sep 15, 2022
ba6825c
fix adaptive scale bugs
yygle Sep 16, 2022
89f133e
[update] encoder.py(fix init weights bugs) and README.md
yygle Sep 20, 2022
78b8077
[update] initialization for input projection
yygle Sep 21, 2022
76fbcf2
fix formatting issues
yygle Sep 21, 2022
cefa4cd
fix formatting issues
yygle Sep 22, 2022
c2f2a05
[update] time reduction layer with conv1d and conv2d
yygle Sep 23, 2022
ba7ed74
fix formatting issues
yygle Sep 23, 2022
027c85c
[update] operators
yygle Sep 23, 2022
ed342f2
[update] experiment results & code format
yygle Sep 25, 2022
ac4013c
[update] experiment results
yygle Sep 25, 2022
6592ae3
[update] streaming support & results, dw_stride trigger
yygle Oct 8, 2022
67e260a
fix formatting issue
yygle Oct 8, 2022
5973352
fix formatting issue
yygle Oct 9, 2022
08c49aa
fix formatting issue
yygle Oct 9, 2022
d777305
fix formatting issue
yygle Oct 9, 2022
cd82d89
[update] SqueezeFormer Large Results
yygle Oct 12, 2022
3c55dde
fix formatting issues
yygle Oct 12, 2022
0824e56
fix format issues
yygle Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions examples/librispeech/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Conformer Result Bidecoder (large)

* Encoder FLOPs(30s): 96,238,430,720, params: 85,709,704
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30
Expand All @@ -18,8 +19,31 @@
| LM-tglarge + attention rescoring | 2.68 | 6.10 |
| LM-fglarge + attention rescoring | 2.65 | 5.98 |

## SqueezeFormer Result (U2++, FFN:2048)

* Encoder info:
* SM12, reduce_idx 5, recover_idx 11, conv1d, batch_norm, syncbn
* encoder_dim 512, output_size 512, head 8, ffn_dim 512*4=2048
* Encoder FLOPs(30s): 82,283,704,832, params: 85,984,648
* Feature info:
* using fbank feature, cmvn, dither, online speed perturb, spec_aug
* Training info:
* train_squeezeformer_bidecoder_large.yaml, kernel size 31
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0
* adamw, lr 8e-4, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0
* Decoding info:
* ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 2.62 | 6.80 | 2.92 | 6.77 |
| ctc prefix beam search | 2.60 | 6.79 | 2.90 | 6.79 |
| attention decoder | 3.06 | 6.90 | 3.38 | 6.82 |
| attention rescoring | 2.33 | 6.29 | 2.57 | 6.22 |

## Conformer Result

* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info: train_conformer.yaml, kernel size 31, lr 0.004, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* Decoding info: ctc_weight 0.5, average_num 30
Expand All @@ -35,6 +59,82 @@
| attention rescoring (beam 50) | 3.12 | 8.55 |
| LM-fglarge + attention rescoring | 3.09 | 7.40 |

## Conformer Result (12 layers, FFN:2048)
* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info: train_squeezeformer.yaml, kernel size 31,
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* AdamW, lr 1e-3, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 3.49 | 9.59 | 3.66 | 9.59 |
| ctc prefix beam search | 3.49 | 9.61 | 3.66 | 9.55 |
| attention decoder | 3.52 | 9.04 | 3.85 | 8.97 |
| attention rescoring | 3.10 | 8.91 | 3.29 | 8.81 |

## SqueezeFormer Result (SM12, FFN:1024)
* Encoder info:
* SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn
* encoder_dim 256, output_size 256, head 4, ffn_dim 256*4=1024
* Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912
* Feature info:
* using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_squeezeformer.yaml, kernel size 31,
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* adamw, lr=1e-3, noamhold, warmup=0.2, hold=0.3, lr_decay=1.0
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 |
| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 |
| attention decoder | 3.59 | 8.74 | 3.75 | 8.70 |
| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 |

## SqueezeFormer Result (SM12, FFN:2048)
* Encoder info:
* SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn
* encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048
* encoder FLOPs(30s): 28,230,473,984, params: 34,827,400
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_squeezeformer.yaml, kernel size 31
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0
* Decoding info:
* ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 |
| ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 |
| attention decoder | 3.64 | 8.62 | 3.91 | 8.33 |
| attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 |

## SqueezeFormer Result (SM12, FFN:1312)
* Encoder info:
* SM12, reduce_idx 5, recover_idx 11, conv1d, w/o syncbn
* encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312
* encoder FLOPs(30s): 34,103,960,008, params: 35,678,352
* Feature info:
* using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_squeezeformer.yaml, kernel size 31,
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0
* adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0
* Decoding info:
* ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 |
| ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 |
| attention decoder | 3.38 | 8.31 | 3.89 | 8.32 |
| attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 |

## Conformer U2++ Result

* Feature info: using fbank feature, cmvn, no speed perturb, dither
Expand All @@ -43,17 +143,48 @@
* Git hash: 65270043fc8c2476d1ab95e7c39f730017a670e0

test clean

| decoding mode | full | 16 |
|--------------------------------|------|------|
| ctc prefix beam search | 3.76 | 4.54 |
| attention rescoring | 3.32 | 3.80 |

test other

| decoding mode | full | 16 |
|--------------------------------|-------|-------|
| ctc prefix beam search | 9.50 | 11.52 |
| attention rescoring | 8.67 | 10.38 |

## SqueezeFormer Result (U2++, FFN:2048)

* Encoder info:
* SM12, reduce_idx 5, recover_idx 11, conv1d, layer_norm, do_rel_shift false
* encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048
* Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400
* Feature info:
* using fbank feature, cmvn, dither, online speed perturb
* Training info:
* train_squeezeformer.yaml, kernel size 31
* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* adamw, lr 1e-3, NoamHold, warmup 0.1, hold 0.4, lr_decay 1.0
* Decoding info:
* ctc_weight 0.3, reverse weight 0.5, average_num 30

test clean

| decoding mode | full | 16 |
|--------------------------------|------|------|
| ctc prefix beam search | 3.81 | 4.59 |
| attention rescoring | 3.36 | 3.93 |

test other

| decoding mode | full | 16 |
|--------------------------------|-------|-------|
| ctc prefix beam search | 9.12 | 11.17 |
| attention rescoring | 8.43 | 10.21 |

## Conformer U2 Result

* Feature info: using fbank feature, cmvn, speed perturb, dither
Expand All @@ -65,6 +196,7 @@ test other
* LM-fglarge: [4-gram.arpa.gz](http://www.openslr.org/resources/11/4-gram.arpa.gz)

test clean

| decoding mode | full | 16 |
|----------------------------------|------|------|
| ctc prefix beam search | 4.26 | 5.00 |
Expand All @@ -76,6 +208,7 @@ test clean
| LM-fglarge + attention rescoring | 3.38 | 3.74 |

test other

| decoding mode | full | 16 |
|----------------------------------|-------|-------|
| ctc prefix beam search | 10.87 | 12.87 |
Expand Down
88 changes: 88 additions & 0 deletions examples/librispeech/s0/conf/train_squeezeformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# network architecture
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 256
output_size: 256 # dimension of attention
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv1d'
feed_forward_expansion_factor: 4
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
cnn_module_kernel: 31
cnn_norm_type: layer_norm
adaptive_scale: true
normalize_before: false

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

# dataset related
dataset_conf:
filter_conf:
max_length: 2000
min_length: 50
token_max_length: 400
token_min_length: 1
min_output_input_ratio: 0.0005
max_output_input_ratio: 0.1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 12

grad_clip: 5
accum_grad: 4
max_epoch: 120
log_interval: 100

optim: adamw
optim_conf:
lr: 1.e-3
weight_decay: 4.e-5

scheduler: NoamHoldAnnealing
scheduler_conf:
warmup_ratio: 0.2
hold_ratio: 0.3
max_steps: 87960
decay_rate: 1.0
min_lr: 1.e-5
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# network architecture
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 512
output_size: 512 # dimension of attention
attention_heads: 8
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 4
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
cnn_module_kernel: 31
cnn_norm_type: batch_norm
adaptive_scale: true
normalize_before: false

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3

# dataset related
dataset_conf:
syncbn: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与这个相关的代码似乎没有提上来?syncbn的转换似乎是可以在Train.py中调用torch api一键完成:

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

这里把他放到dataset_conf域是处于什么考量呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synbn不能直接在wenet中实现主要是因为数据不均衡带来的进程等待,完整的实现中我考虑了两种情况,1. 即DDP数据不分割(每个进程更新完整数据集),以及2. 分割数据集,drop掉多余部分,因此在我这个版本实现中,将这个变量与数据集绑定在了一起。 这个部分的代码因为与Squeezeformer的算法更新无关,属于工程优化范畴,因此会另提交PR更新。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,那我先合并,你继续优化和迭代。

filter_conf:
max_length: 2000
min_length: 50
token_max_length: 400
token_min_length: 1
min_output_input_ratio: 0.0005
max_output_input_ratio: 0.1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 3
num_f_mask: 2
max_t: 100
max_f: 27
max_w: 80
# warp_for_time: true
spec_sub: true
spec_sub_conf:
num_t_sub: 3
max_t: 30
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 12

grad_clip: 5
accum_grad: 4
max_epoch: 500
log_interval: 100

optim: adamw
optim_conf:
lr: 1.e-3
weight_decay: 4.e-5

scheduler: NoamHoldAnnealing
scheduler_conf:
warmup_ratio: 0.2
hold_ratio: 0.3
max_steps: 87960
decay_rate: 1.0
min_lr: 1.e-5

Loading