Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] u2++-lite training support #2202

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/aishell/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
| HLG(k2 LM) + attention rescoring | 4.32 | 4.70 |
| HLG(k2 LM) + attention rescoring + LFMMI | 4.11 | 4.47 |

## U2++ lite Conformer Result (uio shard)

* Feature info: using fbank feature, dither=1.0, cmvn, oneline speed perturb
* Training info: lr 0.001, batch size 16, 8 gpu, acc_grad 1, load a well trained model and continue training 80 epochs with u2++ lite config
* Decoding info: ctc_weight 0.3, reverse_weight 0.5 average_num 30
* Git hash: 73185808fa1463b0163a922dc722513b7baabe9e

| decoding mode/chunk size | full | 16 |
|---------------------------|-------|-------|
| ctc greedy search | 5.21 | 5.91 |
| ctc prefix beam search | 5.20 | 5.91 |
| attention rescoring | 4.67 | 5.10 |

## Unified Conformer Result

* Feature info: using fbank feature, dither=0, cmvn, oneline speed perturb
Expand Down
91 changes: 91 additions & 0 deletions examples/aishell/s0/conf/train_u2++_lite_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 8
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false

# decoder related
decoder: bitransformer
decoder_conf:
attention_heads: 4
linear_units: 1024
num_blocks: 3
r_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3
apply_non_blank_embedding: true # warning: had better use a well trained model as init model

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
spec_sub: true
spec_sub_conf:
num_t_sub: 3
max_t: 30
spec_trim: false
spec_trim_conf:
max_t: 50
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 360
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
3 changes: 3 additions & 0 deletions examples/aishell/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ train_set=train
# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer
# 5. conf/train_u2++_conformer.yaml: U2++ conformer
# 6. conf/train_u2++_transformer.yaml: U2++ transformer
# 7. conf/train_u2++_conformer.yaml: U2++ lite conformer, must load a well
# trained model, and freeze encoder module, otherwise there will be a
# autograd error
train_config=conf/train_conformer.yaml
cmvn=true
dir=exp/conformer
Expand Down
10 changes: 5 additions & 5 deletions wenet/k2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -49,9 +49,9 @@ def __init__(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
return loss_ctc
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
def load_lfmmi_resource(self):
Expand Down Expand Up @@ -106,7 +106,7 @@ def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
for i in text
]
loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text)
return loss
return loss, ctc_probs

def load_hlg_resource_if_necessary(self, hlg, word):
try:
Expand Down
13 changes: 8 additions & 5 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def forward(
# 3.1 ctc branhch
loss_ctc: Optional[torch.Tensor] = None
if self.ctc_weight != 0.0:
loss_ctc = self._forward_ctc(encoder_out, encoder_out_mask, text,
text_lengths)
loss_ctc, ctc_probs = self._forward_ctc(encoder_out,
encoder_out_mask,
text, text_lengths)
# TODO(Mddc): thu acc
loss_decoder = self.criterion_att(decoder_out, ys_pad)
loss = loss_decoder
Expand All @@ -152,10 +153,12 @@ def forward(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
return loss_ctc
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
return loss_ctc, ctc_probs

@torch.jit.ignore(drop=True)
def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
Expand Down
64 changes: 53 additions & 11 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -25,6 +26,7 @@
ctc_prefix_beam_search,
attention_beam_search,
attention_rescoring, DecodeResult)
from wenet.utils.mask import make_pad_mask
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy,
reverse_pad_list)
from wenet.utils.context_graph import ContextGraph
Expand All @@ -45,6 +47,7 @@ def __init__(
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
special_tokens: dict = None,
apply_non_blank_embedding: bool = False,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

Expand All @@ -65,6 +68,7 @@ def __init__(
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.apply_non_blank_embedding = apply_non_blank_embedding

self.encoder = encoder
self.decoder = decoder
Expand Down Expand Up @@ -102,19 +106,25 @@ def forward(
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)

# 2a. Attention-decoder branch
# 2a. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None

# 2b. Attention-decoder branch
# use non blank (token level) embedding for decoder
if self.apply_non_blank_embedding:
assert self.ctc_weight != 0
encoder_out, encoder_mask = self.filter_blank_embedding(
ctc_probs, encoder_out)
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None

# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self._forward_ctc(encoder_out, encoder_mask, text,
text_lengths)
else:
loss_ctc = None
acc_att = None

if loss_ctc is None:
loss = loss_att
Expand All @@ -128,10 +138,39 @@ def forward(
@torch.jit.ignore(drop=True)
def _forward_ctc(self, encoder_out: torch.Tensor,
encoder_mask: torch.Tensor, text: torch.Tensor,
text_lengths: torch.Tensor) -> torch.Tensor:
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
return loss_ctc
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens,
text, text_lengths)
return loss_ctc, ctc_probs

def filter_blank_embedding(
self, ctc_probs: torch.Tensor,
encoder_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = encoder_out.size(0)
maxlen = encoder_out.size(1)
top1_index = torch.argmax(ctc_probs, dim=2)
indices = []
for j in range(batch_size):
indices.append(
torch.tensor(
[i for i in range(maxlen) if top1_index[j][i] != 0]))

select_encoder_out = [
torch.index_select(encoder_out[i, :, :], 0,
indices[i].to(encoder_out.device))
for i in range(batch_size)
]
select_encoder_out = pad_sequence(select_encoder_out,
batch_first=True,
padding_value=0).to(
encoder_out.device)
xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size)
]).to(encoder_out.device)
T = select_encoder_out.size(1)
encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
encoder_out = select_encoder_out
return encoder_out, encoder_mask

