diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ce558495..6f22b7cb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.291 + rev: v0.0.290 hooks: - id: ruff args: [ --fix ] diff --git a/text/chinese_bert.py b/text/chinese_bert.py index 8159425df..a7607198f 100644 --- a/text/chinese_bert.py +++ b/text/chinese_bert.py @@ -4,8 +4,6 @@ tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large") -models = dict() - def get_bert_feature(text, word2ph, device=None): if ( @@ -16,15 +14,14 @@ def get_bert_feature(text, word2ph, device=None): device = "mps" if not device: device = "cuda" - if device not in models.keys(): - models[device] = AutoModelForMaskedLM.from_pretrained( - "./bert/chinese-roberta-wwm-ext-large" - ).to(device) + model = AutoModelForMaskedLM.from_pretrained( + "./bert/chinese-roberta-wwm-ext-large" + ).to(device) with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) - res = models[device](**inputs, output_hidden_states=True) + res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() assert len(word2ph) == len(text) + 2 diff --git a/text/japanese_bert.py b/text/japanese_bert.py index 69004002d..0953b1bd3 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -9,8 +9,6 @@ # default value is 3(untested) BERT_LAYER = 3 -models = dict() - def get_bert_feature(text, word2ph, device=None): if ( @@ -21,10 +19,7 @@ def get_bert_feature(text, word2ph, device=None): device = "mps" if not device: device = "cuda" - if device not in models.keys(): - models[device] = AutoModelForMaskedLM.from_pretrained( - "./bert/bert-base-japanese-v3" - ).to(device) + model = AutoModelForMaskedLM.from_pretrained(BERT).to(device) with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") for i in inputs: