diff --git a/bert/bert-base-japanese-v3/README.md b/bert/bert-base-japanese-v3/README.md new file mode 100644 index 000000000..c5b345671 --- /dev/null +++ b/bert/bert-base-japanese-v3/README.md @@ -0,0 +1,53 @@ +--- +license: apache-2.0 +datasets: +- cc100 +- wikipedia +language: +- ja +widget: +- text: 東北大学で[MASK]の研究をしています。 +--- + +# BERT base Japanese (unidic-lite with whole word masking, CC-100 and jawiki-20230102) + +This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language. + +This version of the model processes input texts with word-level tokenization based on the Unidic 2.1.2 dictionary (available in [unidic-lite](https://pypi.org/project/unidic-lite/) package), followed by the WordPiece subword tokenization. +Additionally, the model is trained with the whole word masking enabled for the masked language modeling (MLM) objective. + +The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/). + +## Model architecture + +The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads. + +## Training Data + +The model is trained on the Japanese portion of [CC-100 dataset](https://data.statmt.org/cc-100/) and the Japanese version of Wikipedia. +For Wikipedia, we generated a text corpus from the [Wikipedia Cirrussearch dump file](https://dumps.wikimedia.org/other/cirrussearch/) as of January 2, 2023. +The corpus files generated from CC-100 and Wikipedia are 74.3GB and 4.9GB in size and consist of approximately 392M and 34M sentences, respectively. + +For the purpose of splitting texts into sentences, we used [fugashi](https://github.com/polm/fugashi) with [mecab-ipadic-NEologd](https://github.com/neologd/mecab-ipadic-neologd) dictionary (v0.0.7). + +## Tokenization + +The texts are first tokenized by MeCab with the Unidic 2.1.2 dictionary and then split into subwords by the WordPiece algorithm. +The vocabulary size is 32768. + +We used [fugashi](https://github.com/polm/fugashi) and [unidic-lite](https://github.com/polm/unidic-lite) packages for the tokenization. + +## Training + +We trained the model first on the CC-100 corpus for 1M steps and then on the Wikipedia corpus for another 1M steps. +For training of the MLM (masked language modeling) objective, we introduced whole word masking in which all of the subword tokens corresponding to a single word (tokenized by MeCab) are masked at once. + +For training of each model, we used a v3-8 instance of Cloud TPUs provided by [TPU Research Cloud](https://sites.research.google/trc/about/). + +## Licenses + +The pretrained models are distributed under the Apache License 2.0. + +## Acknowledgments + +This model is trained with Cloud TPUs provided by [TPU Research Cloud](https://sites.research.google/trc/about/) program. diff --git a/bert/bert-large-japanese-v2/config.json b/bert/bert-base-japanese-v3/config.json similarity index 75% rename from bert/bert-large-japanese-v2/config.json rename to bert/bert-base-japanese-v3/config.json index 427134b5f..65a2f3322 100644 --- a/bert/bert-large-japanese-v2/config.json +++ b/bert/bert-base-japanese-v3/config.json @@ -5,14 +5,14 @@ "attention_probs_dropout_prob": 0.1, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, - "hidden_size": 1024, + "hidden_size": 768, "initializer_range": 0.02, - "intermediate_size": 4096, + "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", - "num_attention_heads": 16, - "num_hidden_layers": 24, + "num_attention_heads": 12, + "num_hidden_layers": 12, "pad_token_id": 0, "type_vocab_size": 2, "vocab_size": 32768 diff --git a/bert/bert-large-japanese-v2/vocab.txt b/bert/bert-base-japanese-v3/vocab.txt similarity index 100% rename from bert/bert-large-japanese-v2/vocab.txt rename to bert/bert-base-japanese-v3/vocab.txt diff --git a/bert/bert-large-japanese-v2/tokenizer_config.json b/bert/bert-large-japanese-v2/tokenizer_config.json deleted file mode 100644 index dfbcc4690..000000000 --- a/bert/bert-large-japanese-v2/tokenizer_config.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "tokenizer_class": "BertJapaneseTokenizer", - "model_max_length": 512, - "do_lower_case": false, - "word_tokenizer_type": "mecab", - "subword_tokenizer_type": "wordpiece", - "mecab_kwargs": { - "mecab_dic": "unidic_lite" - } -} diff --git a/data_utils.py b/data_utils.py index 7f490e39f..5bf1132b3 100644 --- a/data_utils.py +++ b/data_utils.py @@ -154,13 +154,13 @@ def get_text(self, text, word2ph, phone, tone, language_str, wav_path): if language_str == "ZH": bert = bert - ja_bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(768, len(phone)) elif language_str == "JA": ja_bert = bert bert = torch.zeros(1024, len(phone)) else: bert = torch.zeros(1024, len(phone)) - ja_bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(768, len(phone)) assert bert.shape[-1] == len(phone), ( bert.shape, len(phone), @@ -208,13 +208,7 @@ def __call__(self, batch): torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True ) - max_text_len = max( - [ - batch[ids_sorted_decreasing[i]][7].size(1) - for i in range(len(ids_sorted_decreasing)) - ] - + [len(x[0]) for x in batch] - ) + max_text_len = max([len(x[0]) for x in batch]) max_spec_len = max([x[1].size(1) for x in batch]) max_wav_len = max([x[2].size(1) for x in batch]) @@ -227,7 +221,7 @@ def __call__(self, batch): tone_padded = torch.LongTensor(len(batch), max_text_len) language_padded = torch.LongTensor(len(batch), max_text_len) bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len) - ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len) + ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) diff --git a/models.py b/models.py index 44e56a22d..f392136e1 100644 --- a/models.py +++ b/models.py @@ -340,7 +340,7 @@ def __init__( self.language_emb = nn.Embedding(num_languages, hidden_channels) nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5) self.bert_proj = nn.Conv1d(1024, hidden_channels, 1) - self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1) + self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1) self.encoder = attentions.Encoder( hidden_channels, diff --git a/requirements.txt b/requirements.txt index 923cbb1c9..b809dd5cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ pypinyin cn2an gradio av -pyopenjtalk +mecab-python3 loguru unidic-lite cmudict diff --git a/text/__init__.py b/text/__init__.py index 35b8e2105..8dd10db04 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -11,12 +11,7 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): Returns: List of integers corresponding to the symbols in the text """ - phones = [] # _symbol_to_id[symbol] for symbol in cleaned_text - for symbol in cleaned_text: - try: - phones.append(_symbol_to_id[symbol]) - except KeyError: - phones.append(0) # symbol not found in ID map, use 0('_') by default + phones = [_symbol_to_id[symbol] for symbol in cleaned_text] tone_start = language_tone_start_map[language] tones = [i + tone_start for i in tones] lang_id = language_id_map[language] diff --git a/text/japanese.py b/text/japanese.py index 0d43c7ffd..53db38b73 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -2,13 +2,17 @@ # compatible with Julius https://github.com/julius-speech/segmentation-kit import re import unicodedata -import sys from transformers import AutoTokenizer -import pyopenjtalk +from text import punctuation, symbols + +try: + import MeCab +except ImportError as e: + raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e +from num2words import num2words -BERT = "./bert/bert-large-japanese-v2" _CONVRULES = [ # Conversion of 2 letters "アァ/ a a", @@ -349,6 +353,99 @@ def hira2kata(text: str) -> str: return text.replace("う゛", "ヴ") +_SYMBOL_TOKENS = set(list("・、。?!")) +_NO_YOMI_TOKENS = set(list("「」『』―()[][]")) +_TAGGER = MeCab.Tagger() + + +def text2kata(text: str) -> str: + parsed = _TAGGER.parse(text) + res = [] + for line in parsed.split("\n"): + if line == "EOS": + break + parts = line.split("\t") + + word, yomi = parts[0], parts[1] + if yomi: + res.append(yomi) + else: + if word in _SYMBOL_TOKENS: + res.append(word) + elif word in ("っ", "ッ"): + res.append("ッ") + elif word in _NO_YOMI_TOKENS: + pass + else: + res.append(word) + return hira2kata("".join(res)) + + +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") @@ -356,12 +453,48 @@ def hira2kata(text: str) -> str: def japanese_convert_numbers_to_words(text: str) -> str: - res = text - for x in _CURRENCY_MAP.keys(): - res = res.replace(x, _CURRENCY_MAP[x]) + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + +def japanese_text_to_phonemes(text: str) -> str: + """Convert Japanese text to phonemes.""" + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + # res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) + res = kata2phoneme(res) return res +def is_japanese_character(char): + # 定义日语文字系统的 Unicode 范围 + japanese_ranges = [ + (0x3040, 0x309F), # 平假名 + (0x30A0, 0x30FF), # 片假名 + (0x4E00, 0x9FFF), # 汉字 (CJK Unified Ideographs) + (0x3400, 0x4DBF), # 汉字扩展 A + (0x20000, 0x2A6DF), # 汉字扩展 B + # 可以根据需要添加其他汉字扩展范围 + ] + + # 将字符的 Unicode 编码转换为整数 + char_code = ord(char) + + # 检查字符是否在任何一个日语范围内 + for start, end in japanese_ranges: + if start <= char_code <= end: + return True + + return False + + rep_map = { ":": ",", ";": ",", @@ -377,9 +510,18 @@ def japanese_convert_numbers_to_words(text: str) -> str: def replace_punctuation(text): - replaced_text = text - for x in rep_map.keys(): - replaced_text = replaced_text.replace(x, rep_map[x]) + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub( + r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF" + + "".join(punctuation) + + r"]+", + "", + replaced_text, + ) + return replaced_text @@ -391,42 +533,54 @@ def text_normalize(text): return res -tokenizer = AutoTokenizer.from_pretrained(BERT) +def distribute_phone(n_phone, n_word): + phones_per_word = [0] * n_word + for task in range(n_phone): + min_tasks = min(phones_per_word) + min_index = phones_per_word.index(min_tasks) + phones_per_word[min_index] += 1 + return phones_per_word + + +tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3") def g2p(norm_text): tokenized = tokenizer.tokenize(norm_text) - st = [x.replace("#", "") for x in tokenized] + phs = [] + ph_groups = [] + for t in tokenized: + if not t.startswith("#"): + ph_groups.append([t]) + else: + ph_groups[-1].append(t.replace("#", "")) word2ph = [] - phs = pyopenjtalk.g2p(norm_text).split( - " " - ) # Directly use the entire norm_text sequence. - for sub in st: # the following code is only for calculating word2ph - wph = 0 - for x in sub: - sys.stdout.flush() - if x not in ["?", ".", "!", "…", ","]: # This will throw warnings. - phonemes = pyopenjtalk.g2p(x) - else: - phonemes = "pau" - # for x in range(repeat): - wph += len(phonemes.split(" ")) - # print(f'{x}-->:{phones}') - word2ph.append(wph) - phonemes = ["_"] + phs + ["_"] - tones = [0 for i in phonemes] + for group in ph_groups: + phonemes = kata2phoneme(text2kata("".join(group))) + # phonemes = [i for i in phonemes if i in symbols] + for i in phonemes: + assert i in symbols, (group, norm_text, tokenized) + phone_len = len(phonemes) + word_len = len(group) + + aaa = distribute_phone(phone_len, word_len) + word2ph += aaa + + phs += phonemes + phones = ["_"] + phs + ["_"] + tones = [0 for i in phones] word2ph = [1] + word2ph + [1] - return phonemes, tones, word2ph + return phones, tones, word2ph if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained(BERT) + tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3") text = "hello,こんにちは、世界!……" from text.japanese_bert import get_bert_feature text = text_normalize(text) print(text) - phones, tones, word2ph = g2p_ojt(text) + phones, tones, word2ph = g2p(text) bert = get_bert_feature(text, word2ph) print(phones, tones, word2ph, bert.shape) diff --git a/text/japanese_bert.py b/text/japanese_bert.py index 69004002d..5dd196483 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -2,12 +2,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM import sys -BERT = "./bert/bert-large-japanese-v2" - -tokenizer = AutoTokenizer.from_pretrained(BERT) -# bert-large model has 25 hidden layers.You can decide which layer to use by setting this variable to a specific value -# default value is 3(untested) -BERT_LAYER = 3 +tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3") models = dict() @@ -29,13 +24,13 @@ def get_bert_feature(text, word2ph, device=None): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) - res = model(**inputs, output_hidden_states=True) - res = res["hidden_states"][BERT_LAYER] + res = models[device](**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() assert inputs["input_ids"].shape[-1] == len(word2ph) word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): - repeat_feature = res[0][i].repeat(word2phone[i], 1) + repeat_feature = res[i].repeat(word2phone[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) diff --git a/text/symbols.py b/text/symbols.py index 94f0f63cd..161ae9f71 100644 --- a/text/symbols.py +++ b/text/symbols.py @@ -74,7 +74,6 @@ # japanese ja_symbols = [ - "pau", "N", "a", "a:", @@ -118,8 +117,6 @@ "z", "zy", ] -for x in range(ord("a"), ord("z") + 1): - ja_symbols.append(chr(x).upper()) num_ja_tones = 1 # English