From 657c835ffdc72c8afdfce1d33ec9e73f5584dc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stardust=C2=B7=E5=87=8F?= Date: Sat, 23 Sep 2023 20:12:08 +0800 Subject: [PATCH] Update japanese_bert.py --- text/japanese_bert.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/text/japanese_bert.py b/text/japanese_bert.py index 5cc104da4..f83e36cdf 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -2,7 +2,12 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM import sys -tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3") +BERT = './bert/bert-large-japanese-v2' + +tokenizer = AutoTokenizer.from_pretrained(BERT) +# bert-large model has 25 hidden layers.You can decide which layer to use by setting this variable to a specific value +# default value is 3(untested) +BERT_LAYER = 3 def get_bert_feature(text, word2ph, device=None): @@ -14,7 +19,7 @@ def get_bert_feature(text, word2ph, device=None): device = "mps" if not device: device = "cuda" - model = AutoModelForMaskedLM.from_pretrained("./bert/bert-base-japanese-v3").to( + model = AutoModelForMaskedLM.from_pretrained(BERT).to( device ) with torch.no_grad(): @@ -22,12 +27,12 @@ def get_bert_feature(text, word2ph, device=None): for i in inputs: inputs[i] = inputs[i].to(device) res = model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() + res = res['hidden_states'][BERT_LAYER] assert inputs["input_ids"].shape[-1] == len(word2ph) word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): - repeat_feature = res[i].repeat(word2phone[i], 1) + repeat_feature = res[0][i].repeat(word2phone[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0)