Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[recipe] refine yaml for wenetspeech #2229

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/wenetspeech/s0/conf/train_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,37 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train_l/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset: asr
dataset_conf:
filter_conf:
max_length: 1200
Expand Down
25 changes: 25 additions & 0 deletions examples/wenetspeech/s0/conf/train_conformer_bidecoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,38 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

tokenizer: char
tokenizer_conf:
symbol_table_path: 'data/dict/lang_char.txt'
split_with_space: false
bpe_path: null
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train_l/global_cmvn'
is_json_cmvn: true

# hybrid CTC/attention
model: asr_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3

dataset: asr
dataset_conf:
filter_conf:
max_length: 1200
Expand Down
36 changes: 13 additions & 23 deletions examples/wenetspeech/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ test_sets="test_net test_meeting"

train_config=conf/train_conformer.yaml
checkpoint=
cmvn=true
cmvn_sampling_divisor=20 # 20 means 5% of the training data to estimate cmvn
dir=exp/conformer

Expand Down Expand Up @@ -78,33 +77,29 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
mkdir -p $(dirname $dict)
echo "<blank> 0" > ${dict} # 0 will be used for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
echo "▁ 2" >> ${dict} # ▁ is for space
echo "<sos/eos> 2" >> $dict
echo "▁ 3" >> ${dict} # ▁ is for space
tools/text2token.py -s 1 -n 1 --space "▁" data/${train_set}/text \
| cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' \
| grep -v "▁" \
| awk '{print $0 " " NR+2}' >> ${dict} \
| awk '{print $0 " " NR+3}' >> ${dict} \
|| exit 1;
num_token=$(cat $dict | wc -l)
echo "<sos/eos> $num_token" >> $dict
fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Compute cmvn"
# Here we use all the training data, you can sample some some data to save time
# BUG!!! We should use the segmented data for CMVN
if $cmvn; then
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
python3 tools/compute_cmvn_stats.py \
--num_workers 16 \
--train_config $train_config \
--in_scp data/$train_set/wav.scp.sampled \
--out_cmvn data/$train_set/global_cmvn \
|| exit 1;
fi
full_size=`cat data/${train_set}/wav.scp | wc -l`
sampling_size=$((full_size / cmvn_sampling_divisor))
shuf -n $sampling_size data/$train_set/wav.scp \
> data/$train_set/wav.scp.sampled
python3 tools/compute_cmvn_stats.py \
--num_workers 16 \
--train_config $train_config \
--in_scp data/$train_set/wav.scp.sampled \
--out_cmvn data/$train_set/global_cmvn
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
Expand All @@ -129,9 +124,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="nccl"
cmvn_opts=
$cmvn && cp data/${train_set}/global_cmvn $dir
$cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn"
# train.py will write $train_config to $dir/train.yaml with model input
# and output dimension, train.yaml will be used for inference or model
# export later
Expand All @@ -147,13 +139,11 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--train_engine ${train_engine} \
--config $train_config \
--data_type "shard" \
--symbol_table $dict \
--train_data data/$train_set/data.list \
--cv_data data/$dev_set/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
$cmvn_opts \
--num_workers 8 \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
Expand Down Expand Up @@ -189,7 +179,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $result_dir \
Expand All @@ -199,6 +188,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
data/$testset/text $result_dir/$mode/text > $result_dir/$mode/wer
done
}
done
fi

if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
Expand Down
Loading