Skip to content

Commit

Permalink
Add optional index mapping dir in mmap text datasets
Browse files Browse the repository at this point in the history
If datasets are stored on a read-only medium, index files
cannot be created into adjacent files and an
alternative directory must be specified for index
mapping files.

This commit adds an optional `index_mapping_dir` to
the constructors.
Unit tests are also added.

Signed-off-by: Greg Heinrich <[email protected]>
  • Loading branch information
gheinrich committed May 23, 2023
1 parent 57824e0 commit 584ea4f
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece).
max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated.
min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements.
min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements.
add_bos (bool): Whether to add a beginning of sentence token to each data example
add_eos (bool): Whether to add an end of sentence token to each data example
add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer)
Expand Down Expand Up @@ -93,7 +93,9 @@ def __init__(
self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape')
assert self.truncation_field in ["answer", "context"]

self.indexed_dataset = JSONLMemMapDataset(dataset_paths=[file_path], tokenizer=None, header_lines=0)
self.indexed_dataset = JSONLMemMapDataset(
dataset_paths=[file_path], tokenizer=None, header_lines=0, index_mapping_dir=index_mapping_dir
)

# Will be None after this call if `max_num_samples` is None
self._build_samples_mapping()
Expand Down
146 changes: 121 additions & 25 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import pickle
import time
from typing import Callable, List, Optional, Type
from functools import partial

import numpy as np
Expand All @@ -35,7 +36,7 @@ def _build_index_from_memdata(fn, newline_int):
"""
Build index of delimiter positions between samples in memmap.
Can be provided externally.
Returns a 1D array of ints.
"""
# use memmap to read file
Expand Down Expand Up @@ -68,17 +69,28 @@ class TextMemMapDataset(Dataset):

def __init__(
self,
dataset_paths,
newline_int=10,
header_lines=0,
workers=None,
tokenizer=None,
sort_dataset_paths=True,
build_index_fn=_build_index_from_memdata,
dataset_paths: List[str],
newline_int: Optional[int] = 10,
header_lines: Optional[int] = 0,
workers: Optional[int] = None,
tokenizer: Optional[Type["TokenizerSpec"]] = None,
build_index_fn: Optional[Callable[[str, Optional[int]], bool]] = _build_index_from_memdata,
sort_dataset_paths: Optional[bool] = True,
index_mapping_dir: Optional[str] = None,
):
"""
build_index_fn - a callable build_index_fn(fn, newline_int) -> midx [np.array] that returns the index of newlines in a file fn
must be pickleable (to be used in multiprocessing.Pool.map)
Args:
dataset_paths: list of JSONL file paths.
newline_int: ASCII code to use to interpret newlines in file.
header_lines: number of header lines in JSON files.
workers: number of workers to use for creating index files.
tokenizer: tokenizer to use to convert text to tokens.
build_index_fn: a callable build_index_fn(fn, newline_int) -> midx [np.array]
that returns the index of newlines in a file fn must be pickleable
(to be used in multiprocessing.Pool.map).
sort_dataset_paths: whether to sort datasets by paths.
index_mapping_dir: directory to save the index mapping to.
If None, will write to the same folder as the dataset.
"""
super().__init__()
self.mdata_midx_list = []
Expand Down Expand Up @@ -106,14 +118,20 @@ def __init__(
is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized()

if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0):
build_index_files(dataset_paths, newline_int, workers=self._worker, build_index_fn=build_index_fn)
build_index_files(
dataset_paths,
newline_int,
workers=self._worker,
build_index_fn=build_index_fn,
index_mapping_dir=index_mapping_dir,
)

if is_ditributed:
torch.distributed.barrier()

logging.info(f"Loading data files")
start_time = time.time()
mdata_midx_list = [self.load_file(fn) for fn in self._files_list]
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
logging.info(
f'Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
)
Expand Down Expand Up @@ -193,7 +211,7 @@ def _build_data_from_text(self, text):

return data

def load_file(self, fn):
def load_file(self, fn, index_mapping_dir: Optional[str] = None):
"""
Loads a text file as np.int8.
Expand All @@ -203,7 +221,7 @@ def load_file(self, fn):
size - number of lines in file
"""
logging.info(f"Loading {fn}")
idx_fn = f"{fn}.{__idx_suffix__}"
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
Expand Down Expand Up @@ -246,30 +264,47 @@ class CSVMemMapDataset(TextMemMapDataset):

def __init__(
self,
dataset_paths,
newline_int=10,
header_lines=1,
workers=None,
tokenizer=None,
sort_dataset_paths=True,
dataset_paths: List[str],
newline_int: Optional[int] = 10,
header_lines: Optional[int] = 0,
workers: Optional[int] = None,
tokenizer: Optional[Type["TokenizerSpec"]] = None,
sort_dataset_paths: Optional[bool] = True,
data_col=1,
data_sep=',',
index_mapping_dir: Optional[str] = None,
):
"""
Args:
dataset_paths: list of JSONL file paths.
newline_int: ASCII code to use to interpret newlines in file.
header_lines: number of header lines in JSON files.
workers: number of workers to use for creating index files.
tokenizer: tokenizer to use to convert text to tokens.
sort_dataset_paths: whether to sort datasets by paths.
data_col: index of data column.
data_sep: data separator.
index_mapping_dir: directory to save the index mapping to.
If None, will write to the same folder as the dataset.
"""
super().__init__(
dataset_paths=dataset_paths,
newline_int=newline_int,
header_lines=header_lines,
workers=workers,
tokenizer=tokenizer,
sort_dataset_paths=sort_dataset_paths,
index_mapping_dir=index_mapping_dir,
)
self._data_col = data_col
self._data_sep = data_sep

def _build_data_from_text(self, text):
"""Return a CSV field from text"""
# get CSV field
print("text", text)
text = text.split(self._data_sep)[self._data_col]
print("text_", text)
# tokenize
return super()._build_data_from_text(text)

Expand All @@ -280,15 +315,34 @@ class JSONLMemMapDataset(TextMemMapDataset):
"""

def __init__(
self, dataset_paths, newline_int=10, header_lines=1, workers=None, tokenizer=None, sort_dataset_paths=True,
self,
dataset_paths: List[str],
newline_int: Optional[int] = 10,
header_lines: Optional[int] = 0,
workers: Optional[int] = None,
tokenizer: Optional[Type["TokenizerSpec"]] = None,
sort_dataset_paths: Optional[bool] = True,
index_mapping_dir: Optional[str] = None,
):
"""
Args:
dataset_paths: list of JSONL file paths.
newline_int: ASCII code to use to interpret newlines in file.
header_lines: number of header lines in JSON files.
workers: number of workers to use for creating index files.
tokenizer: tokenizer to use to convert text to tokens.
sort_dataset_paths: whether to sort datasets by paths.
index_mapping_dir: directory to save the index mapping to.
If None, will write to the same folder as the dataset.
"""
super().__init__(
dataset_paths=dataset_paths,
newline_int=newline_int,
header_lines=header_lines,
workers=workers,
tokenizer=tokenizer,
sort_dataset_paths=sort_dataset_paths,
index_mapping_dir=index_mapping_dir,
)

def _build_data_from_text(self, text):
Expand All @@ -304,9 +358,46 @@ def _index_file_exists(idx_fn):
return False


def _build_memmap_index_files(newline_int, build_index_fn, fn):
def _index_fn(fn: str, index_mapping_dir: str) -> str:
"""Return base file name of index files.
This returns the base file name associated with specified index
files. This base name is the base on top of which suffixes
like .npy or .info are added.
The parent directory is created if it does not already exist.
fn may be specified in multiple ways:
1. file name: data.jsonl,
2. relative path to a file: relative/path/to/data.jsonl,
3. absolute path to a file: /absolute/path/to/data.jsonl.
This function returns paths in the pattern of:
1. /path/to/input_mapping_dir/data.jsonl.idx
2. /path/to/input_mapping_dir/relative/path/to/data.jsonl.idx
3. /path/to/input_mapping_dir/absolute/path/to/data.jsonl.idx
Args:
fn: filename to get base name for.
index_mapping_dir: directory to save the index mapping to.
If None, will write to the same folder as the dataset.
"""
if index_mapping_dir:
# Note: we don't want to use `os.path.join()` here because
# the supplied `fn` may be passed as an absolute path, and
# calling `os.path.join(index_mapping_dir, absolute_path)` would
# return `absolute_path`, i.e. the index mapping dir would be ignored.
idx_fn = f"{index_mapping_dir}/{fn}.{__idx_suffix__}"
# Create parent directory if needed.
os.makedirs(os.path.dirname(idx_fn), exist_ok=True)
else:
idx_fn = f"{fn}.{__idx_suffix__}"
return idx_fn


def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir: str):
"""Helper function to build an index file"""
idx_fn = f"{fn}.{__idx_suffix__}"
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
if _index_file_exists(idx_fn):
Expand All @@ -332,7 +423,9 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn):
return True


def build_index_files(dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata):
def build_index_files(
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None
):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
raise ValueError("files_list must contain at leat one file name")
Expand All @@ -344,7 +437,10 @@ def build_index_files(dataset_paths, newline_int, workers=None, build_index_fn=_
# load all files into memmap
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(partial(_build_memmap_index_files, newline_int, build_index_fn), dataset_paths)
build_status = p.map(
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir),
dataset_paths,
)

logging.info(
f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
Expand Down
109 changes: 109 additions & 0 deletions tests/collections/nlp/test_mem_map_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import os
import pytest
import json

from nemo.collections.nlp.data.language_modeling import text_memmap_dataset


@pytest.fixture
def jsonl_file(tmp_path):
# Create a temporary file path
file_path = tmp_path / "data.jsonl"

# Generate data to write to the JSONL file
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}, {"name": "Bob", "age": 35}]

# Write data to the JSONL file
with open(file_path, mode="w") as file:
for item in data:
json.dump(item, file)
file.write('\n')

# Provide the file path to the test function
yield str(file_path)

# Optional: Clean up the temporary file after the test
file_path.unlink()


@pytest.fixture
def csv_file(tmp_path):
# Create a temporary file path
file_path = tmp_path / "data.csv"

# Generate data to write to the CSV file
data = [["ID", "Name"], [1, "John"], [2, "Jane"], [3, "Bob"]]

# Write data to the CSV file
with open(file_path, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerows(data)

# Provide the file path to the test function
yield str(file_path)

# Optional: Clean up the temporary file after the test
file_path.unlink()


def test_jsonl_mem_map_dataset(jsonl_file):
"""Test for JSONL memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
assert indexed_dataset[0] == {"name": "John", "age": 30}
assert indexed_dataset[1] == {"name": "Jane", "age": 25}
assert indexed_dataset[2] == {"name": "Bob", "age": 35}


def test_csv_mem_map_dataset(csv_file):
"""Test for CSV memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.CSVMemMapDataset(dataset_paths=[csv_file], data_col=1, header_lines=1)
assert indexed_dataset[0].strip() == "John"
assert indexed_dataset[1].strip() == "Jane"
assert indexed_dataset[2].strip() == "Bob"


@pytest.mark.parametrize(
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset]
)
@pytest.mark.parametrize("use_alternative_index_mapping_dir", [True, False])
@pytest.mark.parametrize("relative_index_fn", [True, False])
def test_mem_map_dataset_index_mapping_dir(
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn
):
"""Test for index_mapping_dir."""

if use_alternative_index_mapping_dir:
index_mapping_dir = tmp_path / "subdir"
if relative_index_fn:
jsonl_file = os.path.relpath(jsonl_file)
else:
jsonl_file = os.path.abspath(jsonl_file)
dataset_class(dataset_paths=[jsonl_file], header_lines=0, index_mapping_dir=str(index_mapping_dir))
# Index files should not be created in default location.
assert not os.path.isfile(f"{jsonl_file}.idx.npy")
assert not os.path.isfile(f"{jsonl_file}.idx.info")
# Index file names are calculated using a hash of the JSON file name.
idx_fn = f"{str(index_mapping_dir)}/{jsonl_file}.idx"
assert os.path.isfile(f"{idx_fn}.npy")
assert os.path.isfile(f"{idx_fn}.info")
else:
text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
assert os.path.isfile(f"{jsonl_file}.idx.npy")
assert os.path.isfile(f"{jsonl_file}.idx.info")

0 comments on commit 584ea4f

Please sign in to comment.