Skip to content

Commit

Permalink
[s2t] move s2t data preprocess into paddlespeech.dataset (#3189)
Browse files Browse the repository at this point in the history
* move s2t data preprocess into paddlespeech.dataset

* avg model, compute wer, format rsl into paddlespeech.dataset

* fix format rsl

* fix avg ckpts
  • Loading branch information
zh794390558 authored Apr 23, 2023
1 parent 8c7859d commit df3be4a
Show file tree
Hide file tree
Showing 38 changed files with 1,337 additions and 1,149 deletions.
15 changes: 11 additions & 4 deletions examples/aishell/asr1/local/test.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#!/bin/bash

if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
set -e

stage=0
stop_stage=100

source utils/parse_options.sh || exit 1;

ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."


if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi

config_path=$1
decode_config_path=$2
ckpt_prefix=$3
Expand Down Expand Up @@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi

if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
echo "using sclite to compute cer..."
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

Expand Down Expand Up @@ -139,7 +140,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):


def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)

Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/dataset/aishell/aishell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddlespeech.dataset.download import download
from paddlespeech.dataset.download import unpack
from paddlespeech.utils.argparse import print_arguments

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

Expand Down Expand Up @@ -205,7 +206,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False):


def main():
print(f"args: {args}")
print_arguments(args, globals())
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)

Expand Down
20 changes: 20 additions & 0 deletions paddlespeech/dataset/s2t/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# s2t utils binaries.
from .avg_model import main as avg_ckpts_main
from .build_vocab import main as build_vocab_main
from .compute_mean_std import main as compute_mean_std_main
from .compute_wer import main as compute_wer_main
from .format_data import main as format_data_main
from .format_rsl import main as format_rsl_main
125 changes: 125 additions & 0 deletions paddlespeech/dataset/s2t/avg_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os

import numpy as np
import paddle


def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')

args = parser.parse_args()
return args


def average_checkpoints(dst_model="",
ckpt_dir="",
val_best=True,
num=5,
min_epoch=0,
max_epoch=65536):
paddle.set_device('cpu')

val_scores = []
jsons = glob.glob(f'{ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= min_epoch and epoch <= max_epoch:
val_scores.append((epoch, loss))
assert val_scores, f"Not find any valid checkpoints: {val_scores}"
val_scores = np.array(val_scores)

if val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
else:
sorted_val_scores = val_scores

beat_val_scores = sorted_val_scores[:num, 1]
selected_epochs = sorted_val_scores[:num, 0].astype(np.int64)
avg_val_score = np.mean(beat_val_scores)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
print("averaged val score = " + str(avg_val_score))

path_list = [
ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:num, 0]
]
print(path_list)

avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
states = paddle.load(path)
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
avg[k] /= num

paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')

meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")


def main():
args = define_argparse()
average_checkpoints(args)


if __name__ == '__main__':
main()
166 changes: 166 additions & 0 deletions paddlespeech/dataset/s2t/build_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build vocabulary from manifest files.
Each item in vocabulary file is a character.
"""
import argparse
import functools
import os
import tempfile
from collections import Counter

import jsonlines

from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
from paddlespeech.s2t.frontend.utility import SPACE
from paddlespeech.s2t.frontend.utility import UNK
from paddlespeech.utils.argparse import add_arguments
from paddlespeech.utils.argparse import print_arguments


def count_manifest(counter, text_feature, manifest_path):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
if isinstance(line_json['text'], str):
tokens = text_feature.tokenize(
line_json['text'], replace_space=False)

counter.update(tokens)
else:
assert isinstance(line_json['text'], list)
for text in line_json['text']:
tokens = text_feature.tokenize(text, replace_space=False)
counter.update(tokens)


def dump_text_manifest(fileobj, manifest_path, key='text'):
manifest_jsons = []
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
if isinstance(line_json[key], str):
fileobj.write(line_json[key] + "\n")
else:
assert isinstance(line_json[key], list)
for line in line_json[key]:
fileobj.write(line + "\n")


def build_vocab(manifest_paths="",
vocab_path="examples/librispeech/data/vocab.txt",
unit_type="char",
count_threshold=0,
text_keys='text',
spm_mode="unigram",
spm_vocab_size=0,
spm_model_prefix="",
spm_character_coverage=0.9995):
fout = open(vocab_path, 'w', encoding='utf-8')
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1

if unit_type == 'spm':
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
import sentencepiece as spm

fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
for manifest_path in manifest_paths:
_text_keys = [text_keys] if type(
text_keys) is not list else text_keys
for text_key in _text_keys:
dump_text_manifest(fp, manifest_path, key=text_key)
fp.close()
# train
spm.SentencePieceTrainer.Train(
input=fp.name,
vocab_size=spm_vocab_size,
model_type=spm_mode,
model_prefix=spm_model_prefix,
input_sentence_size=100000000,
character_coverage=spm_character_coverage)
os.unlink(fp.name)

# encode
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix)
counter = Counter()

for manifest_path in manifest_paths:
count_manifest(counter, text_feature, manifest_path)

count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = []
for token, count in count_sorted:
if count < count_threshold:
break
# replace space by `<space>`
token = SPACE if token == ' ' else token
tokens.append(token)

tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')

fout.write(SOS + "\n") # <sos/eos>
fout.close()


def define_argparse():
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)

# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
"You can provide multiple manifest files.",
nargs='+',
required=True)
add_arg('text_keys', str,
'text',
"keys of the text in manifest for building vocabulary. "
"You can provide multiple k.",
nargs='+')
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols")
# yapf: disable

args = parser.parse_args()
return args

def main():
args = define_argparse()
print_arguments(args, globals())
build_vocab(**vars(args))

if __name__ == '__main__':
main()
Loading

0 comments on commit df3be4a

Please sign in to comment.