Skip to content

Commit

Permalink
Merge pull request #2202 from wenet-e2e/diwu-u2++-lite
Browse files Browse the repository at this point in the history
[train] u2++-lite training support
  • Loading branch information
whiteshirt0429 authored Dec 8, 2023
2 parents 1f8e1c5 + bfaa8a3 commit 2894f7c
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 29 deletions.
13 changes: 13 additions & 0 deletions examples/aishell/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
| HLG(k2 LM) + attention rescoring | 4.32 | 4.70 |
| HLG(k2 LM) + attention rescoring + LFMMI | 4.11 | 4.47 |

## U2++ lite Conformer Result (uio shard)

* Feature info: using fbank feature, dither=1.0, cmvn, oneline speed perturb
* Training info: lr 0.001, batch size 16, 8 gpu, acc_grad 1, load a well trained model and continue training 80 epochs with u2++ lite config
* Decoding info: ctc_weight 0.3, reverse_weight 0.5 average_num 30
* Git hash: 73185808fa1463b0163a922dc722513b7baabe9e

| decoding mode/chunk size | full | 16 |
|---------------------------|-------|-------|
| ctc greedy search | 5.21 | 5.91 |
| ctc prefix beam search | 5.20 | 5.91 |
| attention rescoring | 4.67 | 5.10 |

## Unified Conformer Result

* Feature info: using fbank feature, dither=0, cmvn, oneline speed perturb
Expand Down
91 changes: 91 additions & 0 deletions examples/aishell/s0/conf/train_u2++_lite_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 8
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 4
linear_units: 1024
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3
apply_non_blank_embedding: true # warning: had better use a well trained model as init model

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
spec_sub: true
spec_sub_conf:
num_t_sub: 3
max_t: 30
spec_trim: false
spec_trim_conf:
max_t: 50
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 360
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
3 changes: 3 additions & 0 deletions examples/aishell/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ train_set=train
# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer
# 5. conf/train_u2++_conformer.yaml: U2++ conformer
# 6. conf/train_u2++_transformer.yaml: U2++ transformer
# 7. conf/train_u2++_conformer.yaml: U2++ lite conformer, must load a well
# trained model, and freeze encoder module, otherwise there will be a
# autograd error
train_config=conf/train_conformer.yaml
cmvn=true
dir=exp/conformer
Expand Down
10 changes: 5 additions & 5 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -49,9 +49,9 @@ def __init__(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
return loss_ctc
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
def load_lfmmi_resource(self):
Expand Down Expand Up @@ -106,7 +106,7 @@ def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
for i in text
]
loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text)
return loss
return loss, ctc_probs

