Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(yaml): Config ctc/cmvn/tokenizer in train.yaml #2205

Merged
merged 26 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
185b971
refactor(yaml): Config ctc/cmvn/tokenizer in train.yaml
xingchensong Dec 8, 2023
259f159
refactor(yaml): pass training
xingchensong Dec 8, 2023
807c874
refactor(yaml): try to pass unittest
xingchensong Dec 8, 2023
1a04689
refactor(yaml): remove lfmmi
xingchensong Dec 8, 2023
ffd6dc6
refactor(yaml): nst recipe
xingchensong Dec 8, 2023
8032a37
[refactor] refine run.sh
xingchensong Dec 12, 2023
6e7d876
[refactor] refine run.sh
xingchensong Dec 12, 2023
9b078ed
[refactor] rebase main
xingchensong Dec 12, 2023
847a35d
[refactor] try to pass ut
xingchensong Dec 12, 2023
9acc3e9
[refactor] refine librispeech in next PR
xingchensong Dec 12, 2023
d31fe15
[refactor] add todo
xingchensong Dec 12, 2023
9430a0b
[refactor] refine paraformer in next PR
xingchensong Dec 12, 2023
ca302d3
[refactor] make sos = 2
xingchensong Dec 12, 2023
fe0e9e0
[refactor] make sos = 2
xingchensong Dec 12, 2023
1f9cbfa
[refactor] try to pass ut
xingchensong Dec 12, 2023
334b4b3
[refactor] refine onnx_gpu
xingchensong Dec 12, 2023
69f0364
[refactor] try to pass ut
xingchensong Dec 12, 2023
3fb7e99
[refactor] try to pass ut
xingchensong Dec 12, 2023
6eeb4e4
[refactor] try to pass ut
xingchensong Dec 12, 2023
72a1e18
[refactor] try to pass ut
xingchensong Dec 12, 2023
924289a
refactor: pass decoding
xingchensong Dec 12, 2023
a7b009a
refactor: pass decoding
xingchensong Dec 12, 2023
73d8742
refactor: pass decoding
xingchensong Dec 12, 2023
5394ca1
refactor: refine tokenizer
xingchensong Dec 12, 2023
389d5f7
refactor: try to pass ut
xingchensong Dec 12, 2023
e1f6c38
Merge branch 'main' into xcsong-yaml
xingchensong Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/aishell/NST/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 1200
Expand Down
9 changes: 0 additions & 9 deletions examples/aishell/NST/run_nst.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ data_type=shard
num_utts_per_shard=1000
train_set=train
train_config=conf/train_conformer.yaml
cmvn=true
average_checkpoint=true
target_pt=80
decode_checkpoint=$dir/$target_pt.pt
Expand Down Expand Up @@ -113,9 +112,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
dist_backend="nccl"
# the global_cmvn file need to be calculated by combining both supervised/unsupervised datasets,
# and it should be positioned at data/${train_set}/global_cmvn .
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir/global_cmvn
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"

# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
Expand All @@ -133,14 +129,12 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/$data_list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand Down Expand Up @@ -190,7 +184,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $test_dir/text \
Expand All @@ -216,7 +209,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file $dev_dir/text \
Expand Down Expand Up @@ -275,7 +267,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_file data/train/${dir_split}data_sublist${job_num}/${hypo_name} \
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -59,6 +82,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -63,6 +86,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
24 changes: 24 additions & 0 deletions examples/aishell/rnnt/conf/example_embedding_predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
Expand All @@ -55,6 +78,7 @@ model_conf:
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
12 changes: 2 additions & 10 deletions examples/aishell/rnnt/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ num_utts_per_shard=1000

train_set=train
train_config=conf/conformer_u2pp_rnnt.yaml
cmvn=true
dir=exp/conformer_rnnt
checkpoint=

Expand Down Expand Up @@ -92,11 +91,10 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
mkdir -p $(dirname $dict)
echo "<blank> 0" > ${dict} # 0 is for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
echo "<sos/eos> 2" >> $dict
tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
| tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
awk '{print $0 " " NR+1}' >> ${dict}
num_token=$(cat $dict | wc -l)
echo "<sos/eos> $num_token" >> $dict
awk '{print $0 " " NR+2}' >> ${dict}
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
Expand All @@ -118,9 +116,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="nccl"
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"

# train.py rewrite $train_config to $dir/train.yaml with model input
# and output dimension, and $dir/train.yaml will be used for inference
Expand All @@ -137,14 +132,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--train_engine ${train_engine} \
--config $train_config \
--data_type $data_type \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand Down Expand Up @@ -183,7 +176,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--beam_size 10 \
--batch_size 32 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $rescore_ctc_weight \
--transducer_weight $rescore_transducer_weight \
--attn_weight $rescore_attn_weight \
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_conformer_no_pos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
25 changes: 25 additions & 0 deletions examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,37 @@ decoder_conf:
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
Expand Down
Loading
Loading