Skip to content

Commit

Permalink
[refactor] rebase main
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Dec 12, 2023
1 parent f402469 commit 97d44a1
Show file tree
Hide file tree
Showing 22 changed files with 58 additions and 30 deletions.
2 changes: 1 addition & 1 deletion examples/aishell/NST/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ encoder_conf:
selfattention_layer_type: 'rel_selfattn'


joint: transducerjoint
joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ encoder_conf:
use_dynamic_left_chunk: false


joint: transducerjoint
joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ encoder_conf:
selfattention_layer_type: 'rel_selfattn'


joint: transducerjoint
joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 320
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_conformer_no_pos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_branchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_efficonformer_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_efficonformer_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_u2++_lite_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,39 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 4232
<eos>: 4232

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3
apply_non_blank_embedding: true # warning: had better use a well trained model as init model

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_u2++_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_unified_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_unified_conformer_ctl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: ctlmodel
model: ctl_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_unified_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cmvn_conf:
is_json_cmvn: true

# hybrid CTC/attention
model: asrmodel
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
tools/k2/prepare_mmi.sh data/train/ data/dev data/local/lfmmi

# 9.2 Run LF-MMI training from stage 4, modify below args in train.yaml
# model: k2model
# model: k2_model
# model_conf:
# lfmmi_dir data/local/lfmmi

Expand Down
2 changes: 1 addition & 1 deletion wenet/transducer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __init__(self,
layer_norm_epsilon: float = 1e-5) -> None:
super().__init__()

assert output_size == embed_size
assert embed_size == output_size
assert history_size >= 0
self.embed_size = embed_size
self.context_size = history_size + 1
Expand Down
23 changes: 13 additions & 10 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@
}

WENET_JOINT_CLASSES = {
"transducerjoint": TransducerJoint,
"transducer_joint": TransducerJoint,
}

WENET_MODEL_CLASSES = {
"asrmodel": ASRModel,
"ctlmodel": CTLModel,
"asr_model": ASRModel,
"ctl_model": CTLModel,
"whisper": Whisper,
"k2model": K2Model,
"k2_model": K2Model,
"transducer": Transducer,
}

Expand Down Expand Up @@ -107,9 +107,10 @@ def init_model(args, configs):
blank_id=configs['ctc_conf']['ctc_blank_id']
if 'ctc_conf' in configs else 0)

if configs['model'] == "transducer":
model_type = configs.get('model', 'asr_model')
if model_type == "transducer":
predictor_type = configs.get('predictor', 'rnn')
joint_type = configs.get('joint', 'transducerjoint')
joint_type = configs.get('joint', 'transducer_joint')
predictor = WENET_PREDICTOR_CLASSES[predictor_type](
vocab_size, **configs['predictor_conf'])
joint = WENET_JOINT_CLASSES[joint_type](vocab_size,
Expand All @@ -122,9 +123,10 @@ def init_model(args, configs):
attention_decoder=decoder,
joint=joint,
ctc=ctc,
special_tokens=configs['tokenizer_conf']['special_tokens'],
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
**configs['model_conf'])
elif configs['model'] == 'paraformer':
elif model_type == 'paraformer':
""" NOTE(Mddct): support fintune paraformer, if there is a need for
sanmencoder/decoder in the future, simplify here.
"""
Expand All @@ -135,12 +137,13 @@ def init_model(args, configs):
print(configs)
return model, configs
else:
model = WENET_MODEL_CLASSES[configs['model']](
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=configs['tokenizer_conf']['special_tokens'],
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
**configs['model_conf'])

# If specify checkpoint, load some info from checkpoint
Expand Down

0 comments on commit 97d44a1

Please sign in to comment.