def load_hlg_resource_if_necessary(self, hlg, word):
try:
Expand Down
13 changes: 8 additions & 5 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def forward(
# 3.1 ctc branhch
loss_ctc: Optional[torch.Tensor] = None
if self.ctc_weight != 0.0:
loss_ctc = self._forward_ctc(encoder_out, encoder_out_mask, text,
text_lengths)
loss_ctc, ctc_probs = self._forward_ctc(encoder_out,
encoder_out_mask,
text, text_lengths)
# TODO(Mddc): thu acc
loss_decoder = self.criterion_att(decoder_out, ys_pad)
loss = loss_decoder
Expand All @@ -152,10 +153,12 @@ def forward(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
return loss_ctc
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
Expand Down
64 changes: 53 additions & 11 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -25,6 +26,7 @@
ctc_prefix_beam_search,
attention_beam_search,
attention_rescoring, DecodeResult)
from wenet.utils.mask import make_pad_mask
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy,
reverse_pad_list)
from wenet.utils.context_graph import ContextGraph
Expand All @@ -45,6 +47,7 @@ def __init__(
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
special_tokens: dict = None,
apply_non_blank_embedding: bool = False,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

Expand All @@ -65,6 +68,7 @@ def __init__(
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.apply_non_blank_embedding = apply_non_blank_embedding

self.encoder = encoder
self.decoder = decoder
Expand Down Expand Up @@ -102,19 +106,25 @@ def forward(
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)

# 2a. Attention-decoder branch
# 2a. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None

# 2b. Attention-decoder branch
# use non blank (token level) embedding for decoder
if self.apply_non_blank_embedding:
assert self.ctc_weight != 0
encoder_out, encoder_mask = self.filter_blank_embedding(
ctc_probs, encoder_out)
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None

# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self._forward_ctc(encoder_out, encoder_mask, text,
text_lengths)
else:
loss_ctc = None
acc_att = None

if loss_ctc is None:
loss = loss_att
Expand All @@ -128,10 +138,39 @@ def forward(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
return loss_ctc
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
return loss_ctc, ctc_probs

def filter_blank_embedding(
self, ctc_probs: torch.Tensor,
encoder_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = encoder_out.size(0)
maxlen = encoder_out.size(1)
top1_index = torch.argmax(ctc_probs, dim=2)
indices = []
for j in range(batch_size):
indices.append(
torch.tensor(
[i for i in range(maxlen) if top1_index[j][i] != 0]))

select_encoder_out = [
torch.index_select(encoder_out[i, :, :], 0,
indices[i].to(encoder_out.device))
for i in range(batch_size)
]
select_encoder_out = pad_sequence(select_encoder_out,
batch_first=True,
padding_value=0).to(
encoder_out.device)
xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size)
]).to(encoder_out.device)
T = select_encoder_out.size(1)
encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
encoder_out = select_encoder_out
return encoder_out, encoder_mask

def _calc_att_loss(
self,
Expand Down Expand Up @@ -257,6 +296,9 @@ def decode(
else:
ctc_prefix_result = ctc_prefix_beam_search(
ctc_probs, encoder_lens, beam_size, context_graph)
if self.apply_non_blank_embedding:
encoder_out, _ = self.filter_blank_embedding(
ctc_probs, encoder_out)
results['attention_rescoring'] = attention_rescoring(
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
reverse_weight)
Expand Down
9 changes: 7 additions & 2 deletions wenet/transformer/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)

from typing import Tuple

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -46,7 +48,9 @@ def __init__(
reduction=reduction_type)

def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor:
ys_pad: torch.Tensor,
ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

"""Calculate CTC loss.
Args:
Expand All @@ -63,7 +67,8 @@ def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
# Batch-size average
loss = loss / ys_hat.size(1)
return loss
ys_hat = ys_hat.transpose(0, 1)
return loss, ys_hat

def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
"""log_softmax of frame activations
Expand Down
1 change: 0 additions & 1 deletion wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def train(self, model, optimizer, scheduler, data_loader, writer, configs,
info_dict["tag"] = "TRAIN"
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))

# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
Expand Down
26 changes: 21 additions & 5 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ def add_model_args(parser):
default=None,
type=str,
help="Pre-trained model to initialize encoder")
parser.add_argument(
"--enc_init_mods",
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
parser.add_argument('--enc_init_mods',
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
to initialize ,separated by a comma")
parser.add_argument('--freeze_modules',
default="",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='free module names',)
parser.add_argument('--lfmmi_dir',
default='',
required=False,
Expand Down Expand Up @@ -239,6 +242,12 @@ def check_modify_and_save_config(args, configs, symbol_table):
data = yaml.dump(configs)
fout.write(data)

if configs["model_conf"]["apply_non_blank_embedding"]:
logging.warn(
'Had better load a well trained model'
'if apply_non_blank_embedding is true !!!'
)

return configs


Expand Down Expand Up @@ -601,3 +610,10 @@ def log_per_epoch(writer, info_dict):
if int(os.environ.get('RANK', 0)) == 0:
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch)
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)

def freeze_modules(model, args):
for name, param in model.named_parameters():
for module_name in args.freeze_modules:
if module_name in name:
param.requires_grad = False
logging.debug("{} module is freezed".format(name))

0 comments on commit 2894f7c

Please sign in to comment.