diff --git a/src/gluonnlp/vocab/vocab.py b/src/gluonnlp/vocab/vocab.py index a650841e8b..0483f234da 100644 --- a/src/gluonnlp/vocab/vocab.py +++ b/src/gluonnlp/vocab/vocab.py @@ -138,13 +138,13 @@ def __init__(self, counter=None, max_size=None, min_freq=1, unknown_token=C.UNK_ self._unknown_token = unknown_token special_tokens = [] self._padding_token = padding_token - if padding_token: + if padding_token and padding_token not in special_tokens: special_tokens.append(padding_token) self._bos_token = bos_token - if bos_token: + if bos_token and bos_token not in special_tokens: special_tokens.append(bos_token) self._eos_token = eos_token - if eos_token: + if eos_token and eos_token not in special_tokens: special_tokens.append(eos_token) if reserved_tokens: special_tokens.extend(reserved_tokens) @@ -170,6 +170,7 @@ def _index_special_tokens(self, unknown_token, special_tokens): self._reserved_tokens = None else: self._reserved_tokens = special_tokens[:] + assert len(special_tokens) == len(set(special_tokens)) # sanity check self._idx_to_token.extend(special_tokens) if unknown_token: diff --git a/tests/unittest/test_vocab_embed.py b/tests/unittest/test_vocab_embed.py index 9aff7e5f65..8d9e573a15 100644 --- a/tests/unittest/test_vocab_embed.py +++ b/tests/unittest/test_vocab_embed.py @@ -1223,3 +1223,48 @@ def test_bert_vocab_from_sentencepiece(): for i in range(num_tokens): token = _convert_to_unicode(spm.IdToPiece(i)) assert bert_vocab[token] == i + + +@pytest.mark.parametrize('unknown_token', ['', None]) +@pytest.mark.parametrize('padding_token', ['', '', None]) # padding_token == eos_token +@pytest.mark.parametrize('eos_token', ['', None]) +@pytest.mark.parametrize('reserved_tokens', [[''], []]) +def test_vocab_duplicate_special_tokens(unknown_token, padding_token, + eos_token, reserved_tokens): + """Different special tokens are allowed to map to the same representations. + + Special tokens are a subset of the reserved tokens. In general reserved + tokens must not contain duplicates; however, it is allowed that multiple + special tokens use the same reserved token. + + """ + counter = nlp.data.utils.Counter( + ['a', 'b', 'b', 'c', 'c', 'c', 'some_word$']) + + Vocab = functools.partial(nlp.Vocab, + counter, + max_size=None, + min_freq=1, + unknown_token=unknown_token, + padding_token=padding_token, + bos_token=None, + eos_token=eos_token) + + v = Vocab(reserved_tokens=reserved_tokens) + + # Duplicate special tokens must not corrupt the index + # (Broken before GluonNLP 0.7) + if eos_token is not None and padding_token == eos_token: + # padding_token == eos_token; there should only be a single index for + # + # Before GluonNLP 0.7, idx_to_token looked like + # ['', '', '', 'c', 'b', 'a'] + # But it should look like + # ['', '', 'c', 'b', 'a'] + assert len(v.idx_to_token) == len(v.token_to_idx) + assert len(v.idx_to_token) == len(set(v.idx_to_token)) + + # Specifying a special tokens as reserved tokens is counted as duplicate + if eos_token is not None: + with pytest.raises(AssertionError): + Vocab(reserved_tokens=reserved_tokens + [eos_token])