diff --git a/data_utils.py b/data_utils.py index 7ef204b7c..03c5e2928 100644 --- a/data_utils.py +++ b/data_utils.py @@ -147,7 +147,7 @@ def get_text(self, text, word2ph, phone, tone, language_str, wav_path): try: bert = torch.load(bert_path) except: - bert = get_bert(text, word2ph, language_str,device=None) + bert = get_bert(text, word2ph, language_str, "cuda", tokenizer) torch.save(bert, bert_path) if language_str == "ZH":