Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Replace gluster paths with local file paths for NLG configs #1197

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytext/data/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytext.fields import Field, FieldMeta, RawField, VocabUsingField
from pytext.utils import cuda, distributed, embeddings as embeddings_utils
from pytext.utils.data import parse_json_array
from pytext.utils.path import get_absolute_path
from torchtext import data as textdata

from .utils import align_target_labels
Expand Down Expand Up @@ -249,6 +250,7 @@ def load_vocab(self, vocab_file, vocab_size, lowercase_tokens: bool = False):
lowercase_tokens (bool): if the tokens should be lowercased
"""
vocab: Set[str] = set()
vocab_file = get_absolute_path(vocab_file)
if os.path.isfile(vocab_file):
with open(vocab_file, "r") as f:
for i, line in enumerate(f):
Expand Down Expand Up @@ -727,6 +729,7 @@ def read_from_file(
columns_to_use (Union[Dict[str, int], List[str]]): either a list of
column names or a dict of column name -> column index in the file
"""
file_name = get_absolute_path(file_name)
print("reading data from {}".format(file_name))
if isinstance(columns_to_use, list):
columns_to_use = {
Expand Down
2 changes: 2 additions & 0 deletions pytext/utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytext.common.constants import PackageFileName
from pytext.config.field_config import EmbedInitStrategy
from pytext.utils.file_io import PathManager
from pytext.utils.path import get_absolute_path


class PretrainedEmbedding(object):
Expand All @@ -24,6 +25,7 @@ def __init__(
delimiter: str = " ",
) -> None:
if embeddings_path:
embeddings_path = get_absolute_path(embeddings_path)
if PathManager.isdir(embeddings_path):
serialized_embed_path = os.path.join(
embeddings_path, PackageFileName.SERIALIZED_EMBED
Expand Down