Skip to content

Commit

Permalink
FixTextMemMapDataset index file creation in multi-node setup (#6768)
Browse files Browse the repository at this point in the history
* Fix for isolated filesystems in multi-node setting

Signed-off-by: Greg Heinrich <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Greg Heinrich <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Micha Livne <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2023
1 parent c4e677a commit f344fdb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
66 changes: 46 additions & 20 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import torch

from nemo.core import Dataset
from nemo.utils import logging
from nemo.utils import AppState, logging

__all__ = ['TextMemMapDataset', 'CSVMemMapDataset', 'build_index_files']
__idx_version__ = '0.2' # index file version
__idx_suffix__ = 'idx' # index file suffix
__all__ = ["TextMemMapDataset", "CSVMemMapDataset", "build_index_files"]
__idx_version__ = "0.2" # index file version
__idx_suffix__ = "idx" # index file suffix


def _build_index_from_memdata(fn, newline_int):
Expand All @@ -40,7 +40,7 @@ def _build_index_from_memdata(fn, newline_int):
Returns a 1D array of ints.
"""
# use memmap to read file
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
# find newline positions
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
Expand Down Expand Up @@ -115,9 +115,10 @@ def __init__(

logging.info(f"Building data files")
# load all files into memmap
is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized()
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()

if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0):
if not is_distributed or (is_distributed and torch.distributed.get_rank() == 0):
# Create index files on global rank 0.
build_index_files(
dataset_paths,
newline_int,
Expand All @@ -126,14 +127,39 @@ def __init__(
index_mapping_dir=index_mapping_dir,
)

if is_ditributed:
if is_distributed:
torch.distributed.barrier()

if is_distributed and AppState().local_rank == 0:
# If we are in a distributed multi-node set-up and index files are not stored on
# a shared filesystem, then the index files created on global rank 0 are only
# accessible to the workers on that node.
#
# Two cases may occur here:
#
# 1. case of a shared filesystem, or global_rank==0: the index files are present in
# the locally available filesystem, calling build_index_files() again is a no-op.
# 2. case of a non-shared filesystem, and global_rank>0: the index files are not
# present in the locally available filesystem, calling build_index_files() again
# will create them.
#
# Outcome in all cases: all nodes have access to the index files in their filesystem.
build_index_files(
dataset_paths,
newline_int,
workers=self._worker,
build_index_fn=build_index_fn,
index_mapping_dir=index_mapping_dir,
)

if is_distributed:
torch.distributed.barrier()

logging.info(f"Loading data files")
start_time = time.time()
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
logging.info(
f'Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)

logging.info("Computing global indices")
Expand Down Expand Up @@ -224,34 +250,34 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None):
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")

if _index_file_exists(idx_fn):
# load index file into memory map
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode='r')
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
# test for header
if len(midx) < self._header_lines:
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")

# load meta info
idx_info_dict = pickle.load(open(idx_fn + ".info", 'rb'))
idx_info_dict = pickle.load(open(idx_fn + ".info", "rb"))
# test for mismatch in expected newline_int
if 'newline_int' in idx_info_dict:
newline_int = idx_info_dict['newline_int']
if "newline_int" in idx_info_dict:
newline_int = idx_info_dict["newline_int"]
if self._newline_int != newline_int:
logging.warning(
f"Mismatch in newline_int, expected = {self._newline_int} but loaded {newline_int}"
)

# test for version mismatch (useful to force recreation of index files)
idx_version = idx_info_dict.get('version', '0.0')
idx_version = idx_info_dict.get("version", "0.0")
if __idx_version__ != idx_version:
raise RuntimeError(
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
)
else:
raise ValueError(
f'Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}'
f"Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}"
)

return (mdata, midx)
Expand All @@ -271,7 +297,7 @@ def __init__(
tokenizer: Optional[Type["TokenizerSpec"]] = None,
sort_dataset_paths: Optional[bool] = True,
data_col=1,
data_sep=',',
data_sep=",",
index_mapping_dir: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -424,7 +450,7 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir


def build_index_files(
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None,
):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
Expand All @@ -438,10 +464,10 @@ def build_index_files(
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir),
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir,),
dataset_paths,
)

logging.info(
f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)
12 changes: 8 additions & 4 deletions tests/collections/nlp/test_mem_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def jsonl_file(tmp_path):
file_path = tmp_path / "data.jsonl"

# Generate data to write to the JSONL file
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}, {"name": "Bob", "age": 35}]
data = [
{"name": "John", "age": 30},
{"name": "Jane", "age": 25},
{"name": "Bob", "age": 35},
]

# Write data to the JSONL file
with open(file_path, mode="w") as file:
for item in data:
json.dump(item, file)
file.write('\n')
file.write("\n")

# Provide the file path to the test function
yield str(file_path)
Expand Down Expand Up @@ -81,12 +85,12 @@ def test_csv_mem_map_dataset(csv_file):


@pytest.mark.parametrize(
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset]
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset],
)
@pytest.mark.parametrize("use_alternative_index_mapping_dir", [True, False])
@pytest.mark.parametrize("relative_index_fn", [True, False])
def test_mem_map_dataset_index_mapping_dir(
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn,
):
"""Test for index_mapping_dir."""
if relative_index_fn:
Expand Down

0 comments on commit f344fdb

Please sign in to comment.