diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index 171be71e9f..d22100e487 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -830,8 +830,6 @@ def serialize(self, file_path, compress=True): if not unknown_token: # Store empty string instead of None unknown_token = '' - else: - assert unknown_token == idx_to_token[UNK_IDX] if not compress: np.savez(file=file_path, unknown_token=unknown_token, diff --git a/tests/unittest/test_token_embedding.py b/tests/unittest/test_token_embedding.py index 7f5a635332..eb82c65dcc 100644 --- a/tests/unittest/test_token_embedding.py +++ b/tests/unittest/test_token_embedding.py @@ -18,6 +18,7 @@ # under the License. import functools +import os import mxnet as mx import pytest @@ -49,10 +50,11 @@ def __getitem__(self, tokens): [ (None, None), (['', 'hello', 'world'], mx.nd.zeros(shape=[3, 300])), # 300 == embsize + (['hello', 'world', ''], mx.nd.zeros(shape=[3, 300])), # 300 == embsize (['hello', 'world'], mx.nd.zeros(shape=[2, 300])), # 300 == embsize ]) def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_extend, unknown_lookup, - idx_token_vec_mapping, embsize=300): + idx_token_vec_mapping, tmp_path, embsize=300): idx_to_token, idx_to_vec = idx_token_vec_mapping TokenEmbedding = functools.partial( @@ -61,28 +63,40 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte unknown_lookup=unknown_lookup(embsize) if unknown_lookup is not None else None, idx_to_token=idx_to_token, idx_to_vec=idx_to_vec) + def test_serialization(emb, tmp_path=tmp_path): + emb_path = os.path.join(str(tmp_path), "emb.npz") + emb.serialize(emb_path) + loaded_emb = nlp.embedding.TokenEmbedding.deserialize(emb_path) + assert loaded_emb == emb + ## Test "legacy" constructor if idx_to_token is None: emb = TokenEmbedding() assert len(emb.idx_to_token) == 1 if unknown_token else len(emb.idx_to_token) == 0 # emb does not know the embsize, thus idx_to_vec could not be initialized assert emb.idx_to_vec is None + with pytest.raises(AttributeError): + # Cannot serialize as idx_to_vec is not initialized + test_serialization(emb) # Set unknown_token if unknown_token: emb[unknown_token] = mx.nd.zeros(embsize) - 1 assert (emb[unknown_token].asnumpy() == mx.nd.zeros(embsize).asnumpy() - 1).all() assert emb.idx_to_vec.shape[1] == embsize + test_serialization(emb) if allow_extend: emb = TokenEmbedding() emb[unknown_token] = mx.nd.zeros(embsize) - 1 assert emb.idx_to_vec.shape[1] == embsize + test_serialization(emb) emb = TokenEmbedding() emb[''] = mx.nd.zeros(embsize) - 1 assert emb.idx_to_vec.shape[0] == 2 if unknown_token else emb.idx_to_vec.shape[0] == 1 assert (emb[''].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all() + test_serialization(emb) ## Test with idx_to_vec and idx_to_token arguments else: @@ -95,11 +109,13 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte else: assert emb.idx_to_token == idx_to_token assert (emb.idx_to_vec.asnumpy() == idx_to_vec.asnumpy()).all() + test_serialization(emb) if allow_extend: emb = TokenEmbedding() emb[unknown_token] = mx.nd.zeros(embsize) - 1 assert emb.idx_to_vec.shape[1] == embsize + test_serialization(emb) emb = TokenEmbedding() emb[''] = mx.nd.zeros(embsize) - 1 @@ -109,3 +125,4 @@ def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_exte assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 2 else: assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 1 + test_serialization(emb)