From 8e7b5949c9a1ba37a9203e39774120d8979849d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stardust=C2=B7=E5=87=8F?= Date: Sat, 30 Sep 2023 09:38:24 +0800 Subject: [PATCH] Update chinese_bert.py --- text/chinese_bert.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/text/chinese_bert.py b/text/chinese_bert.py index a7607198f..c6627d68e 100644 --- a/text/chinese_bert.py +++ b/text/chinese_bert.py @@ -3,7 +3,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large") - +models = dict() def get_bert_feature(text, word2ph, device=None): if ( @@ -14,14 +14,15 @@ def get_bert_feature(text, word2ph, device=None): device = "mps" if not device: device = "cuda" - model = AutoModelForMaskedLM.from_pretrained( - "./bert/chinese-roberta-wwm-ext-large" - ).to(device) + if device not in models.keys(): + models[device] = AutoModelForMaskedLM.from_pretrained( + "./bert/chinese-roberta-wwm-ext-large" + ).to(device) with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) - res = model(**inputs, output_hidden_states=True) + res = models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() assert len(word2ph) == len(text) + 2