diff --git a/ci/jenkins/build_steps.groovy b/ci/jenkins/build_steps.groovy index b7b900f753..f2563adad7 100644 --- a/ci/jenkins/build_steps.groovy +++ b/ci/jenkins/build_steps.groovy @@ -117,6 +117,7 @@ def website_linkcheck(workspace_name, conda_env_name) { if [[ ${enforce_linkcheck} == true ]]; then make -C docs linkcheck SPHINXOPTS=-W else + set +e make -C docs linkcheck fi; set +ex diff --git a/docs/conf.py b/docs/conf.py index d30f84e683..c3d022d6f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,9 +40,7 @@ # add markdown parser CommonMarkParser.github_doc_root = github_doc_root -source_parsers = { - '.md': CommonMarkParser -} +extensions = ['recommonmark'] # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones diff --git a/src/gluonnlp/embedding/token_embedding.py b/src/gluonnlp/embedding/token_embedding.py index c52b99c568..171be71e9f 100644 --- a/src/gluonnlp/embedding/token_embedding.py +++ b/src/gluonnlp/embedding/token_embedding.py @@ -42,9 +42,9 @@ from ..data.utils import DefaultLookupDict from ..model.train import FasttextEmbeddingModel -UNK_IDX = 0 # This should not be changed as long as serialized token - # embeddings redistributed on S3 contain an unknown token. Commit - # 46eb28ffcf542ab8ed801a727a45404a73adbce3 has more context. +UNK_IDX = 0 +ENCODING = 'utf8' +INIT_UNKNOWN_VEC = nd.zeros def register(embedding_cls): """Registers a new token embedding. @@ -168,9 +168,10 @@ class TokenEmbedding(object): Any unknown token will be replaced by unknown_token and consequently will be indexed as the same representation. Only used if oov_imputer is not specified. - init_unknown_vec : callback + init_unknown_vec : callback, default nd.zeros The callback used to initialize the embedding vector for the unknown - token. Only used if `unknown_token` is not None. + token. Only used if `unknown_token` is not None and `idx_to_token` is + not None and does not contain `unknown_vec`. allow_extend : bool, default False If True, embedding vectors for previously unknown words can be added via token_embedding[tokens] = vecs. If False, only vectors for known @@ -180,22 +181,84 @@ class TokenEmbedding(object): automatically from `unknown_lookup[unknown_tokens]`. For example, in a FastText model, embeddings for unknown tokens can be computed from the subword information. + idx_to_token : list of str or None, default None + If not None, a list of tokens for which the `idx_to_vec` argument + provides embeddings. The list indices and the indices of `idx_to_vec` + must be aligned. + If `idx_to_token` is not None, `idx_to_vec` must not be None either. + If `idx_to_token` is None, an empty TokenEmbedding object is created. + If `allow_extend` is True, tokens and their embeddings can be added to + the TokenEmbedding at a later stage. + idx_to_vec : mxnet.ndarray.NDArray or None, default None + If not None, a NDArray containing embeddings for the tokens specified + in `idx_to_token`. The first dimension of `idx_to_vec` must be aligned + with `idx_to_token`. + If `idx_to_vec` is not None, `idx_to_token` must not be None either. + If `idx_to_vec` is None, an empty TokenEmbedding object is created. + If `allow_extend` is True, tokens and their embeddings can be added to + the TokenEmbedding at a later stage. + No copy of the idx_to_vec array is made as long as unknown_token is + None or an embedding for unknown_token is specified in `idx_to_vec`. """ - def __init__(self, unknown_token='', init_unknown_vec=nd.zeros, - allow_extend=False, unknown_lookup=None): - self._unknown_token = unknown_token - self._init_unknown_vec = init_unknown_vec - self._allow_extend = allow_extend - self._unknown_lookup = unknown_lookup - self._idx_to_token = [unknown_token] if unknown_token else [] - if unknown_token: - self._token_to_idx = DefaultLookupDict(UNK_IDX) + def __init__(self, unknown_token=C.UNK_TOKEN, init_unknown_vec=INIT_UNKNOWN_VEC, + allow_extend=False, unknown_lookup=None, idx_to_token=None, + idx_to_vec=None): + unknown_index = None + + # With pre-specified tokens and vectors + if idx_to_vec is not None or idx_to_token is not None: + idx_to_token = idx_to_token[:] + + # Sanity checks + if idx_to_vec is None or idx_to_token is None: + raise ValueError('Must specify either none or both of ' + 'idx_to_token and idx_to_vec.') + if idx_to_vec.shape[0] != len(idx_to_token): + raise ValueError('idx_to_token and idx_to_vec must contain ' + 'the same number of tokens and embeddings respectively.') + if init_unknown_vec is not None: + logging.info('Ignoring init_unknown_vec as idx_to_vec is specified') + if unknown_token is not None: + try: + unknown_index = idx_to_token.index(unknown_token) + except ValueError: + idx_to_token.insert(0, unknown_token) + idx_to_vec = nd.concat(init_unknown_vec((1, idx_to_vec.shape[1])), idx_to_vec, + dim=0) + unknown_index = 0 + + # Initialization + self._unknown_token = unknown_token + self._init_unknown_vec = init_unknown_vec + self._allow_extend = allow_extend + self._unknown_lookup = unknown_lookup + + self._idx_to_token = idx_to_token + self._idx_to_vec = idx_to_vec + + # Empty token-embedding + else: + # Initialization + self._unknown_token = unknown_token + if self._unknown_token is not None: + unknown_index = UNK_IDX + self._init_unknown_vec = init_unknown_vec + self._allow_extend = allow_extend + self._unknown_lookup = unknown_lookup + + assert UNK_IDX == 0 + self._idx_to_token = [unknown_token] if unknown_token else [] + self._idx_to_vec = None + + # Initialization of token_to_idx mapping + if self._unknown_token: + assert unknown_index is not None + self._token_to_idx = DefaultLookupDict(unknown_index) else: self._token_to_idx = {} self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) - self._idx_to_vec = None @staticmethod def _get_file_url(cls_name, source_file_hash, source): @@ -221,8 +284,9 @@ def _get_file_path(cls, source_file_hash, embedding_root, source): return pretrained_file_path - def _load_embedding(self, pretrained_file_path, elem_delim, - encoding='utf8'): + @staticmethod + def _load_embedding(pretrained_file_path, elem_delim, unknown_token, + init_unknown_vec, encoding=ENCODING): """Load embedding vectors from a pre-trained token embedding file. Both text files and TokenEmbedding serialization files are supported. @@ -248,24 +312,40 @@ def _load_embedding(self, pretrained_file_path, elem_delim, pretrained_file_path) if pretrained_file_path.endswith('.npz'): - self._load_embedding_serialized( - pretrained_file_path=pretrained_file_path) + return TokenEmbedding._load_embedding_serialized( + pretrained_file_path=pretrained_file_path, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) else: - self._load_embedding_txt( + return TokenEmbedding._load_embedding_txt( pretrained_file_path=pretrained_file_path, - elem_delim=elem_delim, encoding=encoding) + elem_delim=elem_delim, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) - def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8'): + + @staticmethod + def _load_embedding_txt(pretrained_file_path, elem_delim, unknown_token, + init_unknown_vec, encoding=ENCODING): """Load embedding vectors from a pre-trained token embedding file. - For every unknown token, if its representation `self.unknown_token` is encountered in the - pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token - embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the - text embedding vector initialized by `self._init_unknown_vec`. + Returns idx_to_token, idx_to_vec and unknown_token suitable for the + TokenEmbedding constructor. + + For every unknown token, if its representation `unknown_token` is encountered in the + pre-trained token embedding file, index 0 of `idx_to_vec` maps to the pre-trained token + embedding vector loaded from the file; otherwise, index 0 of `idx_to_vec` maps to the + text embedding vector initialized by `init_unknown_vec`. If a token is encountered multiple times in the pre-trained text embedding file, only the first-encountered token embedding vector will be loaded and the rest will be skipped. + """ + idx_to_token = [unknown_token] if unknown_token else [] + unk_idx = None + if unknown_token: + unk_idx = 0 vec_len = None all_elems = [] @@ -287,9 +367,9 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') token, elems = elems[0], [float(i) for i in elems[1:]] - if token == self.unknown_token and loaded_unknown_vec is None: + if loaded_unknown_vec is None and token == unknown_token: loaded_unknown_vec = elems - tokens.add(self.unknown_token) + tokens.add(unknown_token) elif token in tokens: warnings.warn('line {} in {}: duplicate embedding found for ' 'token "{}". Skipped.'.format(line_num, pretrained_file_path, @@ -300,9 +380,10 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') else: if not vec_len: vec_len = len(elems) - if self.unknown_token: + if unknown_token: # Reserve a vector slot for the unknown token at the very beggining # because the unknown token index is 0. + assert len(all_elems) == 0 all_elems.extend([0] * vec_len) else: assert len(elems) == vec_len, \ @@ -311,76 +392,83 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8') pretrained_file_path, token, vec_len, len(elems)) all_elems.extend(elems) - self._idx_to_token.append(token) - self._token_to_idx[token] = len(self._idx_to_token) - 1 + idx_to_token.append(token) tokens.add(token) - self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) + idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) - if self.unknown_token: + if unknown_token: if loaded_unknown_vec is None: - self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) else: - self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec) + idx_to_vec[unk_idx] = nd.array(loaded_unknown_vec) + + return idx_to_token, idx_to_vec, unknown_token - def _load_embedding_serialized(self, pretrained_file_path): + @staticmethod + def _load_embedding_serialized(pretrained_file_path, unknown_token, init_unknown_vec): """Load embedding vectors from a pre-trained token embedding file. - For every unknown token, if its representation `self.unknown_token` is encountered in the - pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token - embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the - text embedding vector initialized by `self._init_unknown_vec`. + Returns idx_to_token, idx_to_vec and unknown_token suitable for the + TokenEmbedding constructor. ValueError is raised if a token occurs multiple times. """ - deserialized_embedding = TokenEmbedding.deserialize(pretrained_file_path) + + idx_to_token = deserialized_embedding.idx_to_token + if len(set(idx_to_token)) != len(idx_to_token): + raise ValueError('Serialized embedding contains duplicate tokens.') + idx_to_vec = deserialized_embedding.idx_to_vec + vec_len = idx_to_vec.shape[1] + loaded_unknown_vec = False if deserialized_embedding.unknown_token: - # Some .npz files on S3 may contain an unknown token and its - # respective embedding. As a workaround, we assume that UNK_IDX - # is the same now as it was when the .npz was generated. Under this - # assumption we can safely overwrite the respective token and - # vector from the npz. - if deserialized_embedding.unknown_token: - idx_to_token = deserialized_embedding.idx_to_token - idx_to_vec = deserialized_embedding.idx_to_vec - idx_to_token[UNK_IDX] = self.unknown_token - if self._init_unknown_vec: - vec_len = idx_to_vec.shape[1] - idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + if not unknown_token: + # If the TokenEmbedding shall not have an unknown token but the + # serialized file provided one, delete the provided one. + unk_idx = deserialized_embedding.token_to_idx[ + deserialized_embedding.unknown_token] + assert unk_idx >= 0 + if unk_idx == 0: + idx_to_token = idx_to_token[1:] + idx_to_vec = idx_to_vec[1:] + else: + idx_to_token = idx_to_token[:unk_idx] + idx_to_token[unk_idx + 1:] + idx_to_vec = nd.concat(idx_to_vec[:unk_idx], idx_to_vec[unk_idx + 1:], dim=0) else: - # If the TokenEmbedding shall not have an unknown token, we - # just delete the one in the npz. - assert UNK_IDX == 0 - idx_to_token = deserialized_embedding.idx_to_token[UNK_IDX + 1:] - idx_to_vec = deserialized_embedding.idx_to_vec[UNK_IDX + 1:] + # If the TokenEmbedding shall have an unknown token and the + # serialized file provided one, replace the representation. + unk_idx = deserialized_embedding.token_to_idx[ + deserialized_embedding.unknown_token] + idx_to_token[unk_idx] = unknown_token + loaded_unknown_vec = True else: - idx_to_token = deserialized_embedding.idx_to_token - idx_to_vec = deserialized_embedding.idx_to_vec + if unknown_token and unknown_token not in idx_to_token: + # If the TokenEmbedding shall have an unknown token but the + # serialized file didn't provided one, insert a new one + idx_to_token = [unknown_token] + idx_to_token + idx_to_vec = nd.concat(nd.zeros((1, vec_len)), idx_to_vec, dim=0) + elif unknown_token: + # The serialized file did define a unknown token, but contains + # the token that is specified by the user to represent the + # unknown token. + assert not deserialized_embedding.unknown_token + loaded_unknown_vec = True + # Move unknown_token to idx 0 to replicate the behavior of + # _load_embedding_text + unk_idx = idx_to_token.index(unknown_token) + if unk_idx > 0: + idx_to_token[0], idx_to_token[unk_idx] = idx_to_token[unk_idx], idx_to_token[0] + idx_to_vec[[0, unk_idx]] = idx_to_vec[[unk_idx, 0]] + else: + assert not deserialized_embedding.unknown_token + assert not unknown_token - if not len(set(idx_to_token)) == len(idx_to_token): - raise ValueError('Serialized embedding invalid. ' - 'It contains duplicate tokens.') + if unknown_token and init_unknown_vec and not loaded_unknown_vec: + unk_idx = idx_to_token.index(unknown_token) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) - if self.unknown_token: - try: - unknown_token_idx = deserialized_embedding.idx_to_token.index( - self.unknown_token) - idx_to_token[UNK_IDX], idx_to_token[ - unknown_token_idx] = idx_to_token[ - unknown_token_idx], idx_to_token[UNK_IDX] - idxs = [UNK_IDX, unknown_token_idx] - idx_to_vec[idxs] = idx_to_vec[idxs[::-1]] - except ValueError: - vec_len = idx_to_vec.shape[1] - idx_to_token.insert(0, self.unknown_token) - idx_to_vec = nd.concat( - self._init_unknown_vec(shape=vec_len).reshape((1, -1)), - idx_to_vec, dim=0) - - self._idx_to_token = idx_to_token - self._idx_to_vec = idx_to_vec - self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) + return idx_to_token, idx_to_vec, unknown_token @property def idx_to_token(self): @@ -635,7 +723,7 @@ def __setitem__(self, tokens, new_embedding): raise KeyError(('Token "{}" is unknown. To update the embedding vector for an' ' unknown token, please explicitly include "{}" as the ' '`unknown_token` in `tokens`. This is to avoid unintended ' - 'updates.').format(token, self._idx_to_token[UNK_IDX])) + 'updates.').format(token, self.unknown_token)) else: raise KeyError(('Token "{}" is unknown. Updating the embedding vector for an ' 'unknown token is not allowed because `unknown_token` is not ' @@ -661,7 +749,7 @@ def _check_source(cls, source_file_hash, source): ', '.join(source_file_hash.keys()))) @staticmethod - def from_file(file_path, elem_delim=' ', encoding='utf8', **kwargs): + def from_file(file_path, elem_delim=' ', encoding=ENCODING, **kwargs): """Creates a user-defined token embedding from a pre-trained embedding file. @@ -693,9 +781,22 @@ def from_file(file_path, elem_delim=' ', encoding='utf8', **kwargs): instance of :class:`gluonnlp.embedding.TokenEmbedding` The user-defined token embedding instance. """ - embedding = TokenEmbedding(**kwargs) - embedding._load_embedding(file_path, elem_delim=elem_delim, encoding=encoding) - return embedding + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + idx_to_token, idx_to_vec, unknown_token = TokenEmbedding._load_embedding( + file_path, + elem_delim=elem_delim, + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + return TokenEmbedding(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) def serialize(self, file_path, compress=True): """Serializes the TokenEmbedding to a file specified by file_path. @@ -740,8 +841,8 @@ def serialize(self, file_path, compress=True): idx_to_token=idx_to_token, idx_to_vec=idx_to_vec) - @classmethod - def deserialize(cls, file_path, **kwargs): + @staticmethod + def deserialize(file_path, **kwargs): """Create a new TokenEmbedding from a serialized one. TokenEmbedding is serialized by converting the list of tokens, the @@ -774,18 +875,15 @@ def deserialize(cls, file_path, **kwargs): idx_to_token = npz_dict['idx_to_token'].tolist() idx_to_vec = nd.array(npz_dict['idx_to_vec']) - embedding = cls(unknown_token=unknown_token, **kwargs) - if unknown_token: - assert unknown_token == idx_to_token[UNK_IDX] - embedding._token_to_idx = DefaultLookupDict(UNK_IDX) - else: - embedding._token_to_idx = {} - - embedding._idx_to_token = idx_to_token - embedding._idx_to_vec = idx_to_vec - embedding._token_to_idx.update((token, idx) for idx, token in enumerate(idx_to_token)) - - return embedding + assert 'unknown_token' not in kwargs + assert 'init_unknown_vec' not in kwargs + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + return TokenEmbedding(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) @register @@ -846,11 +944,24 @@ class GloVe(TokenEmbedding): def __init__(self, source='glove.6B.50d', embedding_root=os.path.join(get_home_dir(), 'embedding'), **kwargs): self._check_source(self.source_file_hash, source) - - super(GloVe, self).__init__(**kwargs) pretrained_file_path = GloVe._get_file_path(self.source_file_hash, embedding_root, source) - - self._load_embedding(pretrained_file_path, elem_delim=' ') + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + encoding = kwargs.pop('encoding', ENCODING) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(GloVe, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) @register @@ -939,7 +1050,8 @@ class FastText(TokenEmbedding): def __init__(self, source='wiki.simple', embedding_root=os.path.join( get_home_dir(), 'embedding'), load_ngrams=False, ctx=cpu(), **kwargs): self._check_source(self.source_file_hash, source) - + pretrained_file_path = FastText._get_file_path(self.source_file_hash, + embedding_root, source) if load_ngrams: try: self._check_source(self.source_bin_file_hash, source) @@ -956,11 +1068,23 @@ def __init__(self, source='wiki.simple', embedding_root=os.path.join( else: unknown_lookup = None - super(FastText, self).__init__(unknown_lookup=unknown_lookup, **kwargs) - pretrained_file_path = FastText._get_file_path(self.source_file_hash, embedding_root, - source) - - self._load_embedding(pretrained_file_path, elem_delim=' ') + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) + encoding = kwargs.pop('encoding', ENCODING) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(FastText, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + unknown_lookup=unknown_lookup, **kwargs) @register @@ -1027,19 +1151,37 @@ class Word2Vec(TokenEmbedding): source_file_hash = C.WORD2VEC_NPZ_SHA1 def __init__(self, source='GoogleNews-vectors-negative300', - embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding='utf8', + embedding_root=os.path.join(get_home_dir(), 'embedding'), encoding=ENCODING, **kwargs): - super(Word2Vec, self).__init__(**kwargs) + unknown_token = kwargs.pop('unknown_token', C.UNK_TOKEN) + init_unknown_vec = kwargs.pop('init_unknown_vec', INIT_UNKNOWN_VEC) if source.endswith('.bin'): pretrained_file_path = os.path.expanduser(source) - self._load_w2v_binary(pretrained_file_path, encoding=encoding) + idx_to_token, idx_to_vec, unknown_token = self._load_w2v_binary( + pretrained_file_path, unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, encoding=encoding) else: self._check_source(self.source_file_hash, source) pretrained_file_path = self._get_file_path(self.source_file_hash, embedding_root, source) - self._load_embedding(pretrained_file_path, elem_delim=' ') + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + pretrained_file_path=pretrained_file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, + encoding=encoding) + + assert 'idx_to_vec' not in kwargs + assert 'idx_to_token' not in kwargs + super(Word2Vec, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) - def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): + @classmethod + def _load_w2v_binary(cls, pretrained_file_path, unknown_token, + init_unknown_vec=INIT_UNKNOWN_VEC, encoding=ENCODING): """Load embedding vectors from a binary pre-trained token embedding file. Parameters @@ -1050,13 +1192,11 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): encoding: str The encoding type of the file. """ - self._idx_to_token = [self.unknown_token] if self.unknown_token else [] - if self.unknown_token: - self._token_to_idx = DefaultLookupDict(UNK_IDX) - else: - self._token_to_idx = {} - self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token)) - self._idx_to_vec = None + idx_to_token = [unknown_token] if unknown_token else [] + unk_idx = None + if unknown_token: + unk_idx = 0 + all_elems = [] tokens = set() loaded_unknown_vec = None @@ -1064,7 +1204,7 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): with io.open(pretrained_file_path, 'rb') as f: header = f.readline().decode(encoding=encoding) vocab_size, vec_len = (int(x) for x in header.split()) - if self.unknown_token: + if unknown_token: # Reserve a vector slot for the unknown token at the very beggining # because the unknown token index is 0. all_elems.extend([0] * vec_len) @@ -1091,9 +1231,9 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): assert len(elems) > 1, 'line {} in {}: unexpected data format.'.format( line_num, pretrained_file_path) - if token == self.unknown_token and loaded_unknown_vec is None: + if token == unknown_token and loaded_unknown_vec is None: loaded_unknown_vec = elems - tokens.add(self.unknown_token) + tokens.add(unknown_token) elif token in tokens: warnings.warn('line {} in {}: duplicate embedding found for ' 'token "{}". Skipped.'.format(line_num, pretrained_file_path, @@ -1105,19 +1245,21 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'): pretrained_file_path, token, vec_len, len(elems)) all_elems.extend(elems) - self._idx_to_token.append(token) - self._token_to_idx[token] = len(self._idx_to_token) - 1 + idx_to_token.append(token) tokens.add(token) - self._idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) - if self.unknown_token: + idx_to_vec = nd.array(all_elems).reshape((-1, vec_len)) + + if unknown_token: if loaded_unknown_vec is None: - self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len) + idx_to_vec[unk_idx] = init_unknown_vec(shape=vec_len) else: - self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec) + idx_to_vec[unk_idx] = nd.array(loaded_unknown_vec) + + return idx_to_token, idx_to_vec, unknown_token @classmethod - def from_w2v_binary(cls, pretrained_file_path, encoding='utf8'): + def from_w2v_binary(cls, pretrained_file_path, encoding=ENCODING): """Load embedding vectors from a binary pre-trained token embedding file. Parameters diff --git a/src/gluonnlp/vocab/vocab.py b/src/gluonnlp/vocab/vocab.py index 13ac920494..61947b5a04 100644 --- a/src/gluonnlp/vocab/vocab.py +++ b/src/gluonnlp/vocab/vocab.py @@ -395,10 +395,6 @@ def set_embedding(self, *embeddings): 'Either all or none of the TokenEmbeddings must have an ' \ 'unknown_token set.' - new_embedding = emb.TokenEmbedding(self.unknown_token, allow_extend=False) - new_embedding._token_to_idx = self.token_to_idx - new_embedding._idx_to_token = self.idx_to_token - new_vec_len = sum(embs.idx_to_vec.shape[1] for embs in embeddings) new_idx_to_vec = nd.zeros(shape=(len(self), new_vec_len)) @@ -412,8 +408,11 @@ def set_embedding(self, *embeddings): new_idx_to_vec[1:, col_start:col_end] = embs[self._idx_to_token[1:]] col_start = col_end - new_embedding._idx_to_vec = new_idx_to_vec - self._embedding = new_embedding + self._embedding = emb.TokenEmbedding(self.unknown_token, + init_unknown_vec=None, + allow_extend=False, + idx_to_token=self.idx_to_token, + idx_to_vec=new_idx_to_vec) def to_tokens(self, indices): """Converts token indices to tokens according to the vocabulary. diff --git a/tests/unittest/test_token_embedding.py b/tests/unittest/test_token_embedding.py new file mode 100644 index 0000000000..7f5a635332 --- /dev/null +++ b/tests/unittest/test_token_embedding.py @@ -0,0 +1,111 @@ +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import functools + +import mxnet as mx +import pytest + +import gluonnlp as nlp +from gluonnlp.base import _str_types + + +class NaiveUnknownLookup(object): + def __init__(self, embsize): + self.embsize = embsize + + def __contains__(self, token): + return True + + def __getitem__(self, tokens): + if isinstance(tokens, _str_types): + return mx.nd.ones(self.embsize) + else: + return mx.nd.ones((len(tokens), self.embsize)) + + +@pytest.mark.parametrize('unknown_token', [None, '', '[UNK]']) +@pytest.mark.parametrize('init_unknown_vec', [mx.nd.zeros, mx.nd.ones]) +@pytest.mark.parametrize('allow_extend', [True, False]) +@pytest.mark.parametrize('unknown_lookup', [None, NaiveUnknownLookup]) +@pytest.mark.parametrize( + 'idx_token_vec_mapping', + [ + (None, None), + (['', 'hello', 'world'], mx.nd.zeros(shape=[3, 300])), # 300 == embsize + (['hello', 'world'], mx.nd.zeros(shape=[2, 300])), # 300 == embsize + ]) +def test_token_embedding_constructor(unknown_token, init_unknown_vec, allow_extend, unknown_lookup, + idx_token_vec_mapping, embsize=300): + idx_to_token, idx_to_vec = idx_token_vec_mapping + + TokenEmbedding = functools.partial( + nlp.embedding.TokenEmbedding, unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec, allow_extend=allow_extend, + unknown_lookup=unknown_lookup(embsize) if unknown_lookup is not None else None, + idx_to_token=idx_to_token, idx_to_vec=idx_to_vec) + + ## Test "legacy" constructor + if idx_to_token is None: + emb = TokenEmbedding() + assert len(emb.idx_to_token) == 1 if unknown_token else len(emb.idx_to_token) == 0 + # emb does not know the embsize, thus idx_to_vec could not be initialized + assert emb.idx_to_vec is None + + # Set unknown_token + if unknown_token: + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert (emb[unknown_token].asnumpy() == mx.nd.zeros(embsize).asnumpy() - 1).all() + assert emb.idx_to_vec.shape[1] == embsize + + if allow_extend: + emb = TokenEmbedding() + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[1] == embsize + + emb = TokenEmbedding() + emb[''] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[0] == 2 if unknown_token else emb.idx_to_vec.shape[0] == 1 + assert (emb[''].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all() + + ## Test with idx_to_vec and idx_to_token arguments + else: + emb = TokenEmbedding() + + if unknown_token and unknown_token not in idx_to_token: + assert emb.idx_to_token == [unknown_token] + idx_to_token + assert (emb.idx_to_vec[1:].asnumpy() == idx_to_vec.asnumpy()).all() + assert (emb.idx_to_vec[0].asnumpy() == init_unknown_vec(embsize).asnumpy()).all() + else: + assert emb.idx_to_token == idx_to_token + assert (emb.idx_to_vec.asnumpy() == idx_to_vec.asnumpy()).all() + + if allow_extend: + emb = TokenEmbedding() + emb[unknown_token] = mx.nd.zeros(embsize) - 1 + assert emb.idx_to_vec.shape[1] == embsize + + emb = TokenEmbedding() + emb[''] = mx.nd.zeros(embsize) - 1 + assert (emb[''].asnumpy() == (mx.nd.zeros(embsize) - 1).asnumpy()).all() + + if unknown_token and unknown_token not in idx_to_token: + assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 2 + else: + assert emb.idx_to_vec.shape[0] == len(idx_to_token) + 1 diff --git a/tests/unittest/test_vocab_embed.py b/tests/unittest/test_vocab_embed.py index 95bd35e66f..0a102b829b 100644 --- a/tests/unittest/test_vocab_embed.py +++ b/tests/unittest/test_vocab_embed.py @@ -17,26 +17,21 @@ # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function +import functools +import os import random import re -import os import sys -import functools +import numpy as np import pytest - -import gluonnlp as nlp from mxnet import ndarray as nd from mxnet.test_utils import * -import numpy as np -if sys.version_info[0] == 3: - _str_types = (str, ) -else: - _str_types = (str, unicode) +import gluonnlp as nlp +from gluonnlp.base import _str_types @pytest.fixture @@ -820,12 +815,20 @@ def __init__(self, embedding_root='embedding', init_unknown_vec=nd.zeros, **kwar source = 'embedding_test' Test._check_source(self.source_file_hash, source) - super(Test, self).__init__(**kwargs) - file_path = Test._get_file_path(self.source_file_hash, embedding_root, source) - - self._load_embedding(file_path, ' ') + unknown_token = kwargs.pop('unknown_token', '') + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) + + return super(Test, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) test_embed = nlp.embedding.create('test', embedding_root='tests/data/embedding') assert_almost_equal(test_embed['hello'].asnumpy(), (nd.arange(5) + 1).asnumpy()) @@ -1050,12 +1053,23 @@ def __init__(self, embedding_root='tests/data/embedding', **kwargs): source = 'embedding_test' Test._check_source(self.source_file_hash, source) - super(Test, self).__init__(**kwargs) - file_path = Test._get_file_path(self.source_file_hash, embedding_root, source) - self._load_embedding(file_path, ' ') + unknown_token = kwargs.pop('unknown_token', '') + init_unknown_vec = kwargs.pop('init_unknown_vec', nd.zeros) + idx_to_token, idx_to_vec, unknown_token = self._load_embedding( + file_path, + elem_delim=' ', + unknown_token=unknown_token, + init_unknown_vec=init_unknown_vec) + + super(Test, self).__init__(unknown_token=unknown_token, + init_unknown_vec=None, + idx_to_token=idx_to_token, + idx_to_vec=idx_to_vec, + **kwargs) + emb = nlp.embedding.create('test', embedding_root='tests/data/embedding')