Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] better results on wenetspeech using revised transcripts #2371

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/wenetspeech/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@
| attention rescoring - 16 | 7.57 % N=328207 C=307065 S=15169 D=5973 I=3687 | 10.13 % N=414285 C=376854 S=28486 D=8945 I=4541 | 15.55 % N=220358 C=191270 S=21136 D=7952 I=5184 |
| attention - full | 7.73 % N=328207 C=306688 S=13166 D=8353 I=3845 | 9.44 % N=414285 C=378096 S=24532 D=11657 I=2908 | 14.98 % N=220358 C=191881 S=15303 D=13174 I=4540 |

## U2++ conformer (text\_fixed, see https://github.com/wenet-e2e/WenetSpeech/discussions/54)

* Feature info: using fbank feature, with dither 1.0, with cmvn
* Training info: lr 0.001, batch size dynamic36000, 8 gpus on 3090, acc_grad 4, 130k steps, 4.6 days
* Decoding info: ctc_weight 0.5, reverse_weight 0.0, average_num 5, blank penalty 0.0, length penalty 0.0
* PR link: https://github.com/wenet-e2e/wenet/pull/2371

| Decoding mode - Chunk size | Dev | Test\_Net | Test\_Meeting |
|:-----------------------------:|:----:|:---------:|:-------------:|
| ctc prefix beam search - full | 6.26 % N=328207 C=310671 S=15612 D=1924 I=3002 | 9.46 % N=414285 C=381373 S=26013 D=6899 I=6295 | 12.52 % N=220358 C=194801 S=19209 D=6348 I=2042 |
| attention rescoring - full | 5.90 % N=328207 C=311721 S=14597 D=1889 I=2888 | 8.96 % N=414092 C=380232 S=27606 D=6254 I=3222 | 11.99 % N=220358 C=195808 S=18243 D=6307 I=1878 |
| attention - full | 5.87 % N=328207 C=311922 S=14204 D=2081 I=2987 | 8.87 % N=414092 C=381014 S=27359 D=5719 I=3650 | 11.79 % N=220358 C=196484 S=17378 D=6496 I=2108 |

## U2++ conformer (wenetspeech plus aishell4)

* Feature info: using fbank feature, with dither 1.0, with cmvn
Expand Down
116 changes: 116 additions & 0 deletions examples/wenetspeech/s0/conf/train_u2++_conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
encoder: conformer
encoder_conf:
activation_type: swish
attention_dropout_rate: 0.1
attention_heads: 8
causal: true
cnn_module_kernel: 15
cnn_module_norm: layer_norm
dropout_rate: 0.1
gradient_checkpointing: true
input_layer: conv2d
linear_units: 2048
normalize_before: true
num_blocks: 12
output_size: 512
pos_enc_layer_type: rel_pos
positional_dropout_rate: 0.1
selfattention_layer_type: rel_selfattn
use_cnn_module: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false

decoder: bitransformer
decoder_conf:
attention_heads: 8
dropout_rate: 0.1
gradient_checkpointing: true
linear_units: 2048
num_blocks: 3
positional_dropout_rate: 0.1
r_num_blocks: 3
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: char
tokenizer_conf:
bpe_path: null
is_multilingual: false
non_lang_syms_path: null
num_languages: 1
special_tokens:
<blank>: 0
<eos>: 2
<sos>: 2
<unk>: 1
split_with_space: false
symbol_table_path: data/dict/lang_char.txt

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: data/train_l/global_cmvn
is_json_cmvn: true

model: asr_model
model_conf:
ctc_weight: 0.3
length_normalized_loss: false
lsm_weight: 0.1
reverse_weight: 0.3

dataset: asr
dataset_conf:
batch_conf:
batch_size: 32
batch_type: dynamic
max_frames_in_batch: 36000
fbank_conf:
dither: 1.0
frame_length: 25
frame_shift: 10
num_mel_bins: 80
filter_conf:
max_length: 4096
max_output_input_ratio: 0.25
min_length: 10
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
shuffle: true
shuffle_conf:
shuffle_size: 5000
sort: true
sort_conf:
sort_size: 1000
spec_aug: true
spec_aug_conf:
max_f: 30
max_t: 50
num_f_mask: 2
num_t_mask: 2
spec_sub: true
spec_sub_conf:
max_t: 30
num_t_sub: 3
spec_trim: true
spec_trim_conf:
max_t: 30
speed_perturb: true

grad_clip: 5
accum_grad: 4
max_epoch: 1 # NOTE(xcsong): Configure the epoch in run.sh
log_interval: 100
save_interval: 1000 # NOTE(xcsong): we use step_save instead of epoch_save for large datasets

optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 50000
88 changes: 70 additions & 18 deletions examples/wenetspeech/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@ fi
# if you don't want to utilize all available GPU resources.
export CUDA_VISIBLE_DEVICES="${gpu_list}"
echo "CUDA_VISIBLE_DEVICES is ${CUDA_VISIBLE_DEVICES}"
stage=0

cuda_visible_devices=${CUDA_VISIBLE_DEVICES:-""}
if [ -z "$cuda_visible_devices" ]; then
echo "CUDA_VISIBLE_DEVICES is not set. Using default device_ids."
device_ids=(0 1 2 3 4 5 6 7)
else
IFS=',' read -r -a device_ids <<< "$cuda_visible_devices"
echo "Using CUDA_VISIBLE_DEVICES: $cuda_visible_devices"
fi
echo "Parsed device_ids: ${device_ids[@]}"

stage=4
stop_stage=5

# You should change the following two parameters for multiple machine training,
Expand All @@ -36,22 +47,34 @@ train_set=train_`echo $set | tr 'A-Z' 'a-z'`
dev_set=dev
test_sets="test_net test_meeting"

train_config=conf/train_conformer.yaml
# NOTE(xcsong): we use step_save instead of epoch_save for large datasets
epoch=100

train_config=conf/train_u2++_conformer.yaml
checkpoint=
dir=exp/u2pp_conformer

cmvn_sampling_divisor=20 # 20 means 5% of the training data to estimate cmvn
dir=exp/conformer

decode_checkpoint=
average_checkpoint=true
average_num=10
decode_modes="attention_rescoring ctc_prefix_beam_search"
average_num=5
average_mode=step
max_step=88888888
decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring"

train_engine=torch_ddp

deepspeed_config=../../aishell/s0/conf/ds_stage2.json
deepspeed_save_states="model_only"
deepspeed_config=../whisper/conf/ds_stage1.json
deepspeed_save_states="model+optimizer"

dict=data/dict/lang_char.txt
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
blank_penalty=0.0
length_penalty=0.0
decode_batch=16

. tools/parse_options.sh || exit 1;

Expand Down Expand Up @@ -133,19 +156,33 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
else
echo "$0: using torch ddp"
fi

# repeat data.list, we use step_save instead of epoch_save for large datasets
train_data=data/$train_set/data.list.repeat${epoch}
if [ ! -f "${train_data}" ]; then
echo "repeat data/$train_set/data.list ${epoch} times"
for (( i=1; i<=$epoch; i++ ))
do
cat "data/$train_set/data.list" >> "${train_data}"
done
echo "save new data.list in ${train_data}, it will be used for training"
else
echo "${train_data} already exists."
fi

echo "$0: num_nodes is $num_nodes, proc_per_node is $num_gpus"
torchrun --nnodes=$num_nodes --nproc_per_node=$num_gpus --rdzv_endpoint=$HOST_NODE_ADDR \
--rdzv_id=2023 --rdzv_backend="c10d" \
wenet/bin/train.py \
--train_engine ${train_engine} \
--config $train_config \
--data_type "shard" \
--train_data data/$train_set/data.list \
--train_data ${train_data} \
--cv_data data/$dev_set/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.dist_backend $dist_backend \
--num_workers 8 \
--num_workers 2 \
--pin_memory \
--deepspeed_config ${deepspeed_config} \
--deepspeed.save_states ${deepspeed_save_states}
Expand All @@ -154,37 +191,52 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Test model"
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg${average_num}.pt
decode_checkpoint=$dir/avg${average_num}_mode${average_mode}_max${max_step}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--mode ${average_mode} \
--max_step ${max_step} \
--val_best
fi
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
blank_penalty=2.5
i=0
for testset in ${test_sets} ${dev_set}; do
{
base=$(basename $decode_checkpoint)
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}
python wenet/bin/recognize.py --gpu 0 \
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}_lengthpenalty${length_penalty}
mkdir -p ${result_dir}
device_id=${device_ids[i % ${#device_ids[@]}]}
echo "Testing ${testset} on GPU ${device_id}"
python wenet/bin/recognize.py --gpu ${device_id} \
--modes $decode_modes \
--config $dir/train.yaml \
--data_type "shard" \
--test_data data/$testset/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 32 \
--batch_size ${decode_batch} \
--blank_penalty ${blank_penalty} \
--length_penalty ${length_penalty} \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $result_dir \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} &
((i++))
if [[ $device_id -eq $((num_gpus - 1)) ]]; then
wait
fi
}
done
wait
for testset in ${test_sets} ${dev_set}; do
{
base=$(basename $decode_checkpoint)
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}_lengthpenalty${length_penalty}
mkdir -p ${result_dir}
for mode in ${decode_modes}; do
python tools/compute-wer.py --char=1 --v=1 \
data/$testset/text $result_dir/$mode/text > $result_dir/$mode/wer
Expand Down
13 changes: 13 additions & 0 deletions examples/wenetspeech/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ python local/modify_ckpt.py \
| attention | 7.27 % N=328207 C=308016 S=11392 D=8799 I=3672 | 7.90 % N=414097 C=383382 S=18954 D=11761 I=2018 | 13.00 % N=220358 C=194417 S=11788 D=14153 I=2705 |
| attention_rescoring | 8.95 % N=328207 C=305892 S=16696 D=5619 I=7056 | 10.83 % N=414097 C=371515 S=30229 D=12353 I=2269 | 15.64 % N=220358 C=193717 S=18669 D=7972 I=7812 |

## Whisper-largev3 (conv1d2, full-parameter tuning) Result (text\_fixed, see https://github.com/wenet-e2e/WenetSpeech/discussions/54)

* Feature info: using log_mel_spectrogram feature, no cmvn
* Training info: bf16, deepspeed stage1, activation checkpointing, batch dynamic12000, acc_grad 8, 8 * 3090 gpu, 48k steps (about 6 days), conf/finetune_whisper_largev3.yaml
* Decoding info: ctc_weight 0.0, average_num 5
* PR link: https://github.com/wenet-e2e/wenet/pull/2371

| decoding_method | Dev | Test\_Net | Test\_Meeting |
|:-------------------:|:----:|:---------:|:-------------:|
| ctc_greedy_search | 7.09 % N=328207 C=308643 S=16976 D=2588 I=3709 | 10.98 % N=414092 C=373301 S=33375 D=7416 I=4697 | 12.84 % N=220358 C=194928 S=18398 D=7032 I=2862 |
| attention | 4.66 % N=328207 C=315591 S=10352 D=2264 I=2692 | 6.54 % N=414092 C=389523 S=19101 D=5468 I=2513 | 8.84 % N=220358 C=202722 S=11296 D=6340 I=1839 |
| attention_rescoring | 5.99 % N=328207 C=311106 S=14807 D=2294 I=2547 | 9.27 % N=414092 C=378406 S=28993 D=6693 I=2715 | 11.47 % N=220358 C=197013 S=16716 D=6629 I=1923 |

# Frequently Asked Questions

- Q: Why are there so many insertion errors in the decoding results of CTC and attention_rescoring?
Expand Down
Loading