Skip to content

Commit

Permalink
Fix for isolated filesystems in multi-node setting
Browse files Browse the repository at this point in the history
Signed-off-by: Greg Heinrich <[email protected]>
  • Loading branch information
gheinrich committed Jun 6, 2023
1 parent 62718fd commit b97c522
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 34 deletions.
105 changes: 78 additions & 27 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import torch

from nemo.core import Dataset
from nemo.utils import logging
from nemo.utils import AppState, logging

__all__ = ['TextMemMapDataset', 'CSVMemMapDataset', 'build_index_files']
__idx_version__ = '0.2' # index file version
__idx_suffix__ = 'idx' # index file suffix
__all__ = ["TextMemMapDataset", "CSVMemMapDataset", "build_index_files"]
__idx_version__ = "0.2" # index file version
__idx_suffix__ = "idx" # index file suffix


def _build_index_from_memdata(fn, newline_int):
Expand All @@ -40,7 +40,7 @@ def _build_index_from_memdata(fn, newline_int):
Returns a 1D array of ints.
"""
# use memmap to read file
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
# find newline positions
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
Expand Down Expand Up @@ -74,7 +74,9 @@ def __init__(
header_lines: Optional[int] = 0,
workers: Optional[int] = None,
tokenizer: Optional[Type["TokenizerSpec"]] = None,
build_index_fn: Optional[Callable[[str, Optional[int]], bool]] = _build_index_from_memdata,
build_index_fn: Optional[
Callable[[str, Optional[int]], bool]
] = _build_index_from_memdata,
sort_dataset_paths: Optional[bool] = True,
index_mapping_dir: Optional[str] = None,
):
Expand Down Expand Up @@ -115,9 +117,37 @@ def __init__(

logging.info(f"Building data files")
# load all files into memmap
is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized()
is_distributed = (
torch.distributed.is_available() and torch.distributed.is_initialized()
)

if not is_distributed or (is_distributed and torch.distributed.get_rank() == 0):
# Create index files on global rank 0.
build_index_files(
dataset_paths,
newline_int,
workers=self._worker,
build_index_fn=build_index_fn,
index_mapping_dir=index_mapping_dir,
)

if is_distributed:
torch.distributed.barrier()

if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0):
if is_distributed and AppState().local_rank == 0:
# If we are in a distributed multi-node set-up and index files are not stored on
# a shared filesystem, then the index files created on global rank 0 are only
# accessible to the workers on that node.
#
# Two cases may occur here:
#
# 1. case of a shared filesystem, or global_rank==0: the index files are present in
# the locally available filesystem, calling build_index_files() again is a no-op.
# 2. case of a non-shared filesystem, and global_rank>0: the index files are not
# present in the locally available filesystem, calling build_index_files() again
# will create them.
#
# Outcome in all cases: all nodes have access to the index files in their filesystem.
build_index_files(
dataset_paths,
newline_int,
Expand All @@ -126,18 +156,22 @@ def __init__(
index_mapping_dir=index_mapping_dir,
)

if is_ditributed:
if is_distributed:
torch.distributed.barrier()

logging.info(f"Loading data files")
start_time = time.time()
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
mdata_midx_list = [
self.load_file(fn, index_mapping_dir) for fn in self._files_list
]
logging.info(
f'Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)

logging.info("Computing global indices")
midx_bins = np.cumsum([(len(midx) - header_lines) for _, midx in mdata_midx_list])
midx_bins = np.cumsum(
[(len(midx) - header_lines) for _, midx in mdata_midx_list]
)

self.midx_bins = midx_bins
self.mdata_midx_list = mdata_midx_list
Expand All @@ -158,7 +192,9 @@ def __getitem__(self, idx):
Return a string from binary memmap
"""
if (idx >= len(self)) or (idx < 0):
raise IndexError(f"Index {idx} if out of dataset range with {len(self)} samples")
raise IndexError(
f"Index {idx} if out of dataset range with {len(self)} samples"
)

# Identify the file containing the record
file_id = np.digitize(idx, self.midx_bins, right=False)
Expand Down Expand Up @@ -189,7 +225,9 @@ def __getitem__(self, idx):
logging.error(
f"Error while building data from text, possible issue with sample expected format (see offending sample below): {e}"
)
logging.error(f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}")
logging.error(
f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}"
)
raise e

return data
Expand Down Expand Up @@ -224,34 +262,36 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None):
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")

if _index_file_exists(idx_fn):
# load index file into memory map
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode='r')
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
# test for header
if len(midx) < self._header_lines:
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")
raise RuntimeError(
f"Missing header, expected {self._header_lines} header lines"
)

