-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[s2t] move s2t data preprocess into paddlespeech.dataset (#3189)
* move s2t data preprocess into paddlespeech.dataset * avg model, compute wer, format rsl into paddlespeech.dataset * fix format rsl * fix avg ckpts
- Loading branch information
1 parent
8c7859d
commit df3be4a
Showing
38 changed files
with
1,337 additions
and
1,149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# s2t utils binaries. | ||
from .avg_model import main as avg_ckpts_main | ||
from .build_vocab import main as build_vocab_main | ||
from .compute_mean_std import main as compute_mean_std_main | ||
from .compute_wer import main as compute_wer_main | ||
from .format_data import main as format_data_main | ||
from .format_rsl import main as format_rsl_main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import argparse | ||
import glob | ||
import json | ||
import os | ||
|
||
import numpy as np | ||
import paddle | ||
|
||
|
||
def define_argparse(): | ||
parser = argparse.ArgumentParser(description='average model') | ||
parser.add_argument('--dst_model', required=True, help='averaged model') | ||
parser.add_argument( | ||
'--ckpt_dir', required=True, help='ckpt model dir for average') | ||
parser.add_argument( | ||
'--val_best', action="store_true", help='averaged model') | ||
parser.add_argument( | ||
'--num', default=5, type=int, help='nums for averaged model') | ||
parser.add_argument( | ||
'--min_epoch', | ||
default=0, | ||
type=int, | ||
help='min epoch used for averaging model') | ||
parser.add_argument( | ||
'--max_epoch', | ||
default=65536, # Big enough | ||
type=int, | ||
help='max epoch used for averaging model') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def average_checkpoints(dst_model="", | ||
ckpt_dir="", | ||
val_best=True, | ||
num=5, | ||
min_epoch=0, | ||
max_epoch=65536): | ||
paddle.set_device('cpu') | ||
|
||
val_scores = [] | ||
jsons = glob.glob(f'{ckpt_dir}/[!train]*.json') | ||
jsons = sorted(jsons, key=os.path.getmtime, reverse=True) | ||
for y in jsons: | ||
with open(y, 'r') as f: | ||
dic_json = json.load(f) | ||
loss = dic_json['val_loss'] | ||
epoch = dic_json['epoch'] | ||
if epoch >= min_epoch and epoch <= max_epoch: | ||
val_scores.append((epoch, loss)) | ||
assert val_scores, f"Not find any valid checkpoints: {val_scores}" | ||
val_scores = np.array(val_scores) | ||
|
||
if val_best: | ||
sort_idx = np.argsort(val_scores[:, 1]) | ||
sorted_val_scores = val_scores[sort_idx] | ||
else: | ||
sorted_val_scores = val_scores | ||
|
||
beat_val_scores = sorted_val_scores[:num, 1] | ||
selected_epochs = sorted_val_scores[:num, 0].astype(np.int64) | ||
avg_val_score = np.mean(beat_val_scores) | ||
print("selected val scores = " + str(beat_val_scores)) | ||
print("selected epochs = " + str(selected_epochs)) | ||
print("averaged val score = " + str(avg_val_score)) | ||
|
||
path_list = [ | ||
ckpt_dir + '/{}.pdparams'.format(int(epoch)) | ||
for epoch in sorted_val_scores[:num, 0] | ||
] | ||
print(path_list) | ||
|
||
avg = None | ||
num = args.num | ||
assert num == len(path_list) | ||
for path in path_list: | ||
print(f'Processing {path}') | ||
states = paddle.load(path) | ||
if avg is None: | ||
avg = states | ||
else: | ||
for k in avg.keys(): | ||
avg[k] += states[k] | ||
# average | ||
for k in avg.keys(): | ||
if avg[k] is not None: | ||
avg[k] /= num | ||
|
||
paddle.save(avg, args.dst_model) | ||
print(f'Saving to {args.dst_model}') | ||
|
||
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' | ||
with open(meta_path, 'w') as f: | ||
data = json.dumps({ | ||
"mode": 'val_best' if args.val_best else 'latest', | ||
"avg_ckpt": args.dst_model, | ||
"val_loss_mean": avg_val_score, | ||
"ckpts": path_list, | ||
"epochs": selected_epochs.tolist(), | ||
"val_losses": beat_val_scores.tolist(), | ||
}) | ||
f.write(data + "\n") | ||
|
||
|
||
def main(): | ||
args = define_argparse() | ||
average_checkpoints(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Build vocabulary from manifest files. | ||
Each item in vocabulary file is a character. | ||
""" | ||
import argparse | ||
import functools | ||
import os | ||
import tempfile | ||
from collections import Counter | ||
|
||
import jsonlines | ||
|
||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer | ||
from paddlespeech.s2t.frontend.utility import BLANK | ||
from paddlespeech.s2t.frontend.utility import SOS | ||
from paddlespeech.s2t.frontend.utility import SPACE | ||
from paddlespeech.s2t.frontend.utility import UNK | ||
from paddlespeech.utils.argparse import add_arguments | ||
from paddlespeech.utils.argparse import print_arguments | ||
|
||
|
||
def count_manifest(counter, text_feature, manifest_path): | ||
manifest_jsons = [] | ||
with jsonlines.open(manifest_path, 'r') as reader: | ||
for json_data in reader: | ||
manifest_jsons.append(json_data) | ||
|
||
for line_json in manifest_jsons: | ||
if isinstance(line_json['text'], str): | ||
tokens = text_feature.tokenize( | ||
line_json['text'], replace_space=False) | ||
|
||
counter.update(tokens) | ||
else: | ||
assert isinstance(line_json['text'], list) | ||
for text in line_json['text']: | ||
tokens = text_feature.tokenize(text, replace_space=False) | ||
counter.update(tokens) | ||
|
||
|
||
def dump_text_manifest(fileobj, manifest_path, key='text'): | ||
manifest_jsons = [] | ||
with jsonlines.open(manifest_path, 'r') as reader: | ||
for json_data in reader: | ||
manifest_jsons.append(json_data) | ||
|
||
for line_json in manifest_jsons: | ||
if isinstance(line_json[key], str): | ||
fileobj.write(line_json[key] + "\n") | ||
else: | ||
assert isinstance(line_json[key], list) | ||
for line in line_json[key]: | ||
fileobj.write(line + "\n") | ||
|
||
|
||
def build_vocab(manifest_paths="", | ||
vocab_path="examples/librispeech/data/vocab.txt", | ||
unit_type="char", | ||
count_threshold=0, | ||
text_keys='text', | ||
spm_mode="unigram", | ||
spm_vocab_size=0, | ||
spm_model_prefix="", | ||
spm_character_coverage=0.9995): | ||
fout = open(vocab_path, 'w', encoding='utf-8') | ||
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC | ||
fout.write(UNK + '\n') # <unk> must be 1 | ||
|
||
if unit_type == 'spm': | ||
# tools/spm_train --input=$wave_data/lang_char/input.txt | ||
# --vocab_size=${nbpe} --model_type=${bpemode} | ||
# --model_prefix=${bpemodel} --input_sentence_size=100000000 | ||
import sentencepiece as spm | ||
|
||
fp = tempfile.NamedTemporaryFile(mode='w', delete=False) | ||
for manifest_path in manifest_paths: | ||
_text_keys = [text_keys] if type( | ||
text_keys) is not list else text_keys | ||
for text_key in _text_keys: | ||
dump_text_manifest(fp, manifest_path, key=text_key) | ||
fp.close() | ||
# train | ||
spm.SentencePieceTrainer.Train( | ||
input=fp.name, | ||
vocab_size=spm_vocab_size, | ||
model_type=spm_mode, | ||
model_prefix=spm_model_prefix, | ||
input_sentence_size=100000000, | ||
character_coverage=spm_character_coverage) | ||
os.unlink(fp.name) | ||
|
||
# encode | ||
text_feature = TextFeaturizer(unit_type, "", spm_model_prefix) | ||
counter = Counter() | ||
|
||
for manifest_path in manifest_paths: | ||
count_manifest(counter, text_feature, manifest_path) | ||
|
||
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) | ||
tokens = [] | ||
for token, count in count_sorted: | ||
if count < count_threshold: | ||
break | ||
# replace space by `<space>` | ||
token = SPACE if token == ' ' else token | ||
tokens.append(token) | ||
|
||
tokens = sorted(tokens) | ||
for token in tokens: | ||
fout.write(token + '\n') | ||
|
||
fout.write(SOS + "\n") # <sos/eos> | ||
fout.close() | ||
|
||
|
||
def define_argparse(): | ||
parser = argparse.ArgumentParser(description=__doc__) | ||
add_arg = functools.partial(add_arguments, argparser=parser) | ||
|
||
# yapf: disable | ||
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") | ||
add_arg('count_threshold', int, 0, | ||
"Truncation threshold for char/word counts.Default 0, no truncate.") | ||
add_arg('vocab_path', str, | ||
'examples/librispeech/data/vocab.txt', | ||
"Filepath to write the vocabulary.") | ||
add_arg('manifest_paths', str, | ||
None, | ||
"Filepaths of manifests for building vocabulary. " | ||
"You can provide multiple manifest files.", | ||
nargs='+', | ||
required=True) | ||
add_arg('text_keys', str, | ||
'text', | ||
"keys of the text in manifest for building vocabulary. " | ||
"You can provide multiple k.", | ||
nargs='+') | ||
# bpe | ||
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") | ||
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") | ||
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") | ||
add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") | ||
# yapf: disable | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
def main(): | ||
args = define_argparse() | ||
print_arguments(args, globals()) | ||
build_vocab(**vars(args)) | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.