diff --git a/wenet/text/bpe_tokenizer.py b/wenet/text/bpe_tokenizer.py index 6dffa09892..c350dbfc2f 100644 --- a/wenet/text/bpe_tokenizer.py +++ b/wenet/text/bpe_tokenizer.py @@ -17,11 +17,19 @@ def __init__( ) -> None: super().__init__(symbol_table, non_lang_syms, split_with_space, connect_symbol, unk) - import sentencepiece as spm - self.bpe_model = spm.SentencePieceProcessor() - self.bpe_model.load(bpe_model) + # NOTE(Mddct): multiprocessing.Process() issues + # see: https://github.com/espnet/espnet/blob/master/espnet2/text/sentencepiece_tokenizer.py#L19 + self._model = bpe_model + self.bpe_model = None + + def _build_sp(self): + if self.bpe_model is None: + import sentencepiece as spm + self.bpe_model = spm.SentencePieceProcessor() + self.bpe_model.load(self._model) def text2tokens(self, line: str) -> List[str]: + self._build_sp() line = line.strip() if self.non_lang_syms_pattern is not None: parts = self.non_lang_syms_pattern.split(line.upper()) @@ -38,5 +46,6 @@ def text2tokens(self, line: str) -> List[str]: return tokens def tokens2text(self, tokens: List[str]) -> str: + self._build_sp() text = super().tokens2text(tokens) return text.replace("▁", ' ').strip()