Skip to content
39 changes: 37 additions & 2 deletions docs/source/en/model_doc/nllb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,43 @@ specific language governing permissions and limitations under the License.

# NLLB

**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=bug&template=bug-report.yml) and assign
@LysandreJik
**DISCLAIMER:** The default behaviour for the tokenizer has recently been fixed (and thus changed!
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated

The previous version adds [self.eos_token_id, self.cur_lang_code] at the end of the token sequence for both target and source tokenization. This is wrong as the NLLB paper mentions (page 48, 6.1.1. Model Architecture) :
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated

Note that we prefix the source sequence with the source language, as opposed to the target
language as previously done in several works (Arivazhagan et al., 2019; Johnson et al.,
2017). This is primarily because we prioritize optimizing zero-shot performance of our
model on any pair of 200 languages at a minor cost to supervised performance.

Previous behaviour:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs proofing :-)


```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[13374, 1398, 4260, 4039, 248130, 2, 256047]

>>> # 2: '</s>'
>>> # 256047 : 'eng_Latn'
```
New behaviour

```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[256047, 13374, 1398, 4260, 4039, 248130, 2]
Enabling the old behaviour:
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated

>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", legacy_behaviour=True)
```

For more details, feel free to check the linked [PR](https://github.com/huggingface/transformers/pull/22313) and [Issue](https://github.com/huggingface/transformers/issues/19943)
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated

## Overview of NLLB

Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ def __init__(
tgt_lang=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.legacy_behaviour = legacy_behaviour

super().__init__(
bos_token=bos_token,
Expand All @@ -160,13 +162,13 @@ def __init__(
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
Expand Down Expand Up @@ -388,13 +390,27 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
31 changes: 24 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def __init__(
src_lang=None,
tgt_lang=None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.legacy_behaviour = legacy_behaviour
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
Expand All @@ -169,6 +170,7 @@ def __init__(
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

Expand Down Expand Up @@ -287,10 +289,18 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]

if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand All @@ -302,10 +312,17 @@ def set_src_lang_special_tokens(self, src_lang) -> None:
)

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand Down
32 changes: 25 additions & 7 deletions tests/models/nllb/test_tokenization_nllb.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [
256047,
16297,
134408,
8165,
Expand All @@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
108,
49486,
2,
256047,
]

@classmethod
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_enro_tokenizer_truncation(self):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(len(ids), desired_max_length)

def test_mask_token(self):
Expand Down Expand Up @@ -389,10 +389,10 @@ def test_enro_tokenizer_prepare_batch(self):
self.assertEqual((2, 15), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])

def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
Expand All @@ -419,9 +419,27 @@ def test_tokenizer_translation(self):
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[70, 7356, 2, 256047]],
"input_ids": [[256047, 70, 7356, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 256057,
},
)

@require_torch
def test_legacy_behaviour(self):
self.tokenizer.legacy_behaviour = True
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
)

self.tokenizer.legacy_behaviour = False
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
)