-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Squeezeformer #1447
Merged
+2,197
−8
Merged
Squeezeformer #1447
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
23c4f99
[init] enable SqueezeformerEncoder
7c24031
[update] enable Squeezeformer training
76ac435
[update] README.md
b291d1e
fix formatting issues
yygle 6a36c23
fix formatting issues
yygle 139fc36
fix formatting issues
yygle 084c843
fix formatting issues
yygle 1f4a8b3
[update] change residual connection & add copyrights
yygle 0558fbb
fix formatting issues
yygle 0264049
[update] enlarge adaptive scale dimensions
yygle be2f56e
fix formatting issues
yygle ba6825c
fix adaptive scale bugs
yygle 89f133e
[update] encoder.py(fix init weights bugs) and README.md
yygle 78b8077
[update] initialization for input projection
yygle 76fbcf2
fix formatting issues
yygle cefa4cd
fix formatting issues
yygle c2f2a05
[update] time reduction layer with conv1d and conv2d
yygle ba7ed74
fix formatting issues
yygle 027c85c
[update] operators
yygle ed342f2
[update] experiment results & code format
yygle ac4013c
[update] experiment results
yygle 6592ae3
[update] streaming support & results, dw_stride trigger
yygle 67e260a
fix formatting issue
yygle 5973352
fix formatting issue
yygle 08c49aa
fix formatting issue
yygle d777305
fix formatting issue
yygle cd82d89
[update] SqueezeFormer Large Results
yygle 3c55dde
fix formatting issues
yygle 0824e56
fix format issues
yygle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Develop Record | ||
|
||
``` | ||
squeezeformer | ||
├── attention.py # reltive multi-head attention module | ||
├── conv2d.py # self defined conv2d valid padding module | ||
├── convolution.py # convolution module in squeezeformer block | ||
├── encoder_layer.py # squeezeformer encoder layer | ||
├── encoder.py # squeezeformer encoder class | ||
├── positionwise_feed_forward.py # feed forward layer | ||
├── subsampling.py # sub-sampling layer, time reduction layer | ||
└── utils.py # residual connection module | ||
``` | ||
|
||
* Implementation Details | ||
* Squeezeformer Encoder | ||
* [x] add pre layer norm before squeezeformer block | ||
* [x] derive time reduction layer from tensorflow version | ||
* [x] enable adaptive scale operation | ||
* [x] enable init weights for deep model training | ||
* [x] enable training config and results | ||
* [x] enable dynamic chunk and JIT export | ||
* Training | ||
* [x] enable NoamHoldAnnealing schedular | ||
|
||
# Performance Record | ||
|
||
### Conformer | ||
|
||
* encoder flops(30s): 2,797,274,624, params: 34,761,608 | ||
|
||
### Squeezeformer Result (SM12, FFN:1024) | ||
|
||
* encoder flops(30s): 21,158,877,440, params: 22,219,912 | ||
* Feature info: using fbank feature, cmvn, dither, online speed perturb | ||
* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size | ||
12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 | ||
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 | ||
|
||
| decoding mode | dev clean | dev other | test clean | test other | | ||
|----------------------------------|-----------|-----------|------------|------------| | ||
| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | | ||
| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | | ||
| attention decoder | 8.74 | 3.59 | 3.75 | 8.70 | | ||
| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | |
88 changes: 88 additions & 0 deletions
88
examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: squeezeformer | ||
encoder_conf: | ||
encoder_dim: 256 | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
num_blocks: 12 # the number of encoder blocks | ||
reduce_idx: 5 | ||
recover_idx: 11 | ||
feed_forward_expansion_factor: 4 | ||
input_dropout_rate: 0.1 | ||
feed_forward_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
cnn_module_kernel: 31 | ||
cnn_norm_type: layer_norm | ||
adaptive_scale: true | ||
normalize_before: false | ||
|
||
# decoder related | ||
decoder: transformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 6 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
self_attention_dropout_rate: 0.0 | ||
src_attention_dropout_rate: 0.0 | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
|
||
# dataset related | ||
dataset_conf: | ||
even_sample: false | ||
syncbn: false | ||
filter_conf: | ||
max_length: 2000 | ||
min_length: 50 | ||
token_max_length: 400 | ||
token_min_length: 1 | ||
min_output_input_ratio: 0.0005 | ||
max_output_input_ratio: 0.1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: true | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 12 | ||
|
||
grad_clip: 5 | ||
accum_grad: 4 | ||
max_epoch: 120 | ||
log_interval: 100 | ||
|
||
optim: adamw | ||
optim_conf: | ||
lr: 1.e-3 | ||
weight_decay: 4.e-5 | ||
|
||
scheduler: NoamHoldAnnealing | ||
scheduler_conf: | ||
warmup_ratio: 0.2 | ||
hold_ratio: 0.3 | ||
max_steps: 87960 | ||
decay_rate: 1.0 | ||
min_lr: 1.e-5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../s0/local | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../wenet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Copyright (c) 2019 Shigeki Karita | ||
# 2020 Mobvoi Inc (Binbin Zhang) | ||
# 2022 Xingchen Song ([email protected]) | ||
# 2022 Ximalaya Inc. (Yuguang Yang) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Multi-Head Attention layer definition.""" | ||
|
||
import math | ||
import torch | ||
import torch.nn as nn | ||
from wenet.transformer.attention import MultiHeadedAttention | ||
from typing import Tuple | ||
|
||
|
||
class RelPositionMultiHeadedAttention(MultiHeadedAttention): | ||
"""Multi-Head Attention layer with relative position encoding. | ||
Paper: https://arxiv.org/abs/1901.02860 | ||
Args: | ||
n_head (int): The number of heads. | ||
n_feat (int): The number of features. | ||
dropout_rate (float): Dropout rate. | ||
""" | ||
|
||
def __init__(self, n_head, n_feat, dropout_rate, | ||
do_rel_shift=False, adaptive_scale=False, init_weights=False): | ||
"""Construct an RelPositionMultiHeadedAttention object.""" | ||
super().__init__(n_head, n_feat, dropout_rate) | ||
# linear transformation for positional encoding | ||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) | ||
# these two learnable bias are used in matrix c and matrix d | ||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | ||
self.do_rel_shift = do_rel_shift | ||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) | ||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) | ||
torch.nn.init.xavier_uniform_(self.pos_bias_u) | ||
torch.nn.init.xavier_uniform_(self.pos_bias_v) | ||
self.adaptive_scale = adaptive_scale | ||
if self.adaptive_scale: | ||
self.ada_scale = nn.Parameter( | ||
torch.ones([1, 1, n_feat]), requires_grad=True) | ||
self.ada_bias = nn.Parameter( | ||
torch.zeros([1, 1, n_feat]), requires_grad=True) | ||
if init_weights: | ||
self.init_weights() | ||
|
||
def init_weights(self): | ||
input_max = (self.h * self.d_k) ** -0.5 | ||
torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max) | ||
torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max) | ||
|
||
def rel_shift(self, x, zero_triu: bool = False): | ||
"""Compute relative positinal encoding. | ||
Args: | ||
x (torch.Tensor): Input tensor (batch, time, size). | ||
zero_triu (bool): If true, return the lower triangular part of | ||
the matrix. | ||
Returns: | ||
torch.Tensor: Output tensor. | ||
""" | ||
|
||
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), | ||
device=x.device, | ||
dtype=x.dtype) | ||
x_padded = torch.cat([zero_pad, x], dim=-1) | ||
|
||
x_padded = x_padded.view(x.size()[0], | ||
x.size()[1], | ||
x.size(3) + 1, x.size(2)) | ||
x = x_padded[:, :, 1:].view_as(x) | ||
|
||
if zero_triu: | ||
ones = torch.ones((x.size(2), x.size(3))) | ||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] | ||
|
||
return x | ||
|
||
def forward_attention( | ||
self, value: torch.Tensor, scores: torch.Tensor, | ||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) | ||
) -> torch.Tensor: | ||
"""Compute attention context vector. | ||
|
||
Args: | ||
value (torch.Tensor): Transformed value, size | ||
(#batch, n_head, time2, d_k). | ||
scores (torch.Tensor): Attention score, size | ||
(#batch, n_head, time1, time2). | ||
mask (torch.Tensor): Mask, size (#batch, 1, time2) or | ||
(#batch, time1, time2), (0, 0, 0) means fake mask. | ||
|
||
Returns: | ||
torch.Tensor: Transformed value (#batch, time1, d_model) | ||
weighted by the attention score (#batch, time1, time2). | ||
|
||
""" | ||
n_batch = value.size(0) | ||
# NOTE(xcsong): When will `if mask.size(2) > 0` be True? | ||
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the | ||
# 1st chunk to ease the onnx export.] | ||
# 2. pytorch training | ||
if mask.size(2) > 0: # time2 > 0 | ||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | ||
# For last chunk, time2 might be larger than scores.size(-1) | ||
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) | ||
scores = scores.masked_fill(mask, -float('inf')) | ||
# (batch, head, time1, time2) | ||
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) | ||
# NOTE(xcsong): When will `if mask.size(2) > 0` be False? | ||
# 1. onnx(16/-1, -1/-1, 16/0) | ||
# 2. jit (16/-1, -1/-1, 16/0, 16/4) | ||
else: | ||
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) | ||
|
||
p_attn = self.dropout(attn) | ||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | ||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, | ||
self.h * self.d_k) | ||
) # (batch, time1, d_model) | ||
|
||
return self.linear_out(x) # (batch, time1, d_model) | ||
|
||
def forward(self, query: torch.Tensor, | ||
key: torch.Tensor, value: torch.Tensor, | ||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | ||
pos_emb: torch.Tensor = torch.empty(0), | ||
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. | ||
Args: | ||
query (torch.Tensor): Query tensor (#batch, time1, size). | ||
key (torch.Tensor): Key tensor (#batch, time2, size). | ||
value (torch.Tensor): Value tensor (#batch, time2, size). | ||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | ||
(#batch, time1, time2), (0, 0, 0) means fake mask. | ||
pos_emb (torch.Tensor): Positional embedding tensor | ||
(#batch, time2, size). | ||
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), | ||
where `cache_t == chunk_size * num_decoding_left_chunks` | ||
and `head * d_k == size` | ||
Returns: | ||
torch.Tensor: Output tensor (#batch, time1, d_model). | ||
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) | ||
where `cache_t == chunk_size * num_decoding_left_chunks` | ||
and `head * d_k == size` | ||
""" | ||
if self.adaptive_scale: | ||
query = self.ada_scale * query + self.ada_bias | ||
key = self.ada_scale * key + self.ada_bias | ||
value = self.ada_scale * value + self.ada_bias | ||
q, k, v = self.forward_qkv(query, key, value) | ||
q = q.transpose(1, 2) # (batch, time1, head, d_k) | ||
|
||
# NOTE(xcsong): | ||
# when export onnx model, for 1st chunk, we feed | ||
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) | ||
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). | ||
# In all modes, `if cache.size(0) > 0` will alwayse be `True` | ||
# and we will always do splitting and | ||
# concatnation(this will simplify onnx export). Note that | ||
# it's OK to concat & split zero-shaped tensors(see code below). | ||
# when export jit model, for 1st chunk, we always feed | ||
# cache(0, 0, 0, 0) since jit supports dynamic if-branch. | ||
# >>> a = torch.ones((1, 2, 0, 4)) | ||
# >>> b = torch.ones((1, 2, 3, 4)) | ||
# >>> c = torch.cat((a, b), dim=2) | ||
# >>> torch.equal(b, c) # True | ||
# >>> d = torch.split(a, 2, dim=-1) | ||
# >>> torch.equal(d[0], d[1]) # True | ||
if cache.size(0) > 0: | ||
key_cache, value_cache = torch.split( | ||
cache, cache.size(-1) // 2, dim=-1) | ||
k = torch.cat([key_cache, k], dim=2) | ||
v = torch.cat([value_cache, v], dim=2) | ||
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's | ||
# non-trivial to calculate `next_cache_start` here. | ||
new_cache = torch.cat((k, v), dim=-1) | ||
|
||
n_batch_pos = pos_emb.size(0) | ||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) | ||
p = p.transpose(1, 2) # (batch, head, time1, d_k) | ||
|
||
# (batch, head, time1, d_k) | ||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) | ||
# (batch, head, time1, d_k) | ||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) | ||
|
||
# compute attention score | ||
# first compute matrix a and matrix c | ||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 | ||
# (batch, head, time1, time2) | ||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) | ||
|
||
# compute matrix b and matrix d | ||
# (batch, head, time1, time2) | ||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) | ||
# Remove rel_shift since it is useless in speech recognition, | ||
# and it requires special attention for streaming. | ||
if self.do_rel_shift: | ||
matrix_bd = self.rel_shift(matrix_bd) | ||
|
||
scores = (matrix_ac + matrix_bd) / math.sqrt( | ||
self.d_k) # (batch, head, time1, time2) | ||
|
||
return self.forward_attention(v, scores, mask), new_cache |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squeezeformer
in examples/librispeech is not required here, we can just do it in s0 by configue since it shares same training and decoding recipe.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, what about remove it after the whole
README.md
of squeezeformer part is updated?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!