def _calc_att_loss(
self,
Expand Down Expand Up @@ -257,6 +296,9 @@ def decode(
else:
ctc_prefix_result = ctc_prefix_beam_search(
ctc_probs, encoder_lens, beam_size, context_graph)
if self.apply_non_blank_embedding:
encoder_out, _ = self.filter_blank_embedding(
ctc_probs, encoder_out)
results['attention_rescoring'] = attention_rescoring(
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
reverse_weight)
Expand Down
9 changes: 7 additions & 2 deletions wenet/transformer/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)

from typing import Tuple

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -46,7 +48,9 @@ def __init__(
reduction=reduction_type)

def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor:
ys_pad: torch.Tensor,
ys_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

"""Calculate CTC loss.

Args:
Expand All @@ -63,7 +67,8 @@ def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor,
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
# Batch-size average
loss = loss / ys_hat.size(1)
return loss
ys_hat = ys_hat.transpose(0, 1)
return loss, ys_hat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里ctc的返回值已经变成俩了,所以k2和paraformer里的ctc调用的返回值也得改一下,不然会报错

image


def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
"""log_softmax of frame activations
Expand Down
1 change: 0 additions & 1 deletion wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def train(self, model, optimizer, scheduler, data_loader, writer, configs,
info_dict["tag"] = "TRAIN"
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))

# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
Expand Down
26 changes: 21 additions & 5 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ def add_model_args(parser):
default=None,
type=str,
help="Pre-trained model to initialize encoder")
parser.add_argument(
"--enc_init_mods",
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
parser.add_argument('--enc_init_mods',
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
to initialize ,separated by a comma")
parser.add_argument('--freeze_modules',
default="",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='free module names',)
parser.add_argument('--lfmmi_dir',
default='',
required=False,
Expand Down Expand Up @@ -239,6 +242,12 @@ def check_modify_and_save_config(args, configs, symbol_table):
data = yaml.dump(configs)
fout.write(data)

if configs["model_conf"]["apply_non_blank_embedding"]:
logging.warn(
'Had better load a well trained model'
'if apply_non_blank_embedding is true !!!'
)

return configs


Expand Down Expand Up @@ -601,3 +610,10 @@ def log_per_epoch(writer, info_dict):
if int(os.environ.get('RANK', 0)) == 0:
writer.add_scalar('epoch/cv_loss', info_dict["cv_loss"], epoch)
writer.add_scalar('epoch/lr', info_dict["lr"], epoch)

def freeze_modules(model, args):
for name, param in model.named_parameters():
for module_name in args.freeze_modules:
if module_name in name:
param.requires_grad = False
logging.debug("{} module is freezed".format(name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

纯好奇,freeze的结果比不freeze更好吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不freeze 多卡训练会有问题,对齐也会发生变化

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不freeze 多卡训练会有问题,对齐也会发生变化

get,多卡训练报啥错

Loading