# load meta info
idx_info_dict = pickle.load(open(idx_fn + ".info", 'rb'))
idx_info_dict = pickle.load(open(idx_fn + ".info", "rb"))
# test for mismatch in expected newline_int
if 'newline_int' in idx_info_dict:
newline_int = idx_info_dict['newline_int']
if "newline_int" in idx_info_dict:
newline_int = idx_info_dict["newline_int"]
if self._newline_int != newline_int:
logging.warning(
f"Mismatch in newline_int, expected = {self._newline_int} but loaded {newline_int}"
)

# test for version mismatch (useful to force recreation of index files)
idx_version = idx_info_dict.get('version', '0.0')
idx_version = idx_info_dict.get("version", "0.0")
if __idx_version__ != idx_version:
raise RuntimeError(
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
)
else:
raise ValueError(
f'Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}'
f"Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}"
)

return (mdata, midx)
Expand All @@ -271,7 +311,7 @@ def __init__(
tokenizer: Optional[Type["TokenizerSpec"]] = None,
sort_dataset_paths: Optional[bool] = True,
data_col=1,
data_sep=',',
data_sep=",",
index_mapping_dir: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -409,7 +449,9 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir
# validate midx
midx = np.asarray(midx)
if not np.issubdtype(midx.dtype, np.integer):
raise TypeError(f"midx must be an integer array, but got type = {midx.dtype}")
raise TypeError(
f"midx must be an integer array, but got type = {midx.dtype}"
)

# create e metadata file
data = dict(newline_int=newline_int, version=__idx_version__)
Expand All @@ -424,7 +466,11 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir


def build_index_files(
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None
dataset_paths,
newline_int,
workers=None,
build_index_fn=_build_index_from_memdata,
index_mapping_dir: str = None,
):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
Expand All @@ -438,10 +484,15 @@ def build_index_files(
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir),
partial(
_build_memmap_index_files,
newline_int,
build_index_fn,
index_mapping_dir=index_mapping_dir,
),
dataset_paths,
)

logging.info(
f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)
29 changes: 22 additions & 7 deletions tests/collections/nlp/test_mem_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def jsonl_file(tmp_path):
file_path = tmp_path / "data.jsonl"

# Generate data to write to the JSONL file
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}, {"name": "Bob", "age": 35}]
data = [
{"name": "John", "age": 30},
{"name": "Jane", "age": 25},
{"name": "Bob", "age": 35},
]

# Write data to the JSONL file
with open(file_path, mode="w") as file:
for item in data:
json.dump(item, file)
file.write('\n')
file.write("\n")

# Provide the file path to the test function
yield str(file_path)
Expand Down Expand Up @@ -65,7 +69,9 @@ def csv_file(tmp_path):
def test_jsonl_mem_map_dataset(jsonl_file):
"""Test for JSONL memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
indexed_dataset = text_memmap_dataset.JSONLMemMapDataset(
dataset_paths=[jsonl_file], header_lines=0
)
assert indexed_dataset[0] == {"name": "John", "age": 30}
assert indexed_dataset[1] == {"name": "Jane", "age": 25}
assert indexed_dataset[2] == {"name": "Bob", "age": 35}
Expand All @@ -74,19 +80,26 @@ def test_jsonl_mem_map_dataset(jsonl_file):
def test_csv_mem_map_dataset(csv_file):
"""Test for CSV memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.CSVMemMapDataset(dataset_paths=[csv_file], data_col=1, header_lines=1)
indexed_dataset = text_memmap_dataset.CSVMemMapDataset(
dataset_paths=[csv_file], data_col=1, header_lines=1
)
assert indexed_dataset[0].strip() == "John"
assert indexed_dataset[1].strip() == "Jane"
assert indexed_dataset[2].strip() == "Bob"


@pytest.mark.parametrize(
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset]
"dataset_class",
[text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset],
)
@pytest.mark.parametrize("use_alternative_index_mapping_dir", [True, False])
@pytest.mark.parametrize("relative_index_fn", [True, False])
def test_mem_map_dataset_index_mapping_dir(
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn
tmp_path,
dataset_class,
jsonl_file,
use_alternative_index_mapping_dir,
relative_index_fn,
):
"""Test for index_mapping_dir."""
if relative_index_fn:
Expand All @@ -108,6 +121,8 @@ def test_mem_map_dataset_index_mapping_dir(
assert os.path.isfile(f"{idx_fn}.npy")
assert os.path.isfile(f"{idx_fn}.info")
else:
text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
text_memmap_dataset.JSONLMemMapDataset(
dataset_paths=[jsonl_file], header_lines=0
)
assert os.path.isfile(f"{jsonl_file}.idx.npy")
assert os.path.isfile(f"{jsonl_file}.idx.info")

0 comments on commit b97c522

Please sign in to comment.