Skip to content

Commit

Permalink
Revert "sync master commit to dev (#39)"
Browse files Browse the repository at this point in the history
This reverts commit 4c488a9.
  • Loading branch information
Stardust-minus authored Sep 30, 2023
1 parent 4c488a9 commit 5c04ee8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.291
rev: v0.0.290
hooks:
- id: ruff
args: [ --fix ]
Expand Down
11 changes: 4 additions & 7 deletions text/chinese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large")

models = dict()


def get_bert_feature(text, word2ph, device=None):
if (
Expand All @@ -16,15 +14,14 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(
"./bert/chinese-roberta-wwm-ext-large"
).to(device)
model = AutoModelForMaskedLM.from_pretrained(
"./bert/chinese-roberta-wwm-ext-large"
).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = models[device](**inputs, output_hidden_states=True)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()

assert len(word2ph) == len(text) + 2
Expand Down
7 changes: 1 addition & 6 deletions text/japanese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
# default value is 3(untested)
BERT_LAYER = 3

models = dict()


def get_bert_feature(text, word2ph, device=None):
if (
Expand All @@ -21,10 +19,7 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(
"./bert/bert-base-japanese-v3"
).to(device)
model = AutoModelForMaskedLM.from_pretrained(BERT).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
Expand Down

0 comments on commit 5c04ee8

Please sign in to comment.