Skip to content

Commit

Permalink
add typehint for g2pw (#2390)
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 authored Sep 16, 2022
1 parent 68c2ec7 commit eac3620
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 58 deletions.
2 changes: 1 addition & 1 deletion paddlespeech/t2s/frontend/g2pw/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter
from .onnx_api import G2PWOnnxConverter
66 changes: 34 additions & 32 deletions paddlespeech/t2s/frontend/g2pw/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np

from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
Expand All @@ -23,22 +27,17 @@


def prepare_onnx_input(tokenizer,
labels,
char2phonemes,
chars,
texts,
query_ids,
phonemes=None,
pos_tags=None,
use_mask=False,
use_char_phoneme=False,
use_pos=False,
window_size=None,
max_len=512):
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size,
texts, query_ids)

truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
Expand All @@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]

try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}

text, query_id, tokens, text2token, token2text = _truncate(
max_len, text, query_id, tokens, text2token, token2text)
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)

processed_tokens = ['[CLS]'] + tokens + ['[SEP]']

Expand Down Expand Up @@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
return outputs


def _truncate_texts(window_size, texts, query_ids):
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
Expand All @@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
return truncated_texts, truncated_query_ids


def _truncate(max_len, text, query_id, tokens, text2token, token2text):
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
Expand All @@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])


def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
if lb_path is None:
return texts, query_ids
else:
phonemes = open(lb_path).read().rstrip().split('\n')
return texts, query_ids, phonemes


def get_phoneme_labels(polyphonic_chars):
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
Expand All @@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
return labels, char2phonemes


def get_char_phoneme_labels(polyphonic_chars):
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
Expand Down
50 changes: 30 additions & 20 deletions paddlespeech/t2s/frontend/g2pw/onnx_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
"""
import json
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import onnxruntime
Expand All @@ -37,7 +41,8 @@
model_version = '1.1'


def predict(session, onnx_input, labels):
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
Expand All @@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):

class G2PWOnnxConverter:
def __init__(self,
model_dir=MODEL_HOME,
style='bopomofo',
model_source=None,
enable_non_tradional_chinese=False):
model_dir: os.PathLike=MODEL_HOME,
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel'][model_version], model_dir)

Expand All @@ -76,7 +81,8 @@ def __init__(self,
os.path.join(uncompress_path, 'g2pW.onnx'),
sess_options=sess_options)
self.config = load_config(
os.path.join(uncompress_path, 'config.py'), use_default=True)
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)

self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
Expand All @@ -103,9 +109,9 @@ def __init__(self,
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
self.polyphonic_chars
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
self.polyphonic_chars)
polyphonic_chars=self.polyphonic_chars)

self.chars = sorted(list(self.char2phonemes.keys()))

Expand Down Expand Up @@ -146,7 +152,7 @@ def __init__(self,
if self.enable_opencc:
self.cc = OpenCC('s2tw')

def _convert_bopomofo_to_pinyin(self, bopomofo):
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
Expand All @@ -156,7 +162,7 @@ def _convert_bopomofo_to_pinyin(self, bopomofo):
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None

def __call__(self, sentences):
def __call__(self, sentences: List[str]) -> List[List[str]]:
if isinstance(sentences, str):
sentences = [sentences]

Expand All @@ -169,23 +175,25 @@ def __call__(self, sentences):
sentences = translated_sentences

texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences)
sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results

onnx_input = prepare_onnx_input(
self.tokenizer,
self.labels,
self.char2phonemes,
self.chars,
texts,
query_ids,
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
use_char_phoneme=self.config.use_char_phoneme,
window_size=None)

preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]

Expand All @@ -195,7 +203,9 @@ def __call__(self, sentences):

return results

def _prepare_data(self, sentences):
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
Expand Down
11 changes: 6 additions & 5 deletions paddlespeech/t2s/frontend/g2pw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re


def wordize_and_map(text):
def wordize_and_map(text: str):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
Expand Down Expand Up @@ -54,8 +55,8 @@ def wordize_and_map(text):
return words, index_map_from_text_to_word, index_map_from_word_to_text


def tokenize_and_map(tokenizer, text):
words, text2word, word2text = wordize_and_map(text)
def tokenize_and_map(tokenizer, text: str):
words, text2word, word2text = wordize_and_map(text=text)

tokens = []
index_map_from_token_to_text = []
Expand All @@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
return tokens, index_map_from_text_to_token, index_map_from_token_to_text


def _load_config(config_path):
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
Expand Down Expand Up @@ -130,7 +131,7 @@ def _load_config(config_path):
}


def load_config(config_path, use_default=False):
def load_config(config_path: os.PathLike, use_default: bool=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
Expand Down

0 comments on commit eac3620

Please sign in to comment.