diff --git a/src/gluonnlp/vocab/vocab.py b/src/gluonnlp/vocab/vocab.py index d2eed49ba1..78e53bbe93 100644 --- a/src/gluonnlp/vocab/vocab.py +++ b/src/gluonnlp/vocab/vocab.py @@ -236,6 +236,8 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] = if token_to_idx: self._sort_index_according_to_user_specification(token_to_idx) + if unknown_token: + self._token_to_idx._default = self._token_to_idx[unknown_token] def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size, diff --git a/tests/unittest/test_models.py b/tests/unittest/test_models.py index 2c01bf14b0..0f65befe22 100644 --- a/tests/unittest/test_models.py +++ b/tests/unittest/test_models.py @@ -176,7 +176,7 @@ def test_pretrained_bert_models(disable_missing_parameters): assert len(vocab) == vocab_size[dataset] for token in special_tokens: assert token in vocab, "Token %s not found in the vocab" % token - assert vocab['RandomWordByHaibin'] == 0 + assert vocab['RandomWordByHaibin'] == vocab[vocab.unknown_token] assert vocab.padding_token == '[PAD]' assert vocab.unknown_token == '[UNK]' assert vocab.bos_token is None diff --git a/tests/unittest/test_vocab_embed.py b/tests/unittest/test_vocab_embed.py index e26cd9f04c..93bca311ba 100644 --- a/tests/unittest/test_vocab_embed.py +++ b/tests/unittest/test_vocab_embed.py @@ -1442,3 +1442,20 @@ def test_vocab_backwards_compatibility_prior_v0_7_corrupted_index_bug(): assert v.idx_to_token[2] == '' assert v.idx_to_token[3] == '' assert v.idx_to_token[4] == 'token' + + +@pytest.mark.parametrize('unknown_token', ['', '']) +@pytest.mark.parametrize('padding_token', ['', '', None]) +@pytest.mark.parametrize('eos_token', ['', None]) +@pytest.mark.parametrize('reserved_tokens', [[''], []]) +def test_vocab_remapped_unknown_token_idx(unknown_token, padding_token, eos_token, reserved_tokens, + counter): + Vocab = functools.partial(nlp.Vocab, counter, max_size=None, min_freq=1, + unknown_token=unknown_token, padding_token=padding_token, + bos_token=None, eos_token=eos_token) + + v = Vocab() + assert v['UNKNOWNWORD'] == 0 + + v = Vocab(token_to_idx={unknown_token: 1}) + assert v['UNKNOWNWORD'] == 1