diff --git a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index 77247495520c..2cc4271841c7 100755 --- a/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder with open(src_vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent)) + # detect whether this is a do_lower_case situation, which can be derived by checking whether we + # have at least one upcase letter in the source vocab + do_lower_case = True + for k in src_vocab.keys(): + if not k.islower(): + do_lower_case = False + break + tgt_dict = Dictionary.load(tgt_dict_file) tgt_vocab = rewrite_dict_keys(tgt_dict.indices) tgt_vocab_size = len(tgt_vocab) @@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder tokenizer_conf = { "langs": [src_lang, tgt_lang], "model_max_length": 1024, + "do_lower_case": do_lower_case, } print(f"Generating {fsmt_tokenizer_config_file}") diff --git a/src/transformers/tokenization_fsmt.py b/src/transformers/tokenization_fsmt.py index 0f5420407c0e..fae7a7a562b4 100644 --- a/src/transformers/tokenization_fsmt.py +++ b/src/transformers/tokenization_fsmt.py @@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer): File containing the vocabulary for the target language. merges_file (:obj:`str`): File containing the merges. - do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to lowercase the input when tokenizing. unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this @@ -186,6 +186,7 @@ def __init__( src_vocab_file=None, tgt_vocab_file=None, merges_file=None, + do_lower_case=False, unk_token="", bos_token="", sep_token="", @@ -197,6 +198,7 @@ def __init__( src_vocab_file=src_vocab_file, tgt_vocab_file=tgt_vocab_file, merges_file=merges_file, + do_lower_case=do_lower_case, unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, @@ -207,6 +209,7 @@ def __init__( self.src_vocab_file = src_vocab_file self.tgt_vocab_file = tgt_vocab_file self.merges_file = merges_file + self.do_lower_case = do_lower_case # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() @@ -351,6 +354,9 @@ def _tokenize(self, text, lang="en", bypass_tokenizer=False): # raise ValueError(f"Expected lang={self.src_lang}, but got {lang}") lang = self.src_lang + if self.do_lower_case: + text = text.lower() + if bypass_tokenizer: text = text.split() else: diff --git a/tests/test_tokenization_fsmt.py b/tests/test_tokenization_fsmt.py index 21eb02a339b2..790df2247cdf 100644 --- a/tests/test_tokenization_fsmt.py +++ b/tests/test_tokenization_fsmt.py @@ -151,6 +151,13 @@ def test_match_encode_decode(self): decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True) self.assertEqual(decoded_text, src_text) + @slow + def test_tokenizer_lower(self): + tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True) + tokens = tokenizer.tokenize("USA is United States of America") + expected = ["us", "a", "is", "un", "i", "ted", "st", "ates", "of", "am", "er", "ica"] + self.assertListEqual(tokens, expected) + @unittest.skip("FSMTConfig.__init__ requires non-optional args") def test_torch_encode_plus_sent_to_model(self): pass