Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/gluonnlp/vocab/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def __init__(self, counter=None, max_size=None, min_freq=1, unknown_token=C.UNK_
self._unknown_token = unknown_token
special_tokens = []
self._padding_token = padding_token
if padding_token:
if padding_token and padding_token not in special_tokens:
special_tokens.append(padding_token)
self._bos_token = bos_token
if bos_token:
if bos_token and bos_token not in special_tokens:
special_tokens.append(bos_token)
self._eos_token = eos_token
if eos_token:
if eos_token and eos_token not in special_tokens:
special_tokens.append(eos_token)
if reserved_tokens:
special_tokens.extend(reserved_tokens)
Expand All @@ -170,6 +170,7 @@ def _index_special_tokens(self, unknown_token, special_tokens):
self._reserved_tokens = None
else:
self._reserved_tokens = special_tokens[:]
assert len(special_tokens) == len(set(special_tokens)) # sanity check
self._idx_to_token.extend(special_tokens)

if unknown_token:
Expand Down
45 changes: 45 additions & 0 deletions tests/unittest/test_vocab_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,3 +1223,48 @@ def test_bert_vocab_from_sentencepiece():
for i in range(num_tokens):
token = _convert_to_unicode(spm.IdToPiece(i))
assert bert_vocab[token] == i


@pytest.mark.parametrize('unknown_token', ['<unk>', None])
@pytest.mark.parametrize('padding_token', ['<pad>', '<eos>', None]) # padding_token == eos_token
@pytest.mark.parametrize('eos_token', ['<eos>', None])
@pytest.mark.parametrize('reserved_tokens', [['<tok>'], []])
def test_vocab_duplicate_special_tokens(unknown_token, padding_token,
eos_token, reserved_tokens):
"""Different special tokens are allowed to map to the same representations.

Special tokens are a subset of the reserved tokens. In general reserved
tokens must not contain duplicates; however, it is allowed that multiple
special tokens use the same reserved token.

"""
counter = nlp.data.utils.Counter(
['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])

Vocab = functools.partial(nlp.Vocab,
counter,
max_size=None,
min_freq=1,
unknown_token=unknown_token,
padding_token=padding_token,
bos_token=None,
eos_token=eos_token)

v = Vocab(reserved_tokens=reserved_tokens)

# Duplicate special tokens must not corrupt the index
# (Broken before GluonNLP 0.7)
if eos_token is not None and padding_token == eos_token:
# padding_token == eos_token; there should only be a single index for
# <eos>
# Before GluonNLP 0.7, idx_to_token looked like
# ['<unk>', '<eos>', '<eos>', 'c', 'b', 'a']
# But it should look like
# ['<unk>', '<eos>', 'c', 'b', 'a']
assert len(v.idx_to_token) == len(v.token_to_idx)
assert len(v.idx_to_token) == len(set(v.idx_to_token))

# Specifying a special tokens as reserved tokens is counted as duplicate
if eos_token is not None:
with pytest.raises(AssertionError):
Vocab(reserved_tokens=reserved_tokens + [eos_token])