From e3379816f838f3c27e2b3521ed8ef5637ff701d9 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 5 Jun 2019 13:05:21 +0000 Subject: [PATCH 1/6] Refactor TokenEmbedding to reduce number of places that initialize internals Prior to this commit, the TokenEmbedding constructor could only construct an empty TokenEmbedding. However, an empty TokenEmbedding is of little use, thus there exist a variety of places that modify and overwrite TokenEmbedding internals after construction to "fill" the idx_to_token, idx_to_vec. Examples are _load_embedding_text or _load_embedding_serialized within the TokenEmbedding class, but also set_embedding in the Vocab class. This commits 1) makes these methods static and changes them to return the idx_to_token and idx_to_vec, 2) extends the TokenEmbedding constructor to allow constructing a "non-empty" TokenEmbedding given newly added idx_to_token and idx_to_vec arguments This change is backwards compatible in that it does not change any public API besides introducing idx_to_token and idx_to_vec arguments to TokenEmbedding. For the future, the "empty" TokenEmbedding initialization may be removed as it provides little benefit. For now it is kept for backwards compatibility. --- src/gluonnlp/embedding/token_embedding.py | 386 ++++++++++++++-------- src/gluonnlp/vocab/vocab.py | 11 +- tests/unittest/test_vocab_embed.py | 33 +- 3 files changed, 285 insertions(+), 145 deletions(-) diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index c52b99c568..8db31faafd 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -42,9 +42,9 @@ from ..data.utils import DefaultLookupDict from ..model.train import FasttextEmbeddingModel -UNK_IDX = 0 # This should not be changed as long as serialized token - # embeddings redistributed on S3 contain an unknown token. Commit - # 46eb28ffcf542ab8ed801a727a45404a73adbce3 has more context. +UNK_IDX = 0 +ENCODING = 'utf8' +INIT_UNKNOWN_VEC = nd.zeros def register(embedding_cls): """Registers a new token embedding. @@ -168,7 +168,7 @@ class TokenEmbedding(object): Any unknown token will be replaced by unknown_token and consequently will be indexed as the same representation. Only used if oov_imputer is not specified. - init_unknown_vec : callback + init_unknown_vec : callback, default nd.zeros The callback used to initialize the embedding vector for the unknown token. Only used if `unknown_token` is not None. allow_extend : bool, default False @@ -183,19 +183,62 @@ class TokenEmbedding(object): """ - def __init__(self, unknown_token='', init_unknown_vec=nd.zeros, - allow_extend=False, unknown_lookup=None): - self._unknown_token = unknown_token - self._init_unknown_vec = init_unknown_vec - self._allow_extend = allow_extend - self._unknown_lookup = unknown_lookup - self._idx_to_token = [unknown_token] if unknown_token else [] - if unknown_token: - self._token_to_idx = DefaultLookupDict(UNK_IDX) + def __init__(self, unknown_token=C.UNK_TOKEN, init_unknown_vec=INIT_UNKNOWN_VEC, + allow_extend=False, unknown_lookup=None, idx_to_vec=None, + idx_to_token=None): + unknown_index = None + + # With pre-specified tokens and vectors + if idx_to_vec is not None or idx_to_token is not None: + # Sanity checks + if idx_to_vec is None or idx_to_token is None: + raise ValueError('Must specify either none or both of ' + 'idx_to_token and idx_to_vec.') + if len(idx_to_vec) != len(idx_to_token): + raise ValueError( + 'idx_to_token and idx_to_vec must be of equal length.') + if init_unknown_vec is not None: + raise ValueError('Must not specify init_unknown_vec ' + 'when specifying idx_to_vec') + if unknown_token is not None: + try: + unknown_index = idx_to_token.index(unknown_token) + except ValueError: + raise ValueError( + 'unknown_token \'{}\' must be part of idx_to_token'. + format(unknown_token)) + + # Initialization + self._unknown_token = unknown_token + assert init_unknown_vec is None + self._init_unknown_vec = init_unknown_vec + self._allow_extend = allow_extend + self._unknown_lookup = unknown_lookup + + self._idx_to_token = idx_to_token + self._idx_to_vec = idx_to_vec + + # Empty token-embedding + else: + # Initialization + self._unknown_token = unknown_token + if self._unknown_token is not None: + unknown_index = UNK_IDX + self._init_unknown_vec = init_unknown_vec + self._allow_extend = allow_extend + self._unknown_lookup = unknown_lookup + + assert UNK_IDX == 0 + self._idx_to_token = [unknown_token] if unknown_token else [] + self._idx_to_vec = None + + # Initialization of token_to_idx mapping + if self._unknown_token: + assert unknown_index is not None + self._token_to_idx = DefaultLookupDict(unknown_index) else: self._token_to_idx = {} self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) - self._idx_to_vec = None @staticmethod def _get_file_url(cls_name, source_file_hash, source): @@ -221,8 +264,9 @@ def _get_file_path(cls, source_file_hash, embedding_root, source): return pretrained_file_path - def _load_embedding(self, pretrained_file_path, elem_delim, - encoding='utf8'): + @staticmethod + def _load_embedding(pretrained_file_path, elem_delim, unknown_token, + init_unknown_vec, encoding=ENCODING): """Load embedding vectors from a pre-trained token embedding file. Both text files and TokenEmbedding serialization files are supported. @@ -248,24 +292,40 @@ def _load_embedding(self, pretrained_file_path, elem_delim, pretrained_file_path) if pretrained_file_path.endswith('.npz'): - self._load_embedding_serialized( - pretrained_file_path=pretrained_file_path) + return TokenEmbedding._load_embedding_serialized( + pretrained_file_path=pretrained_file_path, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) else: - self._load_embedding_txt( + return TokenEmbedding._load_embedding_txt( pretrained_file_path=pretrained_file_path, - elem_delim=elem_delim, encoding=encoding) + elem_delim=elem_delim, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) - def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8'): + + @staticmethod + def _load_embedding_txt(pretrained_file_path, elem_delim, unknown_token, + init_unknown_vec, encoding=ENCODING): """Load embedding vectors from a pre-trained token embedding file. - For every unknown token, if its representation `self.unknown_token` is encountered in the - pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token - embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the - text embedding vector initialized by `self._init_unknown_vec`. + Returns idx_to_token, idx_to_vec and unknown_token suitable for the + TokenEmbedding constructor. + + For every unknown token, if its representation `unknown_token` is encountered in the + pre-trained token embedding file, index 0 of `idx_to_vec` maps to the pre-trained token + embedding vector loaded from the file; otherwise, index 0 of `idx_to_vec` maps to the + text embedding vector initialized by `init_unknown_vec`. If a token is encountered multiple times in the pre-trained text embedding file, only the first-encountered token embedding vector will be loaded and the rest will be skipped. + """ + idx_to_token = [unknown_token] if unknown_token else [] + unk_idx = None + if unknown_token: + unk_idx = 0 vec_len = None all_elems = [] @@ -287,9 +347,9 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') token, elems = elems[0], [float(i) for i in elems[1:]] - if token == self.unknown_token and loaded_unknown_vec is None: + if loaded_unknown_vec is None and token == unknown_token: loaded_unknown_vec = elems - tokens.add(self.unknown_token) + tokens.add(unknown_token) elif token in tokens: warnings.warn('line {} in {}: duplicate embedding found for ' 'token "{}". Skipped.'.format(line_num, pretrained_file_path, @@ -300,9 +360,10 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') else: if not vec_len: vec_len = len(elems) - if self.unknown_token: + if unknown_token: # Reserve a vector slot for the unknown token at the very beggining # because the unknown token index is 0. + assert len(all_elems) == 0 all_elems.extend([0] * vec_len) else: assert len(elems) == vec_len, \ @@ -311,76 +372,83 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') pretrained_file_path, token, vec_len, len(elems)) all_elems.extend(elems) - self._idx_to_token.append(token) - self._token_to_idx[token] = len(self._idx_to_token) - 1 + idx_to_token.append(token) tokens.add(token) - self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) + idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) - if self.unknown_token: + if unknown_token: if loaded_unknown_vec is None: - self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) else: - self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec) + idx_to_vec[unk_idx] = nd.array(loaded_unknown_vec) + + return idx_to_token, idx_to_vec, unknown_token - def _load_embedding_serialized(self, pretrained_file_path): + @staticmethod + def _load_embedding_serialized(pretrained_file_path, unknown_token, init_unknown_vec): """Load embedding vectors from a pre-trained token embedding file. - For every unknown token, if its representation `self.unknown_token` is encountered in the - pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token - embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the - text embedding vector initialized by `self._init_unknown_vec`. + Returns idx_to_token, idx_to_vec and unknown_token suitable for the + TokenEmbedding constructor. ValueError is raised if a token occurs multiple times. """ - deserialized_embedding = TokenEmbedding.deserialize(pretrained_file_path) + + idx_to_token = deserialized_embedding.idx_to_token + if len(set(idx_to_token)) != len(idx_to_token): + raise ValueError('Serialized embedding contains duplicate tokens.') + idx_to_vec = deserialized_embedding.idx_to_vec + vec_len = idx_to_vec.shape[1] + loaded_unknown_vec = False if deserialized_embedding.unknown_token: - # Some .npz files on S3 may contain an unknown token and its - # respective embedding. As a workaround, we assume that UNK_IDX - # is the same now as it was when the .npz was generated. Under this - # assumption we can safely overwrite the respective token and - # vector from the npz. - if deserialized_embedding.unknown_token: - idx_to_token = deserialized_embedding.idx_to_token - idx_to_vec = deserialized_embedding.idx_to_vec - idx_to_token[UNK_IDX] = self.unknown_token - if self._init_unknown_vec: - vec_len = idx_to_vec.shape[1] - idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + if not unknown_token: + # If the TokenEmbedding shall not have an unknown token but the + # serialized file provided one, delete the provided one. + unk_idx = deserialized_embedding.token_to_idx[ + deserialized_embedding.unknown_token] + assert unk_idx >= 0 + if unk_idx == 0: + idx_to_token = idx_to_token[1:] + idx_to_vec = idx_to_vec[1:] + else: + idx_to_token = idx_to_token[:unk_idx] + idx_to_token[unk_idx + 1:] + idx_to_vec = nd.concat(idx_to_vec[:unk_idx], idx_to_vec[unk_idx + 1:], dim=0) else: - # If the TokenEmbedding shall not have an unknown token, we - # just delete the one in the npz. - assert UNK_IDX == 0 - idx_to_token = deserialized_embedding.idx_to_token[UNK_IDX + 1:] - idx_to_vec = deserialized_embedding.idx_to_vec[UNK_IDX + 1:] + # If the TokenEmbedding shall have an unknown token and the + # serialized file provided one, replace the representation. + unk_idx = deserialized_embedding.token_to_idx[ + deserialized_embedding.unknown_token] + idx_to_token[unk_idx] = unknown_token + loaded_unknown_vec = True else: - idx_to_token = deserialized_embedding.idx_to_token - idx_to_vec = deserialized_embedding.idx_to_vec + if unknown_token and unknown_token not in idx_to_token: + # If the TokenEmbedding shall have an unknown token but the + # serialized file didn't provided one, insert a new one + idx_to_token = [unknown_token] + idx_to_token + idx_to_vec = nd.concat(nd.zeros((1, vec_len)), idx_to_vec, dim=0) + elif unknown_token: + # The serialized file did define a unknown token, but contains + # the token that is specified by the user to represent the + # unknown token. + assert not deserialized_embedding.unknown_token + loaded_unknown_vec = True + # Move unknown_token to idx 0 to replicate the behavior of + # _load_embedding_text + unk_idx = idx_to_token.index(unknown_token) + if unk_idx > 0: + idx_to_token[0], idx_to_token[unk_idx] = idx_to_token[unk_idx], idx_to_token[0] + idx_to_vec[[0, unk_idx]] = idx_to_vec[[unk_idx, 0]] + else: + assert not deserialized_embedding.unknown_token + assert not unknown_token - if not len(set(idx_to_token)) == len(idx_to_token): - raise ValueError('Serialized embedding invalid. ' - 'It contains duplicate tokens.') + if unknown_token and init_unknown_vec and not loaded_unknown_vec: + unk_idx = idx_to_token.index(unknown_token) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) - if self.unknown_token: - try: - unknown_token_idx = deserialized_embedding.idx_to_token.index( - self.unknown_token) - idx_to_token[UNK_IDX], idx_to_token[ - unknown_token_idx] = idx_to_token[ - unknown_token_idx], idx_to_token[UNK_IDX] - idxs = [UNK_IDX, unknown_token_idx] - idx_to_vec[idxs] = idx_to_vec[idxs[::-1]] - except ValueError: - vec_len = idx_to_vec.shape[1] - idx_to_token.insert(0, self.unknown_token) - idx_to_vec = nd.concat( - self._init_unknown_vec(shape=vec_len).reshape((1, -1)), - idx_to_vec, dim=0) - - self._idx_to_token = idx_to_token - self._idx_to_vec = idx_to_vec - self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) + return idx_to_token, idx_to_vec, unknown_token @property def idx_to_token(self): @@ -635,7 +703,7 @@ def __setitem__(self, tokens, new_embedding): raise KeyError(('Token "{}" is unknown. To update the embedding vector for an' ' unknown token, please explicitly include "{}" as the ' '`unknown_token` in `tokens`. This is to avoid unintended ' - 'updates.').format(token, self._idx_to_token[UNK_IDX])) + 'updates.').format(token, self.unknown_token)) else: raise KeyError(('Token "{}" is unknown. Updating the embedding vector for an ' 'unknown token is not allowed because `unknown_token` is not ' @@ -661,7 +729,7 @@ def _check_source(cls, source_file_hash, source): ', '.join(source_file_hash.keys()))) @staticmethod - def from_file(file_path, elem_delim=' ', encoding='utf8', **kwargs): + def from_file(file_path, elem_delim=' ', encoding=ENCODING, **kwargs): """Creates a user-defined token embedding from a pre-trained embedding file. @@ -693,9 +761,22 @@ def from_file(file_path, elem_delim=' ', encoding='utf8', **kwargs): instance of :class:`gluonnlp.embedding.TokenEmbedding` The user-defined token embedding instance. """ - embedding = TokenEmbedding(**kwargs) - embedding._load_embedding(file_path, elem_delim=elem_delim, encoding=encoding) - return embedding + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + idx_to_token, idx_to_vec, unknown_token = TokenEmbedding._load_embedding( + file_path, + elem_delim=elem_delim, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + return TokenEmbedding(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) def serialize(self, file_path, compress=True): """Serializes the TokenEmbedding to a file specified by file_path. @@ -740,8 +821,8 @@ def serialize(self, file_path, compress=True): idx_to_token=idx_to_token, idx_to_vec=idx_to_vec) - @classmethod - def deserialize(cls, file_path, **kwargs): + @staticmethod + def deserialize(file_path, **kwargs): """Create a new TokenEmbedding from a serialized one. TokenEmbedding is serialized by converting the list of tokens, the @@ -774,18 +855,15 @@ def deserialize(cls, file_path, **kwargs): idx_to_token = npz_dict['idx_to_token'].tolist() idx_to_vec = nd.array(npz_dict['idx_to_vec']) - embedding = cls(unknown_token=unknown_token, **kwargs) - if unknown_token: - assert unknown_token == idx_to_token[UNK_IDX] - embedding._token_to_idx = DefaultLookupDict(UNK_IDX) - else: - embedding._token_to_idx = {} - - embedding._idx_to_token = idx_to_token - embedding._idx_to_vec = idx_to_vec - embedding._token_to_idx.update((token, idx) for idx, token in enumerate(idx_to_token)) - - return embedding + assert 'unknown_token' not in kwargs + assert 'init_unknown_vec' not in kwargs + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + return TokenEmbedding(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) @register @@ -846,11 +924,24 @@ class GloVe(TokenEmbedding): def __init__(self, source='glove.6B.50d', embedding_root=os.path.join(get_home_dir(), 'embedding'), **kwargs): self._check_source(self.source_file_hash, source) - - super(GloVe, self).__init__(**kwargs) pretrained_file_path = GloVe._get_file_path(self.source_file_hash, embedding_root, source) - - self._load_embedding(pretrained_file_path, elem_delim=' ') + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + encoding = kwargs.pop('encoding', ENCODING) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(GloVe, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) @register @@ -939,7 +1030,8 @@ class FastText(TokenEmbedding): def __init__(self, source='wiki.simple', embedding_root=os.path.join( get_home_dir(), 'embedding'), load_ngrams=False, ctx=cpu(), **kwargs): self._check_source(self.source_file_hash, source) - + pretrained_file_path = FastText._get_file_path(self.source_file_hash, + embedding_root, source) if load_ngrams: try: self._check_source(self.source_bin_file_hash, source) @@ -956,11 +1048,23 @@ def __init__(self, source='wiki.simple', embedding_root=os.path.join( else: unknown_lookup = None - super(FastText, self).__init__(unknown_lookup=unknown_lookup, **kwargs) - pretrained_file_path = FastText._get_file_path(self.source_file_hash, embedding_root, - source) - - self._load_embedding(pretrained_file_path, elem_delim=' ') + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + encoding = kwargs.pop('encoding', ENCODING) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(FastText, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + unknown_lookup=unknown_lookup, **kwargs) @register @@ -1027,19 +1131,37 @@ class Word2Vec(TokenEmbedding): source_file_hash = C.WORD2VEC_NPZ_SHA1 def __init__(self, source='GoogleNews-vectors-negative300', - embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding='utf8', + embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding=ENCODING, **kwargs): - super(Word2Vec, self).__init__(**kwargs) + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) if source.endswith('.bin'): pretrained_file_path = os.path.expanduser(source) - self._load_w2v_binary(pretrained_file_path, encoding=encoding) + idx_to_token, idx_to_vec, unknown_token = self._load_w2v_binary( + pretrained_file_path, unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, encoding=encoding) else: self._check_source(self.source_file_hash, source) pretrained_file_path = self._get_file_path(self.source_file_hash, embedding_root, source) - self._load_embedding(pretrained_file_path, elem_delim=' ') + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(Word2Vec, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) - def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): + @classmethod + def _load_w2v_binary(cls, pretrained_file_path, unknown_token, + init_unknown_vec=INIT_UNKNOWN_VEC, encoding=ENCODING): """Load embedding vectors from a binary pre-trained token embedding file. Parameters @@ -1050,13 +1172,11 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): encoding: str The encoding type of the file. """ - self._idx_to_token = [self.unknown_token] if self.unknown_token else [] - if self.unknown_token: - self._token_to_idx = DefaultLookupDict(UNK_IDX) - else: - self._token_to_idx = {} - self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) - self._idx_to_vec = None + idx_to_token = [unknown_token] if unknown_token else [] + unk_idx = None + if unknown_token: + unk_idx = 0 + all_elems = [] tokens = set() loaded_unknown_vec = None @@ -1064,7 +1184,7 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): with io.open(pretrained_file_path, 'rb') as f: header = f.readline().decode(encoding=encoding) vocab_size, vec_len = (int(x) for x in header.split()) - if self.unknown_token: + if unknown_token: # Reserve a vector slot for the unknown token at the very beggining # because the unknown token index is 0. all_elems.extend([0] * vec_len) @@ -1091,9 +1211,9 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): assert len(elems) > 1, 'line {} in {}: unexpected data format.'.format( line_num, pretrained_file_path) - if token == self.unknown_token and loaded_unknown_vec is None: + if token == unknown_token and loaded_unknown_vec is None: loaded_unknown_vec = elems - tokens.add(self.unknown_token) + tokens.add(unknown_token) elif token in tokens: warnings.warn('line {} in {}: duplicate embedding found for ' 'token "{}". Skipped.'.format(line_num, pretrained_file_path, @@ -1105,19 +1225,21 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): pretrained_file_path, token, vec_len, len(elems)) all_elems.extend(elems) - self._idx_to_token.append(token) - self._token_to_idx[token] = len(self._idx_to_token) - 1 + idx_to_token.append(token) tokens.add(token) - self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) - if self.unknown_token: + idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) + + if unknown_token: if loaded_unknown_vec is None: - self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) else: - self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec) + idx_to_vec[unk_idx] = nd.array(loaded_unknown_vec) + + return idx_to_token, idx_to_vec, unknown_token @classmethod - def from_w2v_binary(cls, pretrained_file_path, encoding='utf8'): + def from_w2v_binary(cls, pretrained_file_path, encoding=ENCODING): """Load embedding vectors from a binary pre-trained token embedding file. Parameters diff --git a/src/gluonnlp/vocab/vocab.py b/src/gluonnlp/vocab/vocab.py index 13ac920494..61947b5a04 100644 --- a/src/gluonnlp/vocab/vocab.py +++ b/src/gluonnlp/vocab/vocab.py @@ -395,10 +395,6 @@ def set_embedding(self, *embeddings): 'Either all or none of the TokenEmbeddings must have an ' \ 'unknown_token set.' - new_embedding = emb.TokenEmbedding(self.unknown_token, allow_extend=False) - new_embedding._token_to_idx = self.token_to_idx - new_embedding._idx_to_token = self.idx_to_token - new_vec_len = sum(embs.idx_to_vec.shape[1] for embs in embeddings) new_idx_to_vec = nd.zeros(shape=(len(self), new_vec_len)) @@ -412,8 +408,11 @@ def set_embedding(self, *embeddings): new_idx_to_vec[1:, col_start:col_end] = embs[self._idx_to_token[1:]] col_start = col_end - new_embedding._idx_to_vec = new_idx_to_vec - self._embedding = new_embedding + self._embedding = emb.TokenEmbedding(self.unknown_token, + init_unknown_vec=None, + allow_extend=False, + idx_to_token=self.idx_to_token, + idx_to_vec=new_idx_to_vec) def to_tokens(self, indices): """Converts token indices to tokens according to the vocabulary. diff --git a/tests/unittest/test_vocab_embed.py b/tests/unittest/test_vocab_embed.py index 95bd35e66f..7f6d4e96b6 100644 --- a/tests/unittest/test_vocab_embed.py +++ b/tests/unittest/test_vocab_embed.py @@ -820,12 +820,20 @@ def __init__(self, embedding_root='embedding', init_unknown_vec=nd.zeros, **kwar source = 'embedding_test' Test._check_source(self.source_file_hash, source) - super(Test, self).__init__(**kwargs) - file_path = Test._get_file_path(self.source_file_hash, embedding_root, source) - - self._load_embedding(file_path, ' ') + unknown_token = kwargs.pop('unknown_token', '') + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) + + return super(Test, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) test_embed = nlp.embedding.create('test', embedding_root='tests/data/embedding') assert_almost_equal(test_embed['hello'].asnumpy(), (nd.arange(5) + 1).asnumpy()) @@ -1050,12 +1058,23 @@ def __init__(self, embedding_root='tests/data/embedding', **kwargs): source = 'embedding_test' Test._check_source(self.source_file_hash, source) - super(Test, self).__init__(**kwargs) - file_path = Test._get_file_path(self.source_file_hash, embedding_root, source) - self._load_embedding(file_path, ' ') + unknown_token = kwargs.pop('unknown_token', '') + init_unknown_vec = kwargs.pop('init_unknown_vec', nd.zeros) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) + + super(Test, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) + emb = nlp.embedding.create('test', embedding_root='tests/data/embedding') From 1fcb344e25c1adb862d8ee1006aa6263714885de Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Sun, 9 Jun 2019 11:48:42 +0000 Subject: [PATCH 2/6] Add idx_to_token and idx_to_vec docstrings --- src/gluonnlp/embedding/token_embedding.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index 8db31faafd..96fa943683 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -170,7 +170,8 @@ class TokenEmbedding(object): not specified. init_unknown_vec : callback, default nd.zeros The callback used to initialize the embedding vector for the unknown - token. Only used if `unknown_token` is not None. + token. Only used if `unknown_token` is not None and `idx_to_token` is + not None and does not contain `unknown_vec`. allow_extend : bool, default False If True, embedding vectors for previously unknown words can be added via token_embedding[tokens] = vecs. If False, only vectors for known @@ -180,6 +181,24 @@ class TokenEmbedding(object): automatically from `unknown_lookup[unknown_tokens]`. For example, in a FastText model, embeddings for unknown tokens can be computed from the subword information. + idx_to_token : list of str or None, default None + If not None, a list of tokens for which the `idx_to_vec` argument + provides embeddings. The list indices and the indices of `idx_to_vec` + must be aligned. + If `idx_to_token` is not None, `idx_to_vec` must not be None either. + If `idx_to_token` is None, an empty TokenEmbedding object is created. + If `allow_extend` is True, tokens and their embeddings can be added to + the TokenEmbedding at a later stage. + idx_to_vec : mxnet.ndarray.NDArray or None, default None + If not None, a NDArray containing embeddings for the tokens specified + in `idx_to_token`. The first dimension of `idx_to_vec` must be aligned + with `idx_to_token`. + If `idx_to_vec` is not None, `idx_to_token` must not be None either. + If `idx_to_vec` is None, an empty TokenEmbedding object is created. + If `allow_extend` is True, tokens and their embeddings can be added to + the TokenEmbedding at a later stage. + No copy of the idx_to_vec array is made as long as unknown_token is + None or an embedding for unknown_token is specified in `idx_to_vec`. """ From 25d0ee8d890427460053697f27b5039c679d26db Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Sun, 9 Jun 2019 11:50:01 +0000 Subject: [PATCH 3/6] Add more tests --- src/gluonnlp/embedding/token_embedding.py | 23 ++--- tests/unittest/test_token_embedding.py | 111 ++++++++++++++++++++++ tests/unittest/test_vocab_embed.py | 17 ++-- 3 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 tests/unittest/test_token_embedding.py diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index 96fa943683..ef6aa5ee90 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -203,33 +203,34 @@ class TokenEmbedding(object): """ def __init__(self, unknown_token=C.UNK_TOKEN, init_unknown_vec=INIT_UNKNOWN_VEC, - allow_extend=False, unknown_lookup=None, idx_to_vec=None, - idx_to_token=None): + allow_extend=False, unknown_lookup=None, idx_to_token=None, + idx_to_vec=None): unknown_index = None # With pre-specified tokens and vectors if idx_to_vec is not None or idx_to_token is not None: + idx_to_token = idx_to_token.copy() + # Sanity checks if idx_to_vec is None or idx_to_token is None: raise ValueError('Must specify either none or both of ' 'idx_to_token and idx_to_vec.') - if len(idx_to_vec) != len(idx_to_token): - raise ValueError( - 'idx_to_token and idx_to_vec must be of equal length.') + if idx_to_vec.shape[0] != len(idx_to_token): + raise ValueError('idx_to_token and idx_to_vec must contain ' + 'the same number of tokens and embeddings respectively.') if init_unknown_vec is not None: - raise ValueError('Must not specify init_unknown_vec ' - 'when specifying idx_to_vec') + logging.info('Ignoring init_unknown_vec as idx_to_vec is specified') if unknown_token is not None: try: unknown_index = idx_to_token.index(unknown_token) except ValueError: - raise ValueError( - 'unknown_token \'{}\' must be part of idx_to_token'. - format(unknown_token)) + idx_to_token.insert(0, unknown_token) + idx_to_vec = nd.concat(init_unknown_vec((1, idx_to_vec.shape[1])), idx_to_vec, + dim=0) + unknown_index = 0 # Initialization self._unknown_token = unknown_token - assert init_unknown_vec is None self._init_unknown_vec = init_unknown_vec self._allow_extend = allow_extend self._unknown_lookup = unknown_lookup diff --git a/tests/unittest/test_token_embedding.py b/tests/unittest/test_token_embedding.py new file mode 100644 index 0000000000..7f5a635332 --- /dev/null +++ b/tests/unittest/test_token_embedding.py @@ -0,0 +1,111 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import functools + +import mxnet as mx +import pytest + +import gluonnlp as nlp +from gluonnlp.base import _str_types + + +class NaiveUnknownLookup(object): + def __init__(self, embsize): + self.embsize = embsize + + def __contains__(self, token): + return True + + def __getitem__(self, tokens): + if isinstance(tokens, _str_types): + return mx.nd.ones(self.embsize) + else: + return mx.nd.ones((len(tokens), self.embsize)) + + +@pytest.mark.parametrize('unknown_token', [None, '', '[UNK]']) +@pytest.mark.parametrize('init_unknown_vec', [mx.nd.zeros, mx.nd.ones]) +@pytest.mark.parametrize('allow_extend', [True, False]) +@pytest.mark.parametrize('unknown_lookup', [None, NaiveUnknownLookup]) +@pytest.mark.parametrize( + 'idx_token_vec_mapping', + [ + (None, None), + (['', 'hello', 'world'], mx.nd.zeros(shape=[3, 300])), # 300 == embsize + (['hello', 'world'], mx.nd.zeros(shape=[2, 300])), # 300 == embsize + ]) +def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_extend, unknown_lookup, + idx_token_vec_mapping, embsize=300): + idx_to_token, idx_to_vec = idx_token_vec_mapping + + TokenEmbedding = functools.partial( + nlp.embedding.TokenEmbedding, unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, allow_extend=allow_extend, + unknown_lookup=unknown_lookup(embsize) if unknown_lookup is not None else None, + idx_to_token=idx_to_token, idx_to_vec=idx_to_vec) + + ## Test "legacy" constructor + if idx_to_token is None: + emb = TokenEmbedding() + assert len(emb.idx_to_token) == 1 if unknown_token else len(emb.idx_to_token) == 0 + # emb does not know the embsize, thus idx_to_vec could not be initialized + assert emb.idx_to_vec is None + + # Set unknown_token + if unknown_token: + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert (emb[unknown_token].asnumpy() == mx.nd.zeros(embsize).asnumpy() - 1).all() + assert emb.idx_to_vec.shape[1] == embsize + + if allow_extend: + emb = TokenEmbedding() + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[1] == embsize + + emb = TokenEmbedding() + emb[''] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[0] == 2 if unknown_token else emb.idx_to_vec.shape[0] == 1 + assert (emb[''].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all() + + ## Test with idx_to_vec and idx_to_token arguments + else: + emb = TokenEmbedding() + + if unknown_token and unknown_token not in idx_to_token: + assert emb.idx_to_token == [unknown_token] + idx_to_token + assert (emb.idx_to_vec[1:].asnumpy() == idx_to_vec.asnumpy()).all() + assert (emb.idx_to_vec[0].asnumpy() == init_unknown_vec(embsize).asnumpy()).all() + else: + assert emb.idx_to_token == idx_to_token + assert (emb.idx_to_vec.asnumpy() == idx_to_vec.asnumpy()).all() + + if allow_extend: + emb = TokenEmbedding() + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[1] == embsize + + emb = TokenEmbedding() + emb[''] = mx.nd.zeros(embsize) - 1 + assert (emb[''].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all() + + if unknown_token and unknown_token not in idx_to_token: + assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 2 + else: + assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 1 diff --git a/tests/unittest/test_vocab_embed.py b/tests/unittest/test_vocab_embed.py index 7f6d4e96b6..0a102b829b 100644 --- a/tests/unittest/test_vocab_embed.py +++ b/tests/unittest/test_vocab_embed.py @@ -17,26 +17,21 @@ # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function +import functools +import os import random import re -import os import sys -import functools +import numpy as np import pytest - -import gluonnlp as nlp from mxnet import ndarray as nd from mxnet.test_utils import * -import numpy as np -if sys.version_info[0] == 3: - _str_types = (str, ) -else: - _str_types = (str, unicode) +import gluonnlp as nlp +from gluonnlp.base import _str_types @pytest.fixture From 8ea19facdc67d580379e5f7617d0559c11de14fa Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 10 Jun 2019 14:58:53 +0000 Subject: [PATCH 4/6] Fix Py2 support --- src/gluonnlp/embedding/token_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index ef6aa5ee90..171be71e9f 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -209,7 +209,7 @@ def __init__(self, unknown_token=C.UNK_TOKEN, init_unknown_vec=INIT_UNKNOWN_VEC, # With pre-specified tokens and vectors if idx_to_vec is not None or idx_to_token is not None: - idx_to_token = idx_to_token.copy() + idx_to_token = idx_to_token[:] # Sanity checks if idx_to_vec is None or idx_to_token is None: From 3103f4f05ed1d4d74d7457499a38d272b52b14b5 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 10 Jun 2019 21:18:51 +0000 Subject: [PATCH 5/6] Update deprecated configs in docs/conf.py --- docs/conf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d30f84e683..c3d022d6f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,9 +40,7 @@ # add markdown parser CommonMarkParser.github_doc_root = github_doc_root -source_parsers = { - '.md': CommonMarkParser -} +extensions = ['recommonmark'] # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones From 7fdd0bbec72f02c85c6257d3f421a254dcb1d924 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 10 Jun 2019 21:44:11 +0000 Subject: [PATCH 6/6] Make linkcheck optional --- ci/jenkins/build_steps.groovy | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/jenkins/build_steps.groovy b/ci/jenkins/build_steps.groovy index b7b900f753..f2563adad7 100644 --- a/ci/jenkins/build_steps.groovy +++ b/ci/jenkins/build_steps.groovy @@ -117,6 +117,7 @@ def website_linkcheck(workspace_name, conda_env_name) { if [[ ${enforce_linkcheck} == true ]]; then make -C docs linkcheck SPHINXOPTS=-W else + set +e make -C docs linkcheck fi; set +ex