diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py old mode 100755 new mode 100644 diff --git a/src/transformers/models/plbart/tokenization_plbart.py b/src/transformers/models/plbart/tokenization_plbart.py index f6f393f9b8bd..94ec77c468c9 100644 --- a/src/transformers/models/plbart/tokenization_plbart.py +++ b/src/transformers/models/plbart/tokenization_plbart.py @@ -88,8 +88,18 @@ } FAIRSEQ_LANGUAGE_CODES = { - "base": ["java", "python", "en_XX"], - "multi": ["java", "python", "en_XX", "javascript", "php", "ruby", "go"], + "base": ["__java__", "__python__", "__en_XX__"], + "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"], +} + +FAIRSEQ_LANGUAGE_CODES_MAP = { + "java": "__java__", + "python": "__python__", + "en_XX": "__en_XX__", + "javascript": "__javascript__", + "php": "__php__", + "ruby": "__ruby__", + "go": "__go__", } @@ -202,6 +212,8 @@ def __init__( sp_model_kwargs=self.sp_model_kwargs, **kwargs, ) + src_lang = self._convert_lang_code_special_format(src_lang) + tgt_lang = self._convert_lang_code_special_format(tgt_lang) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(str(vocab_file)) @@ -247,7 +259,7 @@ def __init__( self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang ) else: - self._src_lang = src_lang if src_lang is not None else "en_XX" + self._src_lang = src_lang if src_lang is not None else "__en_XX__" self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] self.tgt_lang = tgt_lang @@ -284,6 +296,7 @@ def src_lang(self) -> str: @src_lang.setter def src_lang(self, new_src_lang: str) -> None: + new_src_lang = self._convert_lang_code_special_format(new_src_lang) self._src_lang = new_src_lang self.set_src_lang_special_tokens(self._src_lang) @@ -374,9 +387,10 @@ def _build_translation_inputs( """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") - self.src_lang = src_lang + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) - tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs @@ -433,8 +447,8 @@ def prepare_seq2seq_batch( tgt_lang: str = "python", **kwargs, ) -> BatchEncoding: - self.src_lang = src_lang - self.tgt_lang = tgt_lang + self.src_lang = self._convert_lang_code_special_format(src_lang) + self.tgt_lang = self._convert_lang_code_special_format(tgt_lang) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) def _switch_to_input_mode(self): @@ -445,6 +459,7 @@ def _switch_to_target_mode(self): def set_src_lang_special_tokens(self, src_lang) -> None: """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" + src_lang = self._convert_lang_code_special_format(src_lang) self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None self.prefix_tokens = [] if self.cur_lang_code is not None: @@ -454,9 +469,16 @@ def set_src_lang_special_tokens(self, src_lang) -> None: def set_tgt_lang_special_tokens(self, lang: str) -> None: """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code].""" + lang = self._convert_lang_code_special_format(lang) + self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None self.prefix_tokens = [] if self.cur_lang_code is not None: self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] else: self.suffix_tokens = [self.eos_token_id] + + def _convert_lang_code_special_format(self, lang: str) -> str: + """Convert Language Codes to format tokenizer uses if required""" + lang = FAIRSEQ_LANGUAGE_CODES_MAP[lang] if lang in FAIRSEQ_LANGUAGE_CODES_MAP.keys() else lang + return lang diff --git a/tests/models/plbart/test_tokenization_plbart.py b/tests/models/plbart/test_tokenization_plbart.py index 2ce7cafbda6e..f9cc38e0de69 100644 --- a/tests/models/plbart/test_tokenization_plbart.py +++ b/tests/models/plbart/test_tokenization_plbart.py @@ -129,7 +129,14 @@ def test_full_base_tokenizer(self): end = tokenizer.vocab_size language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)] - self.assertListEqual(language_tokens, ["java", "python", "en_XX", ""]) + self.assertListEqual(language_tokens, ["__java__", "__python__", "__en_XX__", ""]) + + code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go" + input_ids = tokenizer(code).input_ids + self.assertEqual( + tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False), + code, + ) def test_full_multi_tokenizer(self): tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True) @@ -208,7 +215,15 @@ def test_full_multi_tokenizer(self): end = tokenizer.vocab_size language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)] - self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"]) + self.assertListEqual( + language_tokens, ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"] + ) + code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go" + input_ids = tokenizer(code).input_ids + self.assertEqual( + tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False), + code, + ) @require_torch @@ -262,9 +277,9 @@ def setUpClass(cls): return cls def check_language_codes(self): - self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001) - self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002) - self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__java__"], 50001) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__python__"], 50002) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__en_XX__"], 50003) def test_python_en_tokenizer_batch_encode_plus(self): ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] @@ -288,7 +303,7 @@ def test_python_en_tokenizer_truncation(self): self.assertEqual(len(ids), desired_max_length) def test_mask_token(self): - self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "java"]), [50004, 50001]) + self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "__java__"]), [50004, 50001]) def test_special_tokens_unaffacted_by_save_load(self): tmpdirname = tempfile.mkdtemp()