Skip to content

Commit

Permalink
Fix japanese cleaner (#61)
Browse files Browse the repository at this point in the history
* 初步,睡觉明天继续写(

* 好好好放错分支了,熬夜是大忌

* [pre-commit.ci] pre-commit autoupdate (#55)

* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](pre-commit/pre-commit-hooks@v4.4.0...v4.5.0)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Create tokenizer_config.json

* update preprocess_text.py:过滤一个音频匹配多个文本的情况 (#57)

* update preprocess_text.py:过滤音频不存在的情况 (#58)

* 修复日语cleaner和bert

* better

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Stardust·减 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sora <[email protected]>
  • Loading branch information
4 people authored Oct 12, 2023
1 parent d2210b2 commit 09e8146
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand Down
10 changes: 10 additions & 0 deletions bert/bert-base-japanese-v3/tokenizer_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"tokenizer_class": "BertJapaneseTokenizer",
"model_max_length": 512,
"do_lower_case": false,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
"mecab_kwargs": {
"mecab_dic": "unidic_lite"
}
}
15 changes: 15 additions & 0 deletions preprocess_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os.path
from collections import defaultdict
from random import shuffle
from typing import Optional
Expand Down Expand Up @@ -67,13 +68,27 @@ def main(
current_sid = 0

with open(transcription_path, encoding="utf-8") as f:
audioPaths = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
if utt in audioPaths:
# 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile(utt):
print(f"没有找到对应的音频:{utt}")
countNotFound += 1
continue
audioPaths.add(utt)
spk_utt_map[spk].append(line)

if spk not in spk_id_map.keys():
spk_id_map[spk] = current_sid
current_sid += 1
print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")

train_list = []
val_list = []
Expand Down
58 changes: 39 additions & 19 deletions text/japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,31 @@ def text2kata(text: str) -> str:
return hira2kata("".join(res))


def text2sep_kata(text: str) -> (list, list):
parsed = _TAGGER.parse(text)
res = []
sep = []
for line in parsed.split("\n"):
if line == "EOS":
break
parts = line.split("\t")

word, yomi = parts[0], parts[1]
if yomi:
res.append(yomi)
else:
if word in _SYMBOL_TOKENS:
res.append(word)
elif word in ("っ", "ッ"):
res.append("ッ")
elif word in _NO_YOMI_TOKENS:
pass
else:
res.append(word)
sep.append(word)
return sep, [hira2kata(i) for i in res]


_ALPHASYMBOL_YOMI = {
"#": "シャープ",
"%": "パーセント",
Expand Down Expand Up @@ -505,7 +530,7 @@ def is_japanese_character(char):
"\n": ".",
"·": ",",
"、": ",",
"...": "",
"": "...",
}


Expand Down Expand Up @@ -546,28 +571,22 @@ def distribute_phone(n_phone, n_word):


def g2p(norm_text):
tokenized = tokenizer.tokenize(norm_text)
phs = []
ph_groups = []
for t in tokenized:
if not t.startswith("#"):
ph_groups.append([t])
else:
ph_groups[-1].append(t.replace("#", ""))
sep_text, sep_kata = text2sep_kata(norm_text)
sep_tokenized = [tokenizer.tokenize(i) for i in sep_text]
sep_phonemes = [kata2phoneme(i) for i in sep_kata]
# 异常处理,MeCab不认识的词的话会一路传到这里来,然后炸掉。目前来看只有那些超级稀有的生僻词会出现这种情况
for i in sep_phonemes:
for j in i:
assert j in symbols, (sep_text, sep_kata, sep_phonemes)

word2ph = []
for group in ph_groups:
phonemes = kata2phoneme(text2kata("".join(group)))
# phonemes = [i for i in phonemes if i in symbols]
for i in phonemes:
assert i in symbols, (group, norm_text, tokenized)
phone_len = len(phonemes)
word_len = len(group)
for token, phoneme in zip(sep_tokenized, sep_phonemes):
phone_len = len(phoneme)
word_len = len(token)

aaa = distribute_phone(phone_len, word_len)
word2ph += aaa

phs += phonemes
phones = ["_"] + phs + ["_"]
phones = ["_"] + [j for i in sep_phonemes for j in i] + ["_"]
tones = [0 for i in phones]
word2ph = [1] + word2ph + [1]
return phones, tones, word2ph
Expand All @@ -580,6 +599,7 @@ def g2p(norm_text):

text = text_normalize(text)
print(text)

phones, tones, word2ph = g2p(text)
bert = get_bert_feature(text, word2ph)

Expand Down
23 changes: 20 additions & 3 deletions text/japanese_bert.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys
from text.japanese import text2sep_kata

tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3")

models = dict()


def get_bert_feature(text, word2ph, device=None):
sep_text, _ = text2sep_kata(text)
sep_tokens = [tokenizer.tokenize(t) for t in sep_text]
sep_ids = [tokenizer.convert_tokens_to_ids(t) for t in sep_tokens]
sep_ids = [2] + [item for sublist in sep_ids for item in sublist] + [3]
return get_bert_feature_with_token(sep_ids, word2ph, device)


def get_bert_feature_with_token(tokens, word2ph, device=None):
if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
Expand All @@ -21,9 +30,17 @@ def get_bert_feature(text, word2ph, device=None):
"./bert/bert-base-japanese-v3"
).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
inputs = torch.tensor(tokens).to(device).unsqueeze(0)
token_type_ids = torch.zeros_like(inputs).to(device)
attention_mask = torch.ones_like(inputs).to(device)
inputs = {
"input_ids": inputs,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}

# for i in inputs:
# inputs[i] = inputs[i].to(device)
res = models[device](**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
assert inputs["input_ids"].shape[-1] == len(word2ph)
Expand Down
2 changes: 2 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def tts_fn(
length_scale=length_scale,
sid=speaker,
language=language,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
Expand Down

0 comments on commit 09e8146

Please sign in to comment.