From 9df547f63bd1b6a5a6a6c0996ebe6f14d09917f2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 23 Sep 2023 12:24:38 +0000 Subject: [PATCH] Format code --- data_utils.py | 8 +++++++- text/__init__.py | 2 +- text/japanese.py | 18 +++++++++++------- text/japanese_bert.py | 8 +++----- text/symbols.py | 2 +- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/data_utils.py b/data_utils.py index af7d601bd..7f490e39f 100644 --- a/data_utils.py +++ b/data_utils.py @@ -208,7 +208,13 @@ def __call__(self, batch): torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True ) - max_text_len = max([batch[ids_sorted_decreasing[i]][7].size(1) for i in range(len(ids_sorted_decreasing))] + [len(x[0]) for x in batch]) + max_text_len = max( + [ + batch[ids_sorted_decreasing[i]][7].size(1) + for i in range(len(ids_sorted_decreasing)) + ] + + [len(x[0]) for x in batch] + ) max_spec_len = max([x[1].size(1) for x in batch]) max_wav_len = max([x[2].size(1) for x in batch]) diff --git a/text/__init__.py b/text/__init__.py index 897ff10b7..35b8e2105 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -11,7 +11,7 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): Returns: List of integers corresponding to the symbols in the text """ - phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text + phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text for symbol in cleaned_text: try: phones.append(_symbol_to_id[symbol]) diff --git a/text/japanese.py b/text/japanese.py index bea2fd511..d1279b6b1 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -11,7 +11,8 @@ import pyopenjtalk from num2words import num2words -BERT = './bert/bert-large-japanese-v2' + +BERT = "./bert/bert-large-japanese-v2" _CONVRULES = [ # Conversion of 2 letters "アァ/ a a", @@ -385,6 +386,7 @@ def replace_punctuation(text): replaced_text = replaced_text.replace(x, rep_map[x]) return replaced_text + def text_normalize(text): res = unicodedata.normalize("NFKC", text) res = japanese_convert_numbers_to_words(res) @@ -393,27 +395,29 @@ def text_normalize(text): return res - tokenizer = AutoTokenizer.from_pretrained(BERT) + def g2p(norm_text): tokenized = tokenizer.tokenize(norm_text) st = [x.replace("#", "") for x in tokenized] word2ph = [] - phs = pyopenjtalk.g2p(norm_text).split(" ") # Directly use the entire norm_text sequence. + phs = pyopenjtalk.g2p(norm_text).split( + " " + ) # Directly use the entire norm_text sequence. for sub in st: # the following code is only for calculating word2ph wph = 0 for x in sub: sys.stdout.flush() - if x not in ['?', '.', '!', '…', ',']: # This will throw warnings. + if x not in ["?", ".", "!", "…", ","]: # This will throw warnings. phonemes = pyopenjtalk.g2p(x) else: - phonemes = 'pau' + phonemes = "pau" # for x in range(repeat): - wph += len(phonemes.split(' ')) + wph += len(phonemes.split(" ")) # print(f'{x}-->:{phones}') word2ph.append(wph) - phonemes = ['_'] + phs + ['_'] + phonemes = ["_"] + phs + ["_"] tones = [0 for i in phonemes] word2ph = [1] + word2ph + [1] return phonemes, tones, word2ph diff --git a/text/japanese_bert.py b/text/japanese_bert.py index f83e36cdf..0953b1bd3 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -2,7 +2,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM import sys -BERT = './bert/bert-large-japanese-v2' +BERT = "./bert/bert-large-japanese-v2" tokenizer = AutoTokenizer.from_pretrained(BERT) # bert-large model has 25 hidden layers.You can decide which layer to use by setting this variable to a specific value @@ -19,15 +19,13 @@ def get_bert_feature(text, word2ph, device=None): device = "mps" if not device: device = "cuda" - model = AutoModelForMaskedLM.from_pretrained(BERT).to( - device - ) + model = AutoModelForMaskedLM.from_pretrained(BERT).to(device) with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) res = model(**inputs, output_hidden_states=True) - res = res['hidden_states'][BERT_LAYER] + res = res["hidden_states"][BERT_LAYER] assert inputs["input_ids"].shape[-1] == len(word2ph) word2phone = word2ph phone_level_feature = [] diff --git a/text/symbols.py b/text/symbols.py index 3690218c2..94f0f63cd 100644 --- a/text/symbols.py +++ b/text/symbols.py @@ -118,7 +118,7 @@ "z", "zy", ] -for x in range(ord('a'), ord('z')+1): +for x in range(ord("a"), ord("z") + 1): ja_symbols.append(chr(x).upper()) num_ja_tones = 1