Skip to content

Commit

Permalink
Merge pull request #2371 from wenet-e2e/xcsong-wenetspech
Browse files Browse the repository at this point in the history
[examples] better results on wenetspeech using revised transcripts
  • Loading branch information
whiteshirt0429 authored Mar 1, 2024
2 parents 844d578 + 68e4e8c commit ad663fd
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 18 deletions.
13 changes: 13 additions & 0 deletions examples/wenetspeech/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@
| attention rescoring - 16 | 7.57 % N=328207 C=307065 S=15169 D=5973 I=3687 | 10.13 % N=414285 C=376854 S=28486 D=8945 I=4541 | 15.55 % N=220358 C=191270 S=21136 D=7952 I=5184 |
| attention - full | 7.73 % N=328207 C=306688 S=13166 D=8353 I=3845 | 9.44 % N=414285 C=378096 S=24532 D=11657 I=2908 | 14.98 % N=220358 C=191881 S=15303 D=13174 I=4540 |

## U2++ conformer (text\_fixed, see https://github.com/wenet-e2e/WenetSpeech/discussions/54)

* Feature info: using fbank feature, with dither 1.0, with cmvn
* Training info: lr 0.001, batch size dynamic36000, 8 gpus on 3090, acc_grad 4, 130k steps, 4.6 days
* Decoding info: ctc_weight 0.5, reverse_weight 0.0, average_num 5, blank penalty 0.0, length penalty 0.0
* PR link: https://github.com/wenet-e2e/wenet/pull/2371

| Decoding mode - Chunk size | Dev | Test\_Net | Test\_Meeting |
|:-----------------------------:|:----:|:---------:|:-------------:|
| ctc prefix beam search - full | 6.26 % N=328207 C=310671 S=15612 D=1924 I=3002 | 9.46 % N=414285 C=381373 S=26013 D=6899 I=6295 | 12.52 % N=220358 C=194801 S=19209 D=6348 I=2042 |
| attention rescoring - full | 5.90 % N=328207 C=311721 S=14597 D=1889 I=2888 | 8.96 % N=414092 C=380232 S=27606 D=6254 I=3222 | 11.99 % N=220358 C=195808 S=18243 D=6307 I=1878 |
| attention - full | 5.87 % N=328207 C=311922 S=14204 D=2081 I=2987 | 8.87 % N=414092 C=381014 S=27359 D=5719 I=3650 | 11.79 % N=220358 C=196484 S=17378 D=6496 I=2108 |

## U2++ conformer (wenetspeech plus aishell4)

* Feature info: using fbank feature, with dither 1.0, with cmvn
Expand Down
116 changes: 116 additions & 0 deletions examples/wenetspeech/s0/conf/train_u2++_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
encoder: conformer
encoder_conf:
activation_type: swish
attention_dropout_rate: 0.1
attention_heads: 8
causal: true
cnn_module_kernel: 15
cnn_module_norm: layer_norm
dropout_rate: 0.1
gradient_checkpointing: true
input_layer: conv2d
linear_units: 2048
normalize_before: true
num_blocks: 12
output_size: 512
pos_enc_layer_type: rel_pos
positional_dropout_rate: 0.1
selfattention_layer_type: rel_selfattn
use_cnn_module: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false

decoder: bitransformer
decoder_conf:
attention_heads: 8
dropout_rate: 0.1
gradient_checkpointing: true
linear_units: 2048
num_blocks: 3
positional_dropout_rate: 0.1
r_num_blocks: 3
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
bpe_path: null
is_multilingual: false
non_lang_syms_path: null
num_languages: 1
special_tokens:
<blank>: 0
<eos>: 2
<sos>: 2
<unk>: 1
split_with_space: false
symbol_table_path: data/dict/lang_char.txt

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: data/train_l/global_cmvn
is_json_cmvn: true

model: asr_model
model_conf:
ctc_weight: 0.3
length_normalized_loss: false
lsm_weight: 0.1
reverse_weight: 0.3

dataset: asr
dataset_conf:
batch_conf:
batch_size: 32
batch_type: dynamic
max_frames_in_batch: 36000
fbank_conf:
dither: 1.0
frame_length: 25
frame_shift: 10
num_mel_bins: 80
filter_conf:
max_length: 4096
max_output_input_ratio: 0.25
min_length: 10
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
shuffle: true
shuffle_conf:
shuffle_size: 5000
sort: true
sort_conf:
sort_size: 1000
spec_aug: true
spec_aug_conf:
max_f: 30
max_t: 50
num_f_mask: 2
num_t_mask: 2
spec_sub: true
spec_sub_conf:
max_t: 30
num_t_sub: 3
spec_trim: true
spec_trim_conf:
max_t: 30
speed_perturb: true

grad_clip: 5
accum_grad: 4
max_epoch: 1 # NOTE(xcsong): Configure the epoch in run.sh
log_interval: 100
save_interval: 1000 # NOTE(xcsong): we use step_save instead of epoch_save for large datasets

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 50000
88 changes: 70 additions & 18 deletions examples/wenetspeech/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@ fi
# if you don't want to utilize all available GPU resources.
export CUDA_VISIBLE_DEVICES="${gpu_list}"
echo "CUDA_VISIBLE_DEVICES is ${CUDA_VISIBLE_DEVICES}"
stage=0

cuda_visible_devices=${CUDA_VISIBLE_DEVICES:-""}
if [ -z "$cuda_visible_devices" ]; then
echo "CUDA_VISIBLE_DEVICES is not set. Using default device_ids."
device_ids=(0 1 2 3 4 5 6 7)
else
IFS=',' read -r -a device_ids <<< "$cuda_visible_devices"
echo "Using CUDA_VISIBLE_DEVICES: $cuda_visible_devices"
fi
echo "Parsed device_ids: ${device_ids[@]}"

stage=4
stop_stage=5

# You should change the following two parameters for multiple machine training,
Expand All @@ -36,22 +47,34 @@ train_set=train_`echo $set | tr 'A-Z' 'a-z'`
dev_set=dev
test_sets="test_net test_meeting"

train_config=conf/train_conformer.yaml
# NOTE(xcsong): we use step_save instead of epoch_save for large datasets
epoch=100

train_config=conf/train_u2++_conformer.yaml
checkpoint=
dir=exp/u2pp_conformer

cmvn_sampling_divisor=20 # 20 means 5% of the training data to estimate cmvn
dir=exp/conformer

decode_checkpoint=
average_checkpoint=true
average_num=10
decode_modes="attention_rescoring ctc_prefix_beam_search"
average_num=5
average_mode=step
max_step=88888888
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"

train_engine=torch_ddp

deepspeed_config=../../aishell/s0/conf/ds_stage2.json
deepspeed_save_states="model_only"
deepspeed_config=../whisper/conf/ds_stage1.json
deepspeed_save_states="model+optimizer"

dict=data/dict/lang_char.txt
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
blank_penalty=0.0
length_penalty=0.0
decode_batch=16

. tools/parse_options.sh || exit 1;

Expand Down Expand Up @@ -133,19 +156,33 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
else
echo "$0: using torch ddp"
fi

# repeat data.list, we use step_save instead of epoch_save for large datasets
train_data=data/$train_set/data.list.repeat${epoch}
if [ ! -f "${train_data}" ]; then
echo "repeat data/$train_set/data.list ${epoch} times"
for (( i=1; i<=$epoch; i++ ))
do
cat "data/$train_set/data.list" >> "${train_data}"
done
echo "save new data.list in ${train_data}, it will be used for training"
else
echo "${train_data} already exists."
fi

echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus"
torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \
--rdzv_id=2023 --rdzv_backend="c10d" \
wenet/bin/train.py \
--train_engine ${train_engine} \
--config $train_config \
--data_type "shard" \
--train_data data/$train_set/data.list \
--train_data ${train_data} \
--cv_data data/$dev_set/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 8 \
--num_workers 2 \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand All @@ -154,37 +191,52 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Test model"
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg${average_num}.pt
decode_checkpoint=$dir/avg${average_num}_mode${average_mode}_max${max_step}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--mode ${average_mode} \
--max_step ${max_step} \
--val_best
fi
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
blank_penalty=2.5
i=0
for testset in ${test_sets} ${dev_set}; do
{
base=$(basename $decode_checkpoint)
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}
python wenet/bin/recognize.py --gpu 0 \
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}_lengthpenalty${length_penalty}
mkdir -p ${result_dir}
device_id=${device_ids[i % ${#device_ids[@]}]}
echo "Testing ${testset} on GPU ${device_id}"
python wenet/bin/recognize.py --gpu ${device_id} \
--modes $decode_modes \
--config $dir/train.yaml \
--data_type "shard" \
--test_data data/$testset/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 32 \
--batch_size ${decode_batch} \
--blank_penalty ${blank_penalty} \
--length_penalty ${length_penalty} \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $result_dir \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} &
((i++))
if [[ $device_id -eq $((num_gpus - 1)) ]]; then
wait
fi
}
done
wait
for testset in ${test_sets} ${dev_set}; do
{
base=$(basename $decode_checkpoint)
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}_lengthpenalty${length_penalty}
mkdir -p ${result_dir}
for mode in ${decode_modes}; do
python tools/compute-wer.py --char=1 --v=1 \
data/$testset/text $result_dir/$mode/text > $result_dir/$mode/wer
Expand Down
13 changes: 13 additions & 0 deletions examples/wenetspeech/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ python local/modify_ckpt.py \
| attention | 7.27 % N=328207 C=308016 S=11392 D=8799 I=3672 | 7.90 % N=414097 C=383382 S=18954 D=11761 I=2018 | 13.00 % N=220358 C=194417 S=11788 D=14153 I=2705 |
| attention_rescoring | 8.95 % N=328207 C=305892 S=16696 D=5619 I=7056 | 10.83 % N=414097 C=371515 S=30229 D=12353 I=2269 | 15.64 % N=220358 C=193717 S=18669 D=7972 I=7812 |

## Whisper-largev3 (conv1d2, full-parameter tuning) Result (text\_fixed, see https://github.com/wenet-e2e/WenetSpeech/discussions/54)

* Feature info: using log_mel_spectrogram feature, no cmvn
* Training info: bf16, deepspeed stage1, activation checkpointing, batch dynamic12000, acc_grad 8, 8 * 3090 gpu, 48k steps (about 6 days), conf/finetune_whisper_largev3.yaml
* Decoding info: ctc_weight 0.0, average_num 5
* PR link: https://github.com/wenet-e2e/wenet/pull/2371

| decoding_method | Dev | Test\_Net | Test\_Meeting |
|:-------------------:|:----:|:---------:|:-------------:|
| ctc_greedy_search | 7.09 % N=328207 C=308643 S=16976 D=2588 I=3709 | 10.98 % N=414092 C=373301 S=33375 D=7416 I=4697 | 12.84 % N=220358 C=194928 S=18398 D=7032 I=2862 |
| attention | 4.66 % N=328207 C=315591 S=10352 D=2264 I=2692 | 6.54 % N=414092 C=389523 S=19101 D=5468 I=2513 | 8.84 % N=220358 C=202722 S=11296 D=6340 I=1839 |
| attention_rescoring | 5.99 % N=328207 C=311106 S=14807 D=2294 I=2547 | 9.27 % N=414092 C=378406 S=28993 D=6693 I=2715 | 11.47 % N=220358 C=197013 S=16716 D=6629 I=1923 |

# Frequently Asked Questions

- Q: Why are there so many insertion errors in the decoding results of CTC and attention_rescoring?
Expand Down

0 comments on commit ad663fd

Please sign in to comment.