Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeezeformer #1447

Merged
merged 29 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
23c4f99
[init] enable SqueezeformerEncoder
Sep 15, 2022
7c24031
[update] enable Squeezeformer training
Sep 15, 2022
76ac435
[update] README.md
Sep 15, 2022
b291d1e
fix formatting issues
yygle Sep 15, 2022
6a36c23
fix formatting issues
yygle Sep 15, 2022
139fc36
fix formatting issues
yygle Sep 15, 2022
084c843
fix formatting issues
yygle Sep 15, 2022
1f4a8b3
[update] change residual connection & add copyrights
yygle Sep 15, 2022
0558fbb
fix formatting issues
yygle Sep 15, 2022
0264049
[update] enlarge adaptive scale dimensions
yygle Sep 15, 2022
be2f56e
fix formatting issues
yygle Sep 15, 2022
ba6825c
fix adaptive scale bugs
yygle Sep 16, 2022
89f133e
[update] encoder.py(fix init weights bugs) and README.md
yygle Sep 20, 2022
78b8077
[update] initialization for input projection
yygle Sep 21, 2022
76fbcf2
fix formatting issues
yygle Sep 21, 2022
cefa4cd
fix formatting issues
yygle Sep 22, 2022
c2f2a05
[update] time reduction layer with conv1d and conv2d
yygle Sep 23, 2022
ba7ed74
fix formatting issues
yygle Sep 23, 2022
027c85c
[update] operators
yygle Sep 23, 2022
ed342f2
[update] experiment results & code format
yygle Sep 25, 2022
ac4013c
[update] experiment results
yygle Sep 25, 2022
6592ae3
[update] streaming support & results, dw_stride trigger
yygle Oct 8, 2022
67e260a
fix formatting issue
yygle Oct 8, 2022
5973352
fix formatting issue
yygle Oct 9, 2022
08c49aa
fix formatting issue
yygle Oct 9, 2022
d777305
fix formatting issue
yygle Oct 9, 2022
cd82d89
[update] SqueezeFormer Large Results
yygle Oct 12, 2022
3c55dde
fix formatting issues
yygle Oct 12, 2022
0824e56
fix format issues
yygle Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/librispeech/squeezeformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Develop Record

```
squeezeformer
├── attention.py # reltive multi-head attention module
├── conv2d.py # self defined conv2d valid padding module
├── convolution.py # convolution module in squeezeformer block
├── encoder_layer.py # squeezeformer encoder layer
├── encoder.py # squeezeformer encoder class
├── positionwise_feed_forward.py # feed forward layer
├── subsampling.py # sub-sampling layer, time reduction layer
└── utils.py # residual connection module
```

* Implementation Details
* Squeezeformer Encoder
* [x] add pre layer norm before squeezeformer block
* [x] derive time reduction layer from tensorflow version
* [x] enable adaptive scale operation
* [x] enable init weights for deep model training
* [x] enable training config and results
* [x] enable dynamic chunk and JIT export
* Training
* [x] enable NoamHoldAnnealing schedular

# Performance Record

### Conformer

* encoder flops(30s): 2,797,274,624, params: 34,761,608

### Squeezeformer Result (SM12, FFN:1024)

* encoder flops(30s): 21,158,877,440, params: 22,219,912
* Feature info: using fbank feature, cmvn, dither, online speed perturb
* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size
12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1
* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30

| decoding mode | dev clean | dev other | test clean | test other |
|----------------------------------|-----------|-----------|------------|------------|
| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 |
| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 |
| attention decoder | 8.74 | 3.59 | 3.75 | 8.70 |
| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 |
88 changes: 88 additions & 0 deletions examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# network architecture
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 256
output_size: 256 # dimension of attention
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 4
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
cnn_module_kernel: 31
cnn_norm_type: layer_norm
adaptive_scale: true
normalize_before: false

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

# dataset related
dataset_conf:
even_sample: false
syncbn: false
filter_conf:
max_length: 2000
min_length: 50
token_max_length: 400
token_min_length: 1
min_output_input_ratio: 0.0005
max_output_input_ratio: 0.1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 12

grad_clip: 5
accum_grad: 4
max_epoch: 120
log_interval: 100

optim: adamw
optim_conf:
lr: 1.e-3
weight_decay: 4.e-5

scheduler: NoamHoldAnnealing
scheduler_conf:
warmup_ratio: 0.2
hold_ratio: 0.3
max_steps: 87960
decay_rate: 1.0
min_lr: 1.e-5
1 change: 1 addition & 0 deletions examples/librispeech/squeezeformer/local
1 change: 1 addition & 0 deletions examples/librispeech/squeezeformer/tools
1 change: 1 addition & 0 deletions examples/librispeech/squeezeformer/wenet
17 changes: 14 additions & 3 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
load_trained_modules)
from wenet.utils.executor import Executor
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.scheduler import WarmupLR
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model

Expand Down Expand Up @@ -243,8 +243,19 @@ def main():
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)

optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
if configs['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
elif configs['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['optim_conf'])
else:
raise Exception('Please choose a correct optimizer.')
if configs['scheduler'] == 'warmuplr':
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
elif configs['scheduler'] == 'NoamHoldAnnealing':
scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
else:
raise Exception('Please choose a correct scheduler.')

final_epoch = None
configs['rank'] = args.rank
configs['is_distributed'] = distributed
Expand Down
223 changes: 223 additions & 0 deletions wenet/squeezeformer/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song ([email protected])
# 2022 Ximalaya Inc. (Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multi-Head Attention layer definition."""

import math
import torch
import torch.nn as nn
from wenet.transformer.attention import MultiHeadedAttention
from typing import Tuple


class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self, n_head, n_feat, dropout_rate,
do_rel_shift=False, adaptive_scale=False, init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
self.ada_scale = nn.Parameter(
torch.ones([1, 1, n_feat]), requires_grad=True)
self.ada_bias = nn.Parameter(
torch.zeros([1, 1, n_feat]), requires_grad=True)
if init_weights:
self.init_weights()

def init_weights(self):
input_max = (self.h * self.d_k) ** -0.5
torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max)

def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""

zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)

if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]

return x

def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.

Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.

Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).

"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
# (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
if self.adaptive_scale:
query = self.ada_scale * query + self.ada_bias
key = self.ada_scale * key + self.ada_bias
value = self.ada_scale * value + self.ada_bias
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)

# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
if self.do_rel_shift:
matrix_bd = self.rel_shift(matrix_bd)

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache
Loading