Skip to content

Commit

Permalink
Additional Japanese processor for NMT that uses MeCab segmentation. F…
Browse files Browse the repository at this point in the history
…ix for BLEU in one-many NMT (#3889)

* Ja Mecab tokenizer

Signed-off-by: MaximumEntropy <[email protected]>

* Set correct lang for ja sacrebleu

Signed-off-by: MaximumEntropy <[email protected]>

* EnJa tok fixes

Signed-off-by: MaximumEntropy <[email protected]>

* BLEU score for one-many fix

Signed-off-by: MaximumEntropy <[email protected]>

* Revert megatron nmt changes

Signed-off-by: MaximumEntropy <[email protected]>

* Fix extra space

Signed-off-by: MaximumEntropy <[email protected]>

* Style

Signed-off-by: MaximumEntropy <[email protected]>

* Empty commit to restart CI

Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy authored Mar 30, 2022
1 parent 9580cf8 commit 84236ba
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
30 changes: 30 additions & 0 deletions nemo/collections/common/tokenizers/en_ja_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List

import ipadic
import MeCab
from pangu import spacing
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer


Expand Down Expand Up @@ -55,3 +59,29 @@ def normalize(self, text) -> str:
return self.normalizer.normalize(text)
else:
return text


class JaMecabProcessor:
"""
Tokenizer, Detokenizer and Normalizer utilities for Japanese MeCab & English
"""

def __init__(self):
self.mecab_tokenizer = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")

def detokenize(self, text: List[str]) -> str:
RE_WS_IN_FW = re.compile(
r'([\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])\s+(?=[\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])'
)

detokenize = lambda s: spacing(RE_WS_IN_FW.sub(r'\1', s)).strip()
return detokenize(' '.join(text))

def tokenize(self, text) -> str:
"""
Tokenizes text using Moses. Returns a string of tokens.
"""
return self.mecab_tokenizer.parse(text).strip()

def normalize(self, text) -> str:
return text
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from nemo.collections.common.parts import transformer_weights_init
from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelProcessor
from nemo.collections.common.tokenizers.chinese_tokenizers import ChineseProcessor
from nemo.collections.common.tokenizers.en_ja_tokenizers import EnJaProcessor
from nemo.collections.common.tokenizers.en_ja_tokenizers import EnJaProcessor, JaMecabProcessor
from nemo.collections.common.tokenizers.indic_tokenizers import IndicProcessor
from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
from nemo.collections.nlp.data import TarredTranslationDataset, TranslationDataset
Expand Down Expand Up @@ -427,9 +427,13 @@ def eval_epoch_end(self, outputs, mode, global_rank):
_translations += [t for (t, g) in tr_and_gt[rank]]
_ground_truths += [g for (t, g) in tr_and_gt[rank]]

if self.tgt_language in ['ja']:
if self.multilingual and isinstance(self.tgt_language, ListConfig):
tgt_language = self.tgt_language[dataloader_idx]
else:
tgt_language = self.tgt_language
if tgt_language in ['ja', 'ja-mecab']:
sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab")
elif self.tgt_language in ['zh']:
elif tgt_language in ['zh']:
sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh")
else:
sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a")
Expand Down Expand Up @@ -862,7 +866,9 @@ def setup_pre_and_post_processing_utils(
if encoder_tokenizer_library == 'byte-level':
source_processor = ByteLevelProcessor()
elif (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'):
source_processor = EnJaProcessor(source_lang)
self.source_processor = EnJaProcessor(source_lang)
elif source_lang == 'ja-mecab':
self.source_processor = JaMecabProcessor()
elif source_lang == 'zh':
source_processor = ChineseProcessor()
elif source_lang == 'hi':
Expand All @@ -875,7 +881,9 @@ def setup_pre_and_post_processing_utils(
if decoder_tokenizer_library == 'byte-level':
target_processor = ByteLevelProcessor()
elif (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'):
target_processor = EnJaProcessor(target_lang)
self.target_processor = EnJaProcessor(target_lang)
elif target_lang == 'ja-mecab':
self.target_processor = JaMecabProcessor()
elif target_lang == 'zh':
target_processor = ChineseProcessor()
elif target_lang == 'hi':
Expand Down

0 comments on commit 84236ba

Please sign in to comment.