Skip to content

Commit

Permalink
Update japanese_bert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Stardust-minus authored Sep 23, 2023
1 parent 4e549c5 commit 657c835
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions text/japanese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys

tokenizer = AutoTokenizer.from_pretrained("./bert/bert-base-japanese-v3")
BERT = './bert/bert-large-japanese-v2'

tokenizer = AutoTokenizer.from_pretrained(BERT)
# bert-large model has 25 hidden layers.You can decide which layer to use by setting this variable to a specific value
# default value is 3(untested)
BERT_LAYER = 3


def get_bert_feature(text, word2ph, device=None):
Expand All @@ -14,20 +19,20 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
model = AutoModelForMaskedLM.from_pretrained("./bert/bert-base-japanese-v3").to(
model = AutoModelForMaskedLM.from_pretrained(BERT).to(
device
)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
res = res['hidden_states'][BERT_LAYER]
assert inputs["input_ids"].shape[-1] == len(word2ph)
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
repeat_feature = res[i].repeat(word2phone[i], 1)
repeat_feature = res[0][i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)

phone_level_feature = torch.cat(phone_level_feature, dim=0)
Expand Down

0 comments on commit 657c835

Please sign in to comment.