Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/gluonnlp/embedding/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,6 @@ def serialize(self, file_path, compress=True):

if not unknown_token: # Store empty string instead of None
unknown_token = ''
else:
assert unknown_token == idx_to_token[UNK_IDX]

if not compress:
np.savez(file=file_path, unknown_token=unknown_token,
Expand Down
19 changes: 18 additions & 1 deletion tests/unittest/test_token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# under the License.

import functools
import os

import mxnet as mx
import pytest
Expand Down Expand Up @@ -49,10 +50,11 @@ def __getitem__(self, tokens):
[
(None, None),
(['<unk>', 'hello', 'world'], mx.nd.zeros(shape=[3, 300])), # 300 == embsize
(['hello', 'world', '<unk>'], mx.nd.zeros(shape=[3, 300])), # 300 == embsize
(['hello', 'world'], mx.nd.zeros(shape=[2, 300])), # 300 == embsize
])
def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_extend, unknown_lookup,
idx_token_vec_mapping, embsize=300):
idx_token_vec_mapping, tmp_path, embsize=300):
idx_to_token, idx_to_vec = idx_token_vec_mapping

TokenEmbedding = functools.partial(
Expand All @@ -61,28 +63,40 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte
unknown_lookup=unknown_lookup(embsize) if unknown_lookup is not None else None,
idx_to_token=idx_to_token, idx_to_vec=idx_to_vec)

def test_serialization(emb, tmp_path=tmp_path):
emb_path = os.path.join(str(tmp_path), "emb.npz")
emb.serialize(emb_path)
loaded_emb = nlp.embedding.TokenEmbedding.deserialize(emb_path)
assert loaded_emb == emb

## Test "legacy" constructor
if idx_to_token is None:
emb = TokenEmbedding()
assert len(emb.idx_to_token) == 1 if unknown_token else len(emb.idx_to_token) == 0
# emb does not know the embsize, thus idx_to_vec could not be initialized
assert emb.idx_to_vec is None
with pytest.raises(AttributeError):
# Cannot serialize as idx_to_vec is not initialized
test_serialization(emb)

# Set unknown_token
if unknown_token:
emb[unknown_token] = mx.nd.zeros(embsize) - 1
assert (emb[unknown_token].asnumpy() == mx.nd.zeros(embsize).asnumpy() - 1).all()
assert emb.idx_to_vec.shape[1] == embsize
test_serialization(emb)

if allow_extend:
emb = TokenEmbedding()
emb[unknown_token] = mx.nd.zeros(embsize) - 1
assert emb.idx_to_vec.shape[1] == embsize
test_serialization(emb)

emb = TokenEmbedding()
emb['<some_token>'] = mx.nd.zeros(embsize) - 1
assert emb.idx_to_vec.shape[0] == 2 if unknown_token else emb.idx_to_vec.shape[0] == 1
assert (emb['<some_token>'].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all()
test_serialization(emb)

## Test with idx_to_vec and idx_to_token arguments
else:
Expand All @@ -95,11 +109,13 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte
else:
assert emb.idx_to_token == idx_to_token
assert (emb.idx_to_vec.asnumpy() == idx_to_vec.asnumpy()).all()
test_serialization(emb)

if allow_extend:
emb = TokenEmbedding()
emb[unknown_token] = mx.nd.zeros(embsize) - 1
assert emb.idx_to_vec.shape[1] == embsize
test_serialization(emb)

emb = TokenEmbedding()
emb['<some_token>'] = mx.nd.zeros(embsize) - 1
Expand All @@ -109,3 +125,4 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte
assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 2
else:
assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 1
test_serialization(emb)