Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Sep 23, 2023
1 parent 732e9ce commit 9df547f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
8 changes: 7 additions & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,13 @@ def __call__(self, batch):
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)

max_text_len = max([batch[ids_sorted_decreasing[i]][7].size(1) for i in range(len(ids_sorted_decreasing))] + [len(x[0]) for x in batch])
max_text_len = max(
[
batch[ids_sorted_decreasing[i]][7].size(1)
for i in range(len(ids_sorted_decreasing))
]
+ [len(x[0]) for x in batch]
)
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])

Expand Down
2 changes: 1 addition & 1 deletion text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
Returns:
List of integers corresponding to the symbols in the text
"""
phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text
phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text
for symbol in cleaned_text:
try:
phones.append(_symbol_to_id[symbol])
Expand Down
18 changes: 11 additions & 7 deletions text/japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import pyopenjtalk

from num2words import num2words
BERT = './bert/bert-large-japanese-v2'

BERT = "./bert/bert-large-japanese-v2"
_CONVRULES = [
# Conversion of 2 letters
"アァ/ a a",
Expand Down Expand Up @@ -385,6 +386,7 @@ def replace_punctuation(text):
replaced_text = replaced_text.replace(x, rep_map[x])
return replaced_text


def text_normalize(text):
res = unicodedata.normalize("NFKC", text)
res = japanese_convert_numbers_to_words(res)
Expand All @@ -393,27 +395,29 @@ def text_normalize(text):
return res



tokenizer = AutoTokenizer.from_pretrained(BERT)


def g2p(norm_text):
tokenized = tokenizer.tokenize(norm_text)
st = [x.replace("#", "") for x in tokenized]
word2ph = []
phs = pyopenjtalk.g2p(norm_text).split(" ") # Directly use the entire norm_text sequence.
phs = pyopenjtalk.g2p(norm_text).split(
" "
) # Directly use the entire norm_text sequence.
for sub in st: # the following code is only for calculating word2ph
wph = 0
for x in sub:
sys.stdout.flush()
if x not in ['?', '.', '!', '…', ',']: # This will throw warnings.
if x not in ["?", ".", "!", "…", ","]: # This will throw warnings.
phonemes = pyopenjtalk.g2p(x)
else:
phonemes = 'pau'
phonemes = "pau"
# for x in range(repeat):
wph += len(phonemes.split(' '))
wph += len(phonemes.split(" "))
# print(f'{x}-->:{phones}')
word2ph.append(wph)
phonemes = ['_'] + phs + ['_']
phonemes = ["_"] + phs + ["_"]
tones = [0 for i in phonemes]
word2ph = [1] + word2ph + [1]
return phonemes, tones, word2ph
Expand Down
8 changes: 3 additions & 5 deletions text/japanese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys

BERT = './bert/bert-large-japanese-v2'
BERT = "./bert/bert-large-japanese-v2"

tokenizer = AutoTokenizer.from_pretrained(BERT)
# bert-large model has 25 hidden layers.You can decide which layer to use by setting this variable to a specific value
Expand All @@ -19,15 +19,13 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
model = AutoModelForMaskedLM.from_pretrained(BERT).to(
device
)
model = AutoModelForMaskedLM.from_pretrained(BERT).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = res['hidden_states'][BERT_LAYER]
res = res["hidden_states"][BERT_LAYER]
assert inputs["input_ids"].shape[-1] == len(word2ph)
word2phone = word2ph
phone_level_feature = []
Expand Down
2 changes: 1 addition & 1 deletion text/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
"z",
"zy",
]
for x in range(ord('a'), ord('z')+1):
for x in range(ord("a"), ord("z") + 1):
ja_symbols.append(chr(x).upper())
num_ja_tones = 1

Expand Down

0 comments on commit 9df547f

Please sign in to comment.