Skip to content

Commit

Permalink
adding OnlineSampleMapping (#8137)
Browse files Browse the repository at this point in the history
* adding OnlineSampleMapping

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] authored Jan 11, 2024
1 parent 8d4218e commit 6c006df
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset, OnlineSampleMapping
from nemo.core.classes import Dataset
from nemo.utils import logging

Expand Down Expand Up @@ -155,6 +155,7 @@ def _maybe_validate_prompt_template(self):

def _build_samples_mapping(self):
if self.max_num_samples is not None:
osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.indexed_dataset,
data_prefix=self.file_path,
Expand All @@ -166,6 +167,7 @@ def _build_samples_mapping(self):
name=self.file_path.split('/')[-1],
binary_head=False,
index_mapping_dir=self.index_mapping_dir,
samples_mapping=osm,
)
else:
self.samples_mapping = None
Expand Down
187 changes: 187 additions & 0 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import functools
import json
import multiprocessing as mp
import os
Expand Down Expand Up @@ -535,3 +536,189 @@ def build_index_files(
logging.info(
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)


def handle_index(dataset, idx):
"""
Remaps negative indices and handles numpy int indices.
Arguments:
dataset (Dataset): dataset to index into
idx (int): Index. Can include negative indices.
Returns:
int: Remapped and fully qualified index.
Raises:
IndexError: If a negative index is out of range.
Examples:
>>> import numpy as np
>>> import torch
>>> from torch.utils.data import TensorDataset
>>> from nemo_chem.data.fasta_dataset import handle_index
>>> dataset = TensorDataset(torch.tensor(-np.arange(5)))
>>> handle_index(dataset, 1)
1
>>> handle_index(dataset, -2)
3
"""
if idx < 0 and idx > -len(dataset) - 1:
idx = len(dataset) + idx
elif idx < 0:
raise IndexError(f'Index out of range: {idx}')
return idx


class OnlineSampleMapping:
"""
This class replaces NeMo's get_samples_mapping function which pre-computes.
It is used to create a sample mapping for certain number of samples, including
pseudo-random shuffling.
The sampler allows to down, or upsample a given dataset.
Shuffling leads to pseudo-random shuffling, where blocks are shuffled,
and each block is internally shuffled.
"""

def __init__(
self,
dataset_size: int,
num_samples: int,
block_size: int = 1000000,
cache_maxsize: int = 2,
seed: int = 1,
shuffle: bool = True,
truncate_to_block_boundary: bool = False,
):
"""
Args:
dataset_size (int): Size of the dataset.
num_samples (int): Number of samples the dataset should contain.
block_size (int): Size of each sample block. This is used to shuffle the samples.
None will be replaced with dataset size.
cache_maxsize (int): Maximum size of the blocks cache for the get_sample_block function.
seed (int): Seed for the random number generator used for shuffling.
shuffle (bool): Whether to shuffle the samples.
truncate_to_block_boundary (bool): Whether to truncate the last block to the block boundary (could drop samples).
"""
self.dataset_size = dataset_size
self.num_samples = num_samples
self.block_size = block_size if block_size is not None else self.dataset_size
self.cache_maxsize = cache_maxsize
self.seed = seed
self.shuffle = shuffle
self.truncate_to_block_boundary = truncate_to_block_boundary

# we need at least num_samples (up-sampling) or dataset_size samples (correct down-sampling)
self.required_samples = max(self.num_samples, self.dataset_size)
# block size cannot be larger than dataset size
self.block_size = min(self.block_size, self.dataset_size)
# reduce the last block if needed, to match the required number of samples
last_block_size = self.required_samples % self.block_size
# store required blocks to cover num_samples samples and dataset_size samples
self.num_blocks = int(np.ceil(self.required_samples / self.block_size))

# if required, truncate the last block to the block boundary
if self.truncate_to_block_boundary and last_block_size:
# update num_samples to account for truncated last block only if needed
if self.required_samples == self.num_samples:
self.num_samples -= last_block_size

# apdate num_blocks to account for truncated last block
self.num_blocks -= 1
self.required_samples -= last_block_size
last_block_size = 0

# create a list of blocks (should cover the entire dataset for correct down sampling)
block_idx_list = np.arange(self.num_blocks)
# compute the size of each block
block_size_list = np.full(self.num_blocks, self.block_size)
if last_block_size:
block_size_list[-1] = last_block_size
self.use_digitize = True
else:
self.use_digitize = False
if shuffle:
local_rng = np.random.RandomState(seed=self.seed)
idx = local_rng.permutation(np.arange(self.num_blocks))
block_idx_list = block_idx_list[idx]
block_size_list = block_size_list[idx]

# store only required number of blocks
self.block_idx_list = block_idx_list
self.block_size_list = block_size_list
self.block_bins = np.cumsum(block_size_list)

# NOTE: MAKE get_sample_block A CACHED FUNCTION!!!
self.get_sample_block = functools.lru_cache(maxsize=cache_maxsize, typed=False)(self.get_sample_block)

def __str__(self):
return f"OnlineSampleMapping(dataset_size={self.dataset_size}, num_samples={self.num_samples}, block_size={self.block_size}, cache_maxsize={self.cache_maxsize}, seed={self.seed}, shuffle={self.shuffle}, truncate_to_block_boundary={self.truncate_to_block_boundary})"

def __getitem__(self, idx: int) -> int:
# handle slices
if isinstance(idx, slice):
slc = idx
start, stop, step = slc.start, slc.stop, slc.step

# Handle None values
start = handle_index(self, start if start is not None else 0)
if start >= self.num_samples:
start = self.num_samples
stop = handle_index(self, stop if stop is not None else self.num_samples)
if stop >= self.num_samples:
stop = self.num_samples
step = step if step is not None else 1
sample_slice = [self[idx] for idx in range(start, stop, step)]
return sample_slice
# handle indices
else:
# If the index is out of range, raise IndexError
if idx >= self.num_samples:
raise IndexError("Index out of range")

# support negative indices
if idx < 0:
idx += self.num_samples

if idx < 0:
raise IndexError("Index out of range")

# fetch the block sample index
if self.use_digitize:
block_idx = np.digitize(idx, self.block_bins)
else:
block_idx = idx // self.block_size
sample_block = self.get_sample_block(block_idx)

# use the local index to fetch the sample
local_idx = idx - self.block_bins[block_idx]
sample_idx = sample_block[local_idx]

return sample_idx, None, None # for comtability with NeMo's get_samples_mapping

def __len__(self) -> int:
return self.num_samples

def get_sample_block(self, block_idx: int) -> np.ndarray:
"""
Returns a block of samples of size self.block_size, shuffled if needed.
This method will be cached using functools.lru_cache for efficiency during construction.
"""
if block_idx >= self.num_blocks:
raise IndexError(f"block_idx {block_idx} is out of range. Maximum block_idx is {self.num_blocks-1}")

# recover index of original block (before shuffling)
start_idx = self.block_idx_list[block_idx] * self.block_size
end_idx = start_idx + self.block_size_list[block_idx]
sample_block = np.arange(start_idx, end_idx)

# shuffle if needed
if self.shuffle:
local_rng = np.random.RandomState(seed=self.seed + block_idx)
sample_block = local_rng.permutation(sample_block)

# project indices to the dataset size
sample_block = sample_block % self.dataset_size

return sample_block

0 comments on commit 6c006df

Please sign in to comment.