diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index e548235..84a8ccf 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -2,7 +2,7 @@ import logging from importlib import import_module -from torchtext.vocab import Vocab +from torchtext.vocab import Vocab, Vectors from .numericalizer import ( AVAILABLE_VOCABS, @@ -93,7 +93,11 @@ def _parse_field_config(cls, field, field_config, record_list): "an field name." ) vocab = Vocab(vocab_counter) - vocab.load_vectors(vocab_type) + if vocab_type in {'tx_embeddings_large.vec','tx_embeddings.vec'}: + vectors = Vectors(vocab_type, cache='.vector_cache') + vocab.load_vectors(vectors) + else: + vocab.load_vectors(vocab_type) # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index dbdbdc8..4bd90df 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -27,6 +27,7 @@ "glove.6B.100d", "glove.6B.200d", "glove.6B.300d", + "tx_embeddings_large.vec", ]