From a83171a96809e039e743b658f7529a2149c238f4 Mon Sep 17 00:00:00 2001 From: arendu Date: Mon, 8 Jan 2024 23:56:08 +0000 Subject: [PATCH 1/2] adding OnlineSampleMapping Signed-off-by: arendu --- .../megatron/gpt_sft_dataset.py | 3 + .../language_modeling/text_memmap_dataset.py | 187 ++++++++++++++++++ 2 files changed, 190 insertions(+) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 63c4f3459682..1dad925cd6ff 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -26,6 +26,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset +from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import OnlineSampleMapping from nemo.core.classes import Dataset from nemo.utils import logging @@ -155,6 +156,7 @@ def _maybe_validate_prompt_template(self): def _build_samples_mapping(self): if self.max_num_samples is not None: + osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples) self.samples_mapping = get_samples_mapping( indexed_dataset=self.indexed_dataset, data_prefix=self.file_path, @@ -166,6 +168,7 @@ def _build_samples_mapping(self): name=self.file_path.split('/')[-1], binary_head=False, index_mapping_dir=self.index_mapping_dir, + samples_mapping=osm, ) else: self.samples_mapping = None diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index b5504d8f7cd1..a964bf6093fa 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import functools import json import multiprocessing as mp import os @@ -535,3 +536,189 @@ def build_index_files( logging.info( f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}" ) + + +def handle_index(dataset, idx): + """ + Remaps negative indices and handles numpy int indices. + + Arguments: + dataset (Dataset): dataset to index into + idx (int): Index. Can include negative indices. + Returns: + int: Remapped and fully qualified index. + + Raises: + IndexError: If a negative index is out of range. + + Examples: + >>> import numpy as np + >>> import torch + >>> from torch.utils.data import TensorDataset + >>> from nemo_chem.data.fasta_dataset import handle_index + >>> dataset = TensorDataset(torch.tensor(-np.arange(5))) + >>> handle_index(dataset, 1) + 1 + >>> handle_index(dataset, -2) + 3 + + """ + if idx < 0 and idx > -len(dataset) - 1: + idx = len(dataset) + idx + elif idx < 0: + raise IndexError(f'Index out of range: {idx}') + return idx + + +class OnlineSampleMapping: + """ + This class replaces NeMo's get_samples_mapping function which pre-computes. + It is used to create a sample mapping for certain number of samples, including + pseudo-random shuffling. + The sampler allows to down, or upsample a given dataset. + Shuffling leads to pseudo-random shuffling, where blocks are shuffled, + and each block is internally shuffled. + """ + + def __init__( + self, + dataset_size: int, + num_samples: int, + block_size: int = 1000000, + cache_maxsize: int = 2, + seed: int = 1, + shuffle: bool = True, + truncate_to_block_boundary: bool = False, + ): + """ + Args: + dataset_size (int): Size of the dataset. + num_samples (int): Number of samples the dataset should contain. + block_size (int): Size of each sample block. This is used to shuffle the samples. + None will be replaced with dataset size. + cache_maxsize (int): Maximum size of the blocks cache for the get_sample_block function. + seed (int): Seed for the random number generator used for shuffling. + shuffle (bool): Whether to shuffle the samples. + truncate_to_block_boundary (bool): Whether to truncate the last block to the block boundary (could drop samples). + """ + self.dataset_size = dataset_size + self.num_samples = num_samples + self.block_size = block_size if block_size is not None else self.dataset_size + self.cache_maxsize = cache_maxsize + self.seed = seed + self.shuffle = shuffle + self.truncate_to_block_boundary = truncate_to_block_boundary + + # we need at least num_samples (up-sampling) or dataset_size samples (correct down-sampling) + self.required_samples = max(self.num_samples, self.dataset_size) + # block size cannot be larger than dataset size + self.block_size = min(self.block_size, self.dataset_size) + # reduce the last block if needed, to match the required number of samples + last_block_size = self.required_samples % self.block_size + # store required blocks to cover num_samples samples and dataset_size samples + self.num_blocks = int(np.ceil(self.required_samples / self.block_size)) + + # if required, truncate the last block to the block boundary + if self.truncate_to_block_boundary and last_block_size: + # update num_samples to account for truncated last block only if needed + if self.required_samples == self.num_samples: + self.num_samples -= last_block_size + + # apdate num_blocks to account for truncated last block + self.num_blocks -= 1 + self.required_samples -= last_block_size + last_block_size = 0 + + # create a list of blocks (should cover the entire dataset for correct down sampling) + block_idx_list = np.arange(self.num_blocks) + # compute the size of each block + block_size_list = np.full(self.num_blocks, self.block_size) + if last_block_size: + block_size_list[-1] = last_block_size + self.use_digitize = True + else: + self.use_digitize = False + if shuffle: + local_rng = np.random.RandomState(seed=self.seed) + idx = local_rng.permutation(np.arange(self.num_blocks)) + block_idx_list = block_idx_list[idx] + block_size_list = block_size_list[idx] + + # store only required number of blocks + self.block_idx_list = block_idx_list + self.block_size_list = block_size_list + self.block_bins = np.cumsum(block_size_list) + + # NOTE: MAKE get_sample_block A CACHED FUNCTION!!! + self.get_sample_block = functools.lru_cache(maxsize=cache_maxsize, typed=False)(self.get_sample_block) + + def __str__(self): + return f"OnlineSampleMapping(dataset_size={self.dataset_size}, num_samples={self.num_samples}, block_size={self.block_size}, cache_maxsize={self.cache_maxsize}, seed={self.seed}, shuffle={self.shuffle}, truncate_to_block_boundary={self.truncate_to_block_boundary})" + + def __getitem__(self, idx: int) -> int: + # handle slices + if isinstance(idx, slice): + slc = idx + start, stop, step = slc.start, slc.stop, slc.step + + # Handle None values + start = handle_index(self, start if start is not None else 0) + if start >= self.num_samples: + start = self.num_samples + stop = handle_index(self, stop if stop is not None else self.num_samples) + if stop >= self.num_samples: + stop = self.num_samples + step = step if step is not None else 1 + sample_slice = [self[idx] for idx in range(start, stop, step)] + return sample_slice + # handle indices + else: + # If the index is out of range, raise IndexError + if idx >= self.num_samples: + raise IndexError("Index out of range") + + # support negative indices + if idx < 0: + idx += self.num_samples + + if idx < 0: + raise IndexError("Index out of range") + + # fetch the block sample index + if self.use_digitize: + block_idx = np.digitize(idx, self.block_bins) + else: + block_idx = idx // self.block_size + sample_block = self.get_sample_block(block_idx) + + # use the local index to fetch the sample + local_idx = idx - self.block_bins[block_idx] + sample_idx = sample_block[local_idx] + + return sample_idx, None, None # for comtability with NeMo's get_samples_mapping + + def __len__(self) -> int: + return self.num_samples + + def get_sample_block(self, block_idx: int) -> np.ndarray: + """ + Returns a block of samples of size self.block_size, shuffled if needed. + This method will be cached using functools.lru_cache for efficiency during construction. + """ + if block_idx >= self.num_blocks: + raise IndexError(f"block_idx {block_idx} is out of range. Maximum block_idx is {self.num_blocks-1}") + + # recover index of original block (before shuffling) + start_idx = self.block_idx_list[block_idx] * self.block_size + end_idx = start_idx + self.block_size_list[block_idx] + sample_block = np.arange(start_idx, end_idx) + + # shuffle if needed + if self.shuffle: + local_rng = np.random.RandomState(seed=self.seed + block_idx) + sample_block = local_rng.permutation(sample_block) + + # project indices to the dataset size + sample_block = sample_block % self.dataset_size + + return sample_block \ No newline at end of file From 854e7bb44c772032d51eb1099facde021bc6e5ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 23:58:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../nlp/data/language_modeling/megatron/gpt_sft_dataset.py | 3 +-- .../nlp/data/language_modeling/text_memmap_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 1dad925cd6ff..f99371876c87 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -25,8 +25,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import get_samples_mapping -from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset -from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import OnlineSampleMapping +from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import JSONLMemMapDataset, OnlineSampleMapping from nemo.core.classes import Dataset from nemo.utils import logging diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index a964bf6093fa..8065d489259b 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -695,7 +695,7 @@ def __getitem__(self, idx: int) -> int: local_idx = idx - self.block_bins[block_idx] sample_idx = sample_block[local_idx] - return sample_idx, None, None # for comtability with NeMo's get_samples_mapping + return sample_idx, None, None # for comtability with NeMo's get_samples_mapping def __len__(self) -> int: return self.num_samples @@ -721,4 +721,4 @@ def get_sample_block(self, block_idx: int) -> np.ndarray: # project indices to the dataset size sample_block = sample_block % self.dataset_size - return sample_block \ No newline at end of file + return sample_block