Skip to content

Commit

Permalink
Update chinese_bert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Stardust-minus authored Sep 30, 2023
1 parent bab550c commit 8e7b594
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions text/chinese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large")

models = dict()

def get_bert_feature(text, word2ph, device=None):
if (
Expand All @@ -14,14 +14,15 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
model = AutoModelForMaskedLM.from_pretrained(
"./bert/chinese-roberta-wwm-ext-large"
).to(device)
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(
"./bert/chinese-roberta-wwm-ext-large"
).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = models[device](**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()

assert len(word2ph) == len(text) + 2
Expand Down

0 comments on commit 8e7b594

Please sign in to comment.