Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding OnlineSampleMapping #8137

Merged
merged 3 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset, OnlineSampleMapping
from nemo.core.classes import Dataset
from nemo.utils import logging

Expand Down Expand Up @@ -155,6 +155,7 @@ def _maybe_validate_prompt_template(self):

def _build_samples_mapping(self):
if self.max_num_samples is not None:
osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.indexed_dataset,
data_prefix=self.file_path,
Expand All @@ -166,6 +167,7 @@ def _build_samples_mapping(self):
name=self.file_path.split('/')[-1],
binary_head=False,
index_mapping_dir=self.index_mapping_dir,
samples_mapping=osm,
)
else:
self.samples_mapping = None
Expand Down
187 changes: 187 additions & 0 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import functools

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'functools' is imported with both 'import' and 'import from'.
import json
import multiprocessing as mp
import os
Expand Down Expand Up @@ -535,3 +536,189 @@
logging.info(
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)


def handle_index(dataset, idx):
"""
Remaps negative indices and handles numpy int indices.

Arguments:
dataset (Dataset): dataset to index into
idx (int): Index. Can include negative indices.
Returns:
int: Remapped and fully qualified index.

Raises:
IndexError: If a negative index is out of range.

Examples:
>>> import numpy as np
>>> import torch
>>> from torch.utils.data import TensorDataset
>>> from nemo_chem.data.fasta_dataset import handle_index
>>> dataset = TensorDataset(torch.tensor(-np.arange(5)))
>>> handle_index(dataset, 1)
1
>>> handle_index(dataset, -2)
3

"""
if idx < 0 and idx > -len(dataset) - 1:
idx = len(dataset) + idx
elif idx < 0:
raise IndexError(f'Index out of range: {idx}')
return idx


class OnlineSampleMapping:
"""
This class replaces NeMo's get_samples_mapping function which pre-computes.
It is used to create a sample mapping for certain number of samples, including
pseudo-random shuffling.
The sampler allows to down, or upsample a given dataset.
Shuffling leads to pseudo-random shuffling, where blocks are shuffled,
and each block is internally shuffled.
"""

def __init__(
self,
dataset_size: int,
num_samples: int,
block_size: int = 1000000,
cache_maxsize: int = 2,
seed: int = 1,
shuffle: bool = True,
truncate_to_block_boundary: bool = False,
):
"""
Args:
dataset_size (int): Size of the dataset.
num_samples (int): Number of samples the dataset should contain.
block_size (int): Size of each sample block. This is used to shuffle the samples.
None will be replaced with dataset size.
cache_maxsize (int): Maximum size of the blocks cache for the get_sample_block function.
seed (int): Seed for the random number generator used for shuffling.
shuffle (bool): Whether to shuffle the samples.
truncate_to_block_boundary (bool): Whether to truncate the last block to the block boundary (could drop samples).
"""
self.dataset_size = dataset_size
self.num_samples = num_samples
self.block_size = block_size if block_size is not None else self.dataset_size
self.cache_maxsize = cache_maxsize
self.seed = seed
self.shuffle = shuffle
self.truncate_to_block_boundary = truncate_to_block_boundary

# we need at least num_samples (up-sampling) or dataset_size samples (correct down-sampling)
self.required_samples = max(self.num_samples, self.dataset_size)
# block size cannot be larger than dataset size
self.block_size = min(self.block_size, self.dataset_size)
# reduce the last block if needed, to match the required number of samples
last_block_size = self.required_samples % self.block_size
# store required blocks to cover num_samples samples and dataset_size samples
self.num_blocks = int(np.ceil(self.required_samples / self.block_size))

# if required, truncate the last block to the block boundary
if self.truncate_to_block_boundary and last_block_size:
# update num_samples to account for truncated last block only if needed
if self.required_samples == self.num_samples:
self.num_samples -= last_block_size

# apdate num_blocks to account for truncated last block
self.num_blocks -= 1
self.required_samples -= last_block_size
last_block_size = 0

# create a list of blocks (should cover the entire dataset for correct down sampling)
block_idx_list = np.arange(self.num_blocks)
# compute the size of each block
block_size_list = np.full(self.num_blocks, self.block_size)
if last_block_size:
block_size_list[-1] = last_block_size
self.use_digitize = True
else:
self.use_digitize = False
if shuffle:
local_rng = np.random.RandomState(seed=self.seed)
idx = local_rng.permutation(np.arange(self.num_blocks))
block_idx_list = block_idx_list[idx]
block_size_list = block_size_list[idx]

# store only required number of blocks
self.block_idx_list = block_idx_list
self.block_size_list = block_size_list
self.block_bins = np.cumsum(block_size_list)

# NOTE: MAKE get_sample_block A CACHED FUNCTION!!!
self.get_sample_block = functools.lru_cache(maxsize=cache_maxsize, typed=False)(self.get_sample_block)

def __str__(self):
return f"OnlineSampleMapping(dataset_size={self.dataset_size}, num_samples={self.num_samples}, block_size={self.block_size}, cache_maxsize={self.cache_maxsize}, seed={self.seed}, shuffle={self.shuffle}, truncate_to_block_boundary={self.truncate_to_block_boundary})"

def __getitem__(self, idx: int) -> int:
# handle slices
if isinstance(idx, slice):
slc = idx
start, stop, step = slc.start, slc.stop, slc.step

# Handle None values
start = handle_index(self, start if start is not None else 0)
if start >= self.num_samples:
start = self.num_samples
stop = handle_index(self, stop if stop is not None else self.num_samples)
if stop >= self.num_samples:
stop = self.num_samples
step = step if step is not None else 1
sample_slice = [self[idx] for idx in range(start, stop, step)]
return sample_slice
# handle indices
else:
# If the index is out of range, raise IndexError
if idx >= self.num_samples:
raise IndexError("Index out of range")

# support negative indices
if idx < 0:
idx += self.num_samples

if idx < 0:
raise IndexError("Index out of range")

# fetch the block sample index
if self.use_digitize:
block_idx = np.digitize(idx, self.block_bins)
else:
block_idx = idx // self.block_size
sample_block = self.get_sample_block(block_idx)

# use the local index to fetch the sample
local_idx = idx - self.block_bins[block_idx]
sample_idx = sample_block[local_idx]

return sample_idx, None, None # for comtability with NeMo's get_samples_mapping

def __len__(self) -> int:
return self.num_samples

def get_sample_block(self, block_idx: int) -> np.ndarray:
"""
Returns a block of samples of size self.block_size, shuffled if needed.
This method will be cached using functools.lru_cache for efficiency during construction.
"""
if block_idx >= self.num_blocks:
raise IndexError(f"block_idx {block_idx} is out of range. Maximum block_idx is {self.num_blocks-1}")

# recover index of original block (before shuffling)
start_idx = self.block_idx_list[block_idx] * self.block_size
end_idx = start_idx + self.block_size_list[block_idx]
sample_block = np.arange(start_idx, end_idx)

# shuffle if needed
if self.shuffle:
local_rng = np.random.RandomState(seed=self.seed + block_idx)
sample_block = local_rng.permutation(sample_block)

# project indices to the dataset size
sample_block = sample_block % self.dataset_size

return sample_block
Loading