Skip to content

Commit

Permalink
[decode] blank penalty (#2278)
Browse files Browse the repository at this point in the history
* [decode] blank penalty

* [decode] add result to readme

* [decode] fix blank_id
  • Loading branch information
xingchensong authored Jan 6, 2024
1 parent cacc562 commit 62a486f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
10 changes: 5 additions & 5 deletions examples/wenetspeech/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

* Feature info: using fbank feature, with dither 1.0, with cmvn
* Training info: lr 0.002, batch size dynamic24000, 24 gpus on 3090, acc_grad 16, 80 epochs, 4.5 days
* Decoding info: ctc_weight 0.5, reverse_weight 0.0, average_num 10
* Decoding info: ctc_weight 0.5, reverse_weight 0.0, average_num 10, blank penalty 2.5

| Decoding mode - Chunk size | Dev | Test\_Net | Test\_Meeting |
|:-----------------------------:|:----:|:---------:|:-------------:|
| ctc greedy search - full | 8.50 | 9.47 | 15.77 |
| ctc greedy search - 16 | 9.01 | 11.14 | 16.89 |
| attention rescoring - full | 8.37 | 9.02 | 15.52 |
| attention rescoring - 16 | 8.62 | 10.30 | 16.32 |
| ctc prefix beam search - full | 7.21 % N=328207 C=309358 S=14175 D=4674 I=4801 | 9.46 % N=414285 C=381373 S=26013 D=6899 I=6295 | 14.02 % N=220358 C=195224 S=17266 D=7868 I=5754 |
| ctc prefix beam search - 16 | 7.93 % N=328207 C=307192 S=16529 D=4486 I=5000 | 11.14 % N=414285 C=374733 S=30241 D=9311 I=6596 | 16.37 % N=220358 C=191394 S=22435 D=6529 I=7116 |
| attention rescoring - full | 7.10 % N=328207 C=308457 S=13215 D=6535 I=3537 | 8.83 % N=414285 C=381936 S=24808 D=7541 I=4215 | 13.64 % N=220358 C=194438 S=16238 D=9682 I=4133 |
| attention rescoring - 16 | 7.57 % N=328207 C=307065 S=15169 D=5973 I=3687 | 10.13 % N=414285 C=376854 S=28486 D=8945 I=4541 | 15.55 % N=220358 C=191270 S=21136 D=7952 I=5184 |
12 changes: 7 additions & 5 deletions examples/wenetspeech/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ dir=exp/conformer
decode_checkpoint=
average_checkpoint=true
average_num=10
decode_modes="attention_rescoring ctc_greedy_search"
decode_modes="attention_rescoring ctc_prefix_beam_search"

train_engine=torch_ddp

deepspeed_config=../../aishell/s0/conf/ds_stage2.json
deepspeed_save_states="model_only"

dict=data/dict/lang_char.txt

. tools/parse_options.sh || exit 1;

set -u
Expand All @@ -70,7 +72,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data || exit 1;
fi

dict=data/dict/lang_char.txt
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Make a dictionary"
echo "dictionary: ${dict}"
Expand Down Expand Up @@ -166,19 +167,20 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
decoding_chunk_size=
ctc_weight=0.5
reverse_weight=0.0
blank_penalty=2.5
for testset in ${test_sets} ${dev_set}; do
{
base=$(basename $decode_checkpoint)
result_dir=$dir/${testset}_${base}
result_dir=$dir/${testset}_${base}_chunk${decoding_chunk_size}_ctc${ctc_weight}_reverse${reverse_weight}_blankpenalty${blank_penalty}
python wenet/bin/recognize.py --gpu 0 \
--modes $decode_modes \
--config $dir/train.yaml \
--data_type "shard" \
--test_data data/$testset/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--batch_size 32 \
--blank_penalty ${blank_penalty} \
--ctc_weight $ctc_weight \
--reverse_weight $reverse_weight \
--result_dir $result_dir \
Expand Down
7 changes: 6 additions & 1 deletion wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def get_args():
type=float,
default=0.0,
help='length penalty')
parser.add_argument('--blank_penalty',
type=float,
default=0.0,
help='blank penalty')
parser.add_argument('--result_dir', required=True, help='asr result file')
parser.add_argument('--batch_size',
type=int,
Expand Down Expand Up @@ -251,7 +255,8 @@ def main():
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight,
context_graph=context_graph,
blank_id=blank_id)
blank_id=blank_id,
blank_penalty=args.blank_penalty)
for i, key in enumerate(keys):
for mode, hyps in results.items():
tokens = hyps[i].tokens
Expand Down
11 changes: 9 additions & 2 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def decode(
reverse_weight: float = 0.0,
context_graph: ContextGraph = None,
blank_id: int = 0,
blank_penalty: float = 0.0,
) -> Dict[str, List[DecodeResult]]:
""" Decode input speech
Expand Down Expand Up @@ -266,7 +267,12 @@ def decode(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
encoder_lens = encoder_mask.squeeze(1).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out)
if blank_penalty > 0.0:
logits = self.ctc.ctc_lo(encoder_out)
logits[:, :, blank_id] -= blank_penalty
ctc_probs = logits.log_softmax(dim=2)
else:
ctc_probs = self.ctc.log_softmax(encoder_out)
results = {}
if 'attention' in methods:
results['attention'] = attention_beam_search(
Expand All @@ -285,7 +291,8 @@ def decode(
ctc_prefix_result = results['ctc_prefix_beam_search']
else:
ctc_prefix_result = ctc_prefix_beam_search(
ctc_probs, encoder_lens, beam_size, context_graph)
ctc_probs, encoder_lens, beam_size, context_graph,
blank_id)
if self.apply_non_blank_embedding:
encoder_out, _ = self.filter_blank_embedding(
ctc_probs, encoder_out)
Expand Down

0 comments on commit 62a486f

Please sign in to comment.