From a13dedf5ad16fc98152d0a825c3c3158ac25890c Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 28 Nov 2023 00:24:33 +0800 Subject: [PATCH] [text] fix bpe model in multiprocess env --- wenet/text/bpe_tokenizer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py index 6dffa0989..de1b504a9 100644 --- a/wenet/text/bpe_tokenizer.py +++ b/wenet/text/bpe_tokenizer.py @@ -17,11 +17,19 @@ def __init__( ) -> None: super().__init__(symbol_table, non_lang_syms, split_with_space, connect_symbol, unk) - import sentencepiece as spm - self.bpe_model = spm.SentencePieceProcessor() - self.bpe_model.load(bpe_model) + self._model = bpe_model + # NOTE(Mddct): multiprocessing.Process() issues + # don't build sp here + self.bpe_model = None + + def _build_sp(self): + if self.bpe_model is None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(self._model) def text2tokens(self, line: str) -> List[str]: + self._build_sp() line = line.strip() if self.non_lang_syms_pattern is not None: parts = self.non_lang_syms_pattern.split(line.upper()) @@ -38,5 +46,6 @@ def text2tokens(self, line: str) -> List[str]: return tokens def tokens2text(self, tokens: List[str]) -> str: + self._build_sp() text = super().tokens2text(tokens) return text.replace("▁", ' ').strip()