Skip to content

Commit

Permalink
Revert "Revert "change g2p and other fix" (#40)"
Browse files Browse the repository at this point in the history
This reverts commit 2f687e4.
  • Loading branch information
Stardust-minus authored Sep 30, 2023
1 parent 2f687e4 commit e0c9e04
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 253 deletions.
53 changes: 0 additions & 53 deletions bert/bert-base-japanese-v3/README.md

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"intermediate_size": 4096,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 32768
Expand Down
10 changes: 10 additions & 0 deletions bert/bert-large-japanese-v2/tokenizer_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"tokenizer_class": "BertJapaneseTokenizer",
"model_max_length": 512,
"do_lower_case": false,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
"mecab_kwargs": {
"mecab_dic": "unidic_lite"
}
}
File renamed without changes.
14 changes: 10 additions & 4 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):

if language_str == "ZH":
bert = bert
ja_bert = torch.zeros(768, len(phone))
ja_bert = torch.zeros(1024, len(phone))
elif language_str == "JA":
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(768, len(phone))
ja_bert = torch.zeros(1024, len(phone))
assert bert.shape[-1] == len(phone), (
bert.shape,
len(phone),
Expand Down Expand Up @@ -208,7 +208,13 @@ def __call__(self, batch):
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)

max_text_len = max([len(x[0]) for x in batch])
max_text_len = max(
[
batch[ids_sorted_decreasing[i]][7].size(1)
for i in range(len(ids_sorted_decreasing))
]
+ [len(x[0]) for x in batch]
)
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])

Expand All @@ -221,7 +227,7 @@ def __call__(self, batch):
tone_padded = torch.LongTensor(len(batch), max_text_len)
language_padded = torch.LongTensor(len(batch), max_text_len)
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len)
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)

spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __init__(
self.language_emb = nn.Embedding(num_languages, hidden_channels)
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)

self.encoder = attentions.Encoder(
hidden_channels,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pypinyin
cn2an
gradio
av
mecab-python3
pyopenjtalk
loguru
unidic-lite
cmudict
Expand Down
7 changes: 6 additions & 1 deletion text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
Returns:
List of integers corresponding to the symbols in the text
"""
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text
for symbol in cleaned_text:
try:
phones.append(_symbol_to_id[symbol])
except KeyError:
phones.append(0) # symbol not found in ID map, use 0('_') by default
tone_start = language_tone_start_map[language]
tones = [i + tone_start for i in tones]
lang_id = language_id_map[language]
Expand Down
Loading

0 comments on commit e0c9e04

Please sign in to comment.