From 399dd7f988f6ccbb40268bed6474a355d57204ec Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Mon, 16 Feb 2026 20:20:19 +0200 Subject: [PATCH 1/5] Enables packed sequence dataset in mmap format Adds support for creating and using packed sequence datasets in `mmap` format for faster loading and reduced memory usage. This includes changes to: - Allow specifying `packed_sequence_format` as `mmap` in `PackedSequenceSpecs` - Modify the data preparation to save `mmap` compatible data files with `.idx.npy` and `.bin` extensions - Update dataset loading logic to use `numpy.memmap` for zero-copy access. - Adds number of tokenizer workers to `prepare_packed_sequence_data` to speed up tokenization. --- .../data/builders/finetuning_dataset.py | 29 ++++- .../bridge/data/datasets/packed_sequence.py | 109 +++++++++++++++++- .../bridge/data/datasets/packing_utils.py | 15 ++- src/megatron/bridge/data/datasets/sft.py | 62 ++++++++++ 4 files changed, 197 insertions(+), 18 deletions(-) diff --git a/src/megatron/bridge/data/builders/finetuning_dataset.py b/src/megatron/bridge/data/builders/finetuning_dataset.py index 0dff48fda1..a6f21eade4 100644 --- a/src/megatron/bridge/data/builders/finetuning_dataset.py +++ b/src/megatron/bridge/data/builders/finetuning_dataset.py @@ -95,7 +95,7 @@ def prepare_packed_data(self) -> None: if self.packed_sequence_size > 0: from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data - if not self.train_path_packed.is_file(): + if not self._packed_path_exists(self.train_path_packed): print_rank_0(f"Preparing packed training data at {self.train_path_packed}") prepare_packed_sequence_data( input_path=self.train_path, @@ -107,9 +107,11 @@ def prepare_packed_data(self) -> None: output_metadata_path=self.pack_metadata, dataset_kwargs=self.dataset_kwargs, pad_seq_to_mult=self._pad_seq_to_mult, + num_tokenizer_workers=self.packed_sequence_specs.num_tokenizer_workers, + save_format=self.packed_sequence_specs.packed_sequence_format, ) - if self.do_validation and not self.validation_path_packed.is_file(): + if self.do_validation and not self._packed_path_exists(self.validation_path_packed): print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}") prepare_packed_sequence_data( input_path=self.validation_path, @@ -121,6 +123,8 @@ def prepare_packed_data(self) -> None: output_metadata_path=self.pack_metadata, dataset_kwargs=self.dataset_kwargs, pad_seq_to_mult=self._pad_seq_to_mult, + num_tokenizer_workers=self.packed_sequence_specs.num_tokenizer_workers, + save_format=self.packed_sequence_specs.packed_sequence_format, ) def build(self) -> list[Optional[Any]]: @@ -155,6 +159,7 @@ def _build_datasets(self) -> list[Optional[Any]]: train_ds = self._create_dataset( self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata, + pack_save_format=None if self.packed_sequence_size <= 0 else self.packed_sequence_specs.packed_sequence_format, max_num_samples=self.max_train_samples, **self.dataset_kwargs, ) @@ -163,6 +168,7 @@ def _build_datasets(self) -> list[Optional[Any]]: valid_ds = self._create_dataset( self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed, pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata, + pack_save_format=None if self.packed_sequence_size <= 0 else self.packed_sequence_specs.packed_sequence_format, is_test=True, **self.dataset_kwargs, ) @@ -184,6 +190,7 @@ def _create_dataset( self, path: Union[str, Path], pack_metadata_path: Optional[Union[str, Path]] = None, + pack_save_format: Optional[str] = None, is_test: bool = False, **kwargs: Any, ) -> Optional[Any]: @@ -200,9 +207,14 @@ def _create_dataset( """ if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() - path_exists = msc.Path(path).exists() + if pack_save_format == "mmap": + path_exists = msc.Path(str(path) + '.idx.npy').is_file() and msc.Path(str(path) + '.bin').is_file() + else: path_exists = msc.Path(path).exists() else: - path_exists = Path(path).exists() + if pack_save_format == "mmap": + path_exists = Path(str(path) + '.idx.npy').is_file() and Path(str(path) + '.bin').is_file() + else: + path_exists = Path(path).exists() if not path_exists: print_rank_0(f"Warning: Dataset path {path} does not exist") @@ -283,7 +295,7 @@ def train_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_train_data_path is not None: return self.packed_sequence_specs.packed_train_data_path - return self.default_pack_path / f"training_{self.packed_sequence_size}.npy" + return self.default_pack_path / f"training_{self.packed_sequence_size}.{self.packed_sequence_specs.packed_sequence_format}" else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") @@ -303,7 +315,7 @@ def validation_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_val_data_path is not None: return self.packed_sequence_specs.packed_val_data_path - return self.default_pack_path / f"validation_{self.packed_sequence_size}.npy" + return self.default_pack_path / f"validation_{self.packed_sequence_size}.{self.packed_sequence_specs.packed_sequence_format}" else: raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.") @@ -340,3 +352,8 @@ def _extract_tokenizer_model_name(self) -> str: return tokenizer_model_name else: return f"unknown_tokenizer_{hash(self.tokenizer)}" + + def _packed_path_exists(self, path: Path) -> bool: + if self.packed_sequence_specs.packed_sequence_format == "mmap": + return path.with_suffix(path.suffix + '.idx.npy').is_file() and path.with_suffix(path.suffix + '.bin').is_file() + return path.is_file() diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index cfe83c575a..e4c34085b4 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -15,6 +15,10 @@ import logging from dataclasses import dataclass from pathlib import Path +from multiprocessing import Pool +import multiprocessing as mp +from tqdm import tqdm +from typing import Literal, List import numpy as np from megatron.core.msc_utils import MultiStorageClientFeature @@ -26,6 +30,20 @@ logger = logging.getLogger(__name__) +_shared_dataset = None +def _tokenize_get_item(i): + return _shared_dataset[i] + +def _tokenize_init_worker(dataset): + global _shared_dataset + _shared_dataset = dataset + +def _retrieve_tokenized(dataset, num_workers): + if num_workers == 1: + return np.array([dataset[i] for i in tqdm(range(len(dataset)))]) + num_workers = num_workers if num_workers > 0 else mp.cpu_count() + with Pool(num_workers, initializer=_tokenize_init_worker, initargs=(dataset,)) as pool: + return np.array(list(tqdm(pool.imap(_tokenize_get_item, range(len(dataset))), total=len(dataset)))) def tokenize_dataset( path: Path, @@ -34,6 +52,7 @@ def tokenize_dataset( seed: int, dataset_kwargs: dict | None = None, pad_seq_to_mult: int | None = 1, + num_tokenizer_workers: int = -1, ): """ Tokenizes a dataset from the provided path using the specified tokenizer @@ -88,7 +107,7 @@ def tokenize_dataset( pad_id = dataset.tokenizer.eod pad_seq_length_to_mult = dataset.pad_seq_length_to_mult max_seq_length = dataset.max_seq_length - dataset = np.array([dataset[i] for i in range(len(dataset))]) + dataset = _retrieve_tokenized(dataset, num_tokenizer_workers) if pad_seq_to_mult > 1: @@ -120,6 +139,72 @@ def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): return dataset +def save_packed_dataset(packed_data: List[dict], output_path: Path, save_format: Literal["npy", "mmap"] = "npy"): + logging.info(f'Saving packed_sequence index to {output_path} in {save_format} format...') + + """ + Saves the packed dataset to the specified output path in the given format. + + Args: + packed_data (List[dict]): The packed dataset to be saved. + output_path (Path): The path where the packed dataset should be saved. + save_format (Literal["npy", "mmap"]): The file format to save the packed dataset. + Can be either "npy" or "mmap". Defaults to "npy". + + Returns: + None: The function saves the packed dataset to the specified output path. + """ + if save_format == "npy": + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + msc.numpy.save(output_path, packed_data) + else: + np.save(output_path, packed_data) + elif save_format == "mmap": + if not output_path.parent.exists(): + # the msc.numpy.memmap doesn't work properly with mode="w+" + raise NotImplementedError("Packed sequence dataset in 'mmap' format is not supported on non-local filesystems. Please use 'npy' format or save to a local directory.") + + """Convert npy to memmap format with .idx metadata and .bin data""" + # Compute metadata + print('Computing metadata...') + input_lens = np.array([len(item["input_ids"]) for item in packed_data], dtype=np.int32) + loss_lens = np.array([len(item["loss_mask"]) for item in packed_data], dtype=np.int32) + seq_lens = np.array([len(item["seq_start_id"]) for item in packed_data], dtype=np.int32) + + # Compute byte offsets for each sample + bytes_per_sample = (input_lens + loss_lens + seq_lens) * 4 # int32 = 4 bytes + offsets = np.concatenate([[0], np.cumsum(bytes_per_sample)[:-1]]) + + # Save index file + idx_file = output_path.with_suffix(output_path.suffix + '.idx.npy') + np.save(idx_file, { + 'offsets': offsets, + 'input_lens': input_lens, + 'loss_lens': loss_lens, + 'seq_lens': seq_lens + }) + + # Write data file + bin_file = output_path.with_suffix(output_path.suffix + '.bin') + + total_ints = int(np.sum(bytes_per_sample) // 4) + mmap = np.memmap(bin_file, dtype=np.int32, mode='w+', shape=(total_ints,)) + pos = 0 + for item in tqdm(packed_data): + inp = np.array(item["input_ids"], dtype=np.int32) + loss = np.array(item["loss_mask"], dtype=np.int32) + seq = np.array(item["seq_start_id"], dtype=np.int32) + + chunk_len = len(inp) + len(loss) + len(seq) + mmap[pos:pos+len(inp)] = inp + mmap[pos+len(inp):pos+len(inp)+len(loss)] = loss + mmap[pos+len(inp)+len(loss):pos+chunk_len] = seq + pos += chunk_len + mmap.flush() + del mmap + else: + raise ValueError(f"Unsupported save format: {save_format}") def prepare_packed_sequence_data( input_path: Path, @@ -132,6 +217,8 @@ def prepare_packed_sequence_data( packing_algorithm: str = "first_fit_shuffle", dataset_kwargs: dict | None = None, pad_seq_to_mult: int | None = 1, + num_tokenizer_workers: int = -1, + save_format: Literal["npy", "mmap"] = "npy", ): """ Prepares a packed sequence dataset from a given input file and saves it to an output file. @@ -162,6 +249,7 @@ def prepare_packed_sequence_data( seed, dataset_kwargs, pad_seq_to_mult=pad_seq_to_mult, + num_tokenizer_workers=num_tokenizer_workers ) sequences, histogram = create_hist(dataset, max_seq_length) @@ -169,11 +257,7 @@ def prepare_packed_sequence_data( output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id) # save output data - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - msc.numpy.save(output_path, output_data) - else: - np.save(output_path, output_data) + save_packed_dataset(output_data, output_path, save_format) # save packing metadata, packing_metadata is appended to the packing file if it exists if output_metadata_path is not None: @@ -214,12 +298,25 @@ class PackedSequenceSpecs: of the original sequence (i.e. the length to truncate long sequences in the input data). """ + packed_sequence_format: Literal["npy", "mmap"] = "npy" + """ + The file format for the packed sequence dataset. Can be either "npy" (default) or "mmap". + "mmap" produces two files: .mmap.idx.npy and .mmap.bin. The `packed_train_data_path` and `packed_val_data_path` should point + to the prefix of the .idx.npy and .bin files (i.e. without the .idx.npy or .bin suffix). + """ + tokenizer_model_name: str = None """ Keep track of tokenizer model name, since each tokenizer produces a different packed sequence dataset file. This field is set by llm.finetune api. """ + num_tokenizer_workers: int = -1 + """ + The number of worker processes to use for tokenization when preparing the packed sequence dataset. + If -1, the number of workers will be set to the number of CPU cores available + """ + packed_train_data_path: str = None """ If specified, use this file for the packed training dataset instead of the default path. diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index dd36f5e1f4..1d0b2ed722 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int: +def find_first_bin_that_fits(bin_sums: List[int], s: int, bin_size: int) -> int: """ Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'. @@ -37,9 +37,9 @@ def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> in Returns: The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. """ - for i, abin in enumerate(bins): - if sum(abin) + s <= bin_size: - return i + for i,cur_sum in enumerate(bin_sums): + if cur_sum + s <= bin_size: + return i return -1 @@ -56,12 +56,15 @@ def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]: of the sequences assigned to that bin. """ res = [] - for s in seqlens: - first_bin = find_first_bin_that_fits(res, s, pack_size) + res_sums = [] + for s in tqdm(seqlens): + first_bin = find_first_bin_that_fits(res_sums, s, pack_size) if first_bin == -1: # open a new bin res.append([s]) + res_sums.append(s) else: res[first_bin].append(s) + res_sums[first_bin] += s return res diff --git a/src/megatron/bridge/data/datasets/sft.py b/src/megatron/bridge/data/datasets/sft.py index f29900223b..cfc31eb3f3 100644 --- a/src/megatron/bridge/data/datasets/sft.py +++ b/src/megatron/bridge/data/datasets/sft.py @@ -180,6 +180,13 @@ def create_sft_dataset( **gpt_sft_dataset_kwargs, **kwargs, ) + elif path.suffix == '.mmap': + return GPTSFTMMapPackedDataset( + pack_metadata_file_path=pack_metadata_file_path, + pad_cu_seqlens=pad_cu_seqlens, + **gpt_sft_dataset_kwargs, + **kwargs, + ) elif chat: return GPTSFTChatDataset( **gpt_sft_dataset_kwargs, @@ -1012,6 +1019,16 @@ def collate_fn(self, batch): return processed_batch +class GPTSFTMMapPackedDataset(GPTSFTPackedDataset): + def _load_dataset(self): + try: + self.indexed_dataset = MemmapPackedDataset(self.file_path) + except Exception as e: + logger.error( + f"Failed to load packed dataset. The dataset should be a pair of `.idx.npy` and '.bin' files. " + f"Please check if the packed dataset was prepared correctly. The original error was:\n {e}", + ) + exit(1) class GPTSFTChatDataset(GPTSFTDataset): """Dataset class for chat-based fine-tuning with optional HuggingFace chat template support. @@ -1228,3 +1245,48 @@ def collate_fn(self, batch): processed_batch["attention_mask"] = attention_mask return processed_batch + +class MemmapPackedDataset: + """Zero-copy dataset using numpy memmap""" + + def __init__(self, path_prefix): + self.path_prefix = path_prefix + if False and MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + self.mmap = msc.numpy.memmap(self.path_prefix + '.bin', dtype=np.int32, mode='r') + idx_data = msc.numpy.load(self.path_prefix + '.idx.npy', allow_pickle=True).item() + else: + self.mmap = np.memmap(self.path_prefix + '.bin', dtype=np.int32, mode='r') + idx_data = np.load(self.path_prefix + '.idx.npy', allow_pickle=True).item() + + # Load index + self.offsets = idx_data['offsets'] + self.input_lens = idx_data['input_lens'] + self.loss_lens = idx_data['loss_lens'] + self.seq_lens = idx_data['seq_lens'] + self.length = len(self.offsets) + self.num_tokens = np.sum(self.input_lens) + + def __getitem__(self, idx): + offset = self.offsets[idx] // 4 # byte offset -> int32 offset + inp_len = self.input_lens[idx] + loss_len = self.loss_lens[idx] + seq_len = self.seq_lens[idx] + + inp_end = offset + inp_len + loss_end = inp_end + loss_len + seq_end = loss_end + seq_len + + ret = { + "input_ids": self.mmap[offset:inp_end].tolist(), + "loss_mask": self.mmap[inp_end:loss_end].astype(bool).tolist(), + "seq_start_id": self.mmap[loss_end:seq_end].tolist() + } + return ret + + def __len__(self): + return self.length + + def __del__(self): + if hasattr(self, 'mmap') and self.mmap is not None: + del self.mmap \ No newline at end of file From c02660b09808629cae69f2101dba02a382ef6fe8 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Wed, 11 Mar 2026 23:50:17 +0200 Subject: [PATCH 2/5] Removed memmap implementation, left optimizations --- .../data/builders/finetuning_dataset.py | 29 ++----- .../bridge/data/datasets/packed_sequence.py | 81 ++----------------- src/megatron/bridge/data/datasets/sft.py | 62 -------------- 3 files changed, 11 insertions(+), 161 deletions(-) diff --git a/src/megatron/bridge/data/builders/finetuning_dataset.py b/src/megatron/bridge/data/builders/finetuning_dataset.py index a6f21eade4..0dff48fda1 100644 --- a/src/megatron/bridge/data/builders/finetuning_dataset.py +++ b/src/megatron/bridge/data/builders/finetuning_dataset.py @@ -95,7 +95,7 @@ def prepare_packed_data(self) -> None: if self.packed_sequence_size > 0: from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data - if not self._packed_path_exists(self.train_path_packed): + if not self.train_path_packed.is_file(): print_rank_0(f"Preparing packed training data at {self.train_path_packed}") prepare_packed_sequence_data( input_path=self.train_path, @@ -107,11 +107,9 @@ def prepare_packed_data(self) -> None: output_metadata_path=self.pack_metadata, dataset_kwargs=self.dataset_kwargs, pad_seq_to_mult=self._pad_seq_to_mult, - num_tokenizer_workers=self.packed_sequence_specs.num_tokenizer_workers, - save_format=self.packed_sequence_specs.packed_sequence_format, ) - if self.do_validation and not self._packed_path_exists(self.validation_path_packed): + if self.do_validation and not self.validation_path_packed.is_file(): print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}") prepare_packed_sequence_data( input_path=self.validation_path, @@ -123,8 +121,6 @@ def prepare_packed_data(self) -> None: output_metadata_path=self.pack_metadata, dataset_kwargs=self.dataset_kwargs, pad_seq_to_mult=self._pad_seq_to_mult, - num_tokenizer_workers=self.packed_sequence_specs.num_tokenizer_workers, - save_format=self.packed_sequence_specs.packed_sequence_format, ) def build(self) -> list[Optional[Any]]: @@ -159,7 +155,6 @@ def _build_datasets(self) -> list[Optional[Any]]: train_ds = self._create_dataset( self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata, - pack_save_format=None if self.packed_sequence_size <= 0 else self.packed_sequence_specs.packed_sequence_format, max_num_samples=self.max_train_samples, **self.dataset_kwargs, ) @@ -168,7 +163,6 @@ def _build_datasets(self) -> list[Optional[Any]]: valid_ds = self._create_dataset( self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed, pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata, - pack_save_format=None if self.packed_sequence_size <= 0 else self.packed_sequence_specs.packed_sequence_format, is_test=True, **self.dataset_kwargs, ) @@ -190,7 +184,6 @@ def _create_dataset( self, path: Union[str, Path], pack_metadata_path: Optional[Union[str, Path]] = None, - pack_save_format: Optional[str] = None, is_test: bool = False, **kwargs: Any, ) -> Optional[Any]: @@ -207,14 +200,9 @@ def _create_dataset( """ if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() - if pack_save_format == "mmap": - path_exists = msc.Path(str(path) + '.idx.npy').is_file() and msc.Path(str(path) + '.bin').is_file() - else: path_exists = msc.Path(path).exists() + path_exists = msc.Path(path).exists() else: - if pack_save_format == "mmap": - path_exists = Path(str(path) + '.idx.npy').is_file() and Path(str(path) + '.bin').is_file() - else: - path_exists = Path(path).exists() + path_exists = Path(path).exists() if not path_exists: print_rank_0(f"Warning: Dataset path {path} does not exist") @@ -295,7 +283,7 @@ def train_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_train_data_path is not None: return self.packed_sequence_specs.packed_train_data_path - return self.default_pack_path / f"training_{self.packed_sequence_size}.{self.packed_sequence_specs.packed_sequence_format}" + return self.default_pack_path / f"training_{self.packed_sequence_size}.npy" else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") @@ -315,7 +303,7 @@ def validation_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_val_data_path is not None: return self.packed_sequence_specs.packed_val_data_path - return self.default_pack_path / f"validation_{self.packed_sequence_size}.{self.packed_sequence_specs.packed_sequence_format}" + return self.default_pack_path / f"validation_{self.packed_sequence_size}.npy" else: raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.") @@ -352,8 +340,3 @@ def _extract_tokenizer_model_name(self) -> str: return tokenizer_model_name else: return f"unknown_tokenizer_{hash(self.tokenizer)}" - - def _packed_path_exists(self, path: Path) -> bool: - if self.packed_sequence_specs.packed_sequence_format == "mmap": - return path.with_suffix(path.suffix + '.idx.npy').is_file() and path.with_suffix(path.suffix + '.bin').is_file() - return path.is_file() diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index e4c34085b4..5739f92f32 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -139,73 +139,6 @@ def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): return dataset -def save_packed_dataset(packed_data: List[dict], output_path: Path, save_format: Literal["npy", "mmap"] = "npy"): - logging.info(f'Saving packed_sequence index to {output_path} in {save_format} format...') - - """ - Saves the packed dataset to the specified output path in the given format. - - Args: - packed_data (List[dict]): The packed dataset to be saved. - output_path (Path): The path where the packed dataset should be saved. - save_format (Literal["npy", "mmap"]): The file format to save the packed dataset. - Can be either "npy" or "mmap". Defaults to "npy". - - Returns: - None: The function saves the packed dataset to the specified output path. - """ - if save_format == "npy": - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - msc.numpy.save(output_path, packed_data) - else: - np.save(output_path, packed_data) - elif save_format == "mmap": - if not output_path.parent.exists(): - # the msc.numpy.memmap doesn't work properly with mode="w+" - raise NotImplementedError("Packed sequence dataset in 'mmap' format is not supported on non-local filesystems. Please use 'npy' format or save to a local directory.") - - """Convert npy to memmap format with .idx metadata and .bin data""" - # Compute metadata - print('Computing metadata...') - input_lens = np.array([len(item["input_ids"]) for item in packed_data], dtype=np.int32) - loss_lens = np.array([len(item["loss_mask"]) for item in packed_data], dtype=np.int32) - seq_lens = np.array([len(item["seq_start_id"]) for item in packed_data], dtype=np.int32) - - # Compute byte offsets for each sample - bytes_per_sample = (input_lens + loss_lens + seq_lens) * 4 # int32 = 4 bytes - offsets = np.concatenate([[0], np.cumsum(bytes_per_sample)[:-1]]) - - # Save index file - idx_file = output_path.with_suffix(output_path.suffix + '.idx.npy') - np.save(idx_file, { - 'offsets': offsets, - 'input_lens': input_lens, - 'loss_lens': loss_lens, - 'seq_lens': seq_lens - }) - - # Write data file - bin_file = output_path.with_suffix(output_path.suffix + '.bin') - - total_ints = int(np.sum(bytes_per_sample) // 4) - mmap = np.memmap(bin_file, dtype=np.int32, mode='w+', shape=(total_ints,)) - pos = 0 - for item in tqdm(packed_data): - inp = np.array(item["input_ids"], dtype=np.int32) - loss = np.array(item["loss_mask"], dtype=np.int32) - seq = np.array(item["seq_start_id"], dtype=np.int32) - - chunk_len = len(inp) + len(loss) + len(seq) - mmap[pos:pos+len(inp)] = inp - mmap[pos+len(inp):pos+len(inp)+len(loss)] = loss - mmap[pos+len(inp)+len(loss):pos+chunk_len] = seq - pos += chunk_len - mmap.flush() - del mmap - else: - raise ValueError(f"Unsupported save format: {save_format}") - def prepare_packed_sequence_data( input_path: Path, output_path: Path, @@ -218,7 +151,6 @@ def prepare_packed_sequence_data( dataset_kwargs: dict | None = None, pad_seq_to_mult: int | None = 1, num_tokenizer_workers: int = -1, - save_format: Literal["npy", "mmap"] = "npy", ): """ Prepares a packed sequence dataset from a given input file and saves it to an output file. @@ -257,7 +189,11 @@ def prepare_packed_sequence_data( output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id) # save output data - save_packed_dataset(output_data, output_path, save_format) + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + msc.numpy.save(output_path, output_data) + else: + np.save(output_path, output_data) # save packing metadata, packing_metadata is appended to the packing file if it exists if output_metadata_path is not None: @@ -298,13 +234,6 @@ class PackedSequenceSpecs: of the original sequence (i.e. the length to truncate long sequences in the input data). """ - packed_sequence_format: Literal["npy", "mmap"] = "npy" - """ - The file format for the packed sequence dataset. Can be either "npy" (default) or "mmap". - "mmap" produces two files: .mmap.idx.npy and .mmap.bin. The `packed_train_data_path` and `packed_val_data_path` should point - to the prefix of the .idx.npy and .bin files (i.e. without the .idx.npy or .bin suffix). - """ - tokenizer_model_name: str = None """ Keep track of tokenizer model name, since each tokenizer produces a different packed sequence dataset file. diff --git a/src/megatron/bridge/data/datasets/sft.py b/src/megatron/bridge/data/datasets/sft.py index cfc31eb3f3..f29900223b 100644 --- a/src/megatron/bridge/data/datasets/sft.py +++ b/src/megatron/bridge/data/datasets/sft.py @@ -180,13 +180,6 @@ def create_sft_dataset( **gpt_sft_dataset_kwargs, **kwargs, ) - elif path.suffix == '.mmap': - return GPTSFTMMapPackedDataset( - pack_metadata_file_path=pack_metadata_file_path, - pad_cu_seqlens=pad_cu_seqlens, - **gpt_sft_dataset_kwargs, - **kwargs, - ) elif chat: return GPTSFTChatDataset( **gpt_sft_dataset_kwargs, @@ -1019,16 +1012,6 @@ def collate_fn(self, batch): return processed_batch -class GPTSFTMMapPackedDataset(GPTSFTPackedDataset): - def _load_dataset(self): - try: - self.indexed_dataset = MemmapPackedDataset(self.file_path) - except Exception as e: - logger.error( - f"Failed to load packed dataset. The dataset should be a pair of `.idx.npy` and '.bin' files. " - f"Please check if the packed dataset was prepared correctly. The original error was:\n {e}", - ) - exit(1) class GPTSFTChatDataset(GPTSFTDataset): """Dataset class for chat-based fine-tuning with optional HuggingFace chat template support. @@ -1245,48 +1228,3 @@ def collate_fn(self, batch): processed_batch["attention_mask"] = attention_mask return processed_batch - -class MemmapPackedDataset: - """Zero-copy dataset using numpy memmap""" - - def __init__(self, path_prefix): - self.path_prefix = path_prefix - if False and MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - self.mmap = msc.numpy.memmap(self.path_prefix + '.bin', dtype=np.int32, mode='r') - idx_data = msc.numpy.load(self.path_prefix + '.idx.npy', allow_pickle=True).item() - else: - self.mmap = np.memmap(self.path_prefix + '.bin', dtype=np.int32, mode='r') - idx_data = np.load(self.path_prefix + '.idx.npy', allow_pickle=True).item() - - # Load index - self.offsets = idx_data['offsets'] - self.input_lens = idx_data['input_lens'] - self.loss_lens = idx_data['loss_lens'] - self.seq_lens = idx_data['seq_lens'] - self.length = len(self.offsets) - self.num_tokens = np.sum(self.input_lens) - - def __getitem__(self, idx): - offset = self.offsets[idx] // 4 # byte offset -> int32 offset - inp_len = self.input_lens[idx] - loss_len = self.loss_lens[idx] - seq_len = self.seq_lens[idx] - - inp_end = offset + inp_len - loss_end = inp_end + loss_len - seq_end = loss_end + seq_len - - ret = { - "input_ids": self.mmap[offset:inp_end].tolist(), - "loss_mask": self.mmap[inp_end:loss_end].astype(bool).tolist(), - "seq_start_id": self.mmap[loss_end:seq_end].tolist() - } - return ret - - def __len__(self): - return self.length - - def __del__(self): - if hasattr(self, 'mmap') and self.mmap is not None: - del self.mmap \ No newline at end of file From e91e771713494646c0c82d810efb4e09a2013854 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Wed, 11 Mar 2026 23:52:18 +0200 Subject: [PATCH 3/5] Updated test --- tests/functional_tests/data/datasets/test_packing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/functional_tests/data/datasets/test_packing_utils.py b/tests/functional_tests/data/datasets/test_packing_utils.py index cc53be34a7..9115c27399 100644 --- a/tests/functional_tests/data/datasets/test_packing_utils.py +++ b/tests/functional_tests/data/datasets/test_packing_utils.py @@ -34,14 +34,15 @@ def test_find_first_bin_that_fits(self): [17, 11, 0, -5], [100, 200], ] + bin_sums = list(map(sum, bins)) bin_size = 1 s = 11 - first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size) + first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size) assert first_bin_that_fits == -1 bin_size = 1000 - first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size) + first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size) assert first_bin_that_fits == 1 From 96fde100649e6b63f806993d08e59481053fb99b Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 12 Mar 2026 00:03:52 +0200 Subject: [PATCH 4/5] Addressed coderabbit comments --- src/megatron/bridge/data/datasets/packed_sequence.py | 8 ++++---- src/megatron/bridge/data/datasets/packing_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index 5739f92f32..fd2ba23a32 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -13,14 +13,14 @@ # limitations under the License. import json import logging +import multiprocessing as mp from dataclasses import dataclass -from pathlib import Path from multiprocessing import Pool -import multiprocessing as mp -from tqdm import tqdm -from typing import Literal, List +from pathlib import Path import numpy as np +from tqdm import tqdm + from megatron.core.msc_utils import MultiStorageClientFeature from megatron.bridge.data.datasets.packing_utils import create_hist, create_packing_strategy, fill_packing_strategy diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index 1d0b2ed722..28312fd6d2 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -39,7 +39,7 @@ def find_first_bin_that_fits(bin_sums: List[int], s: int, bin_size: int) -> int: """ for i,cur_sum in enumerate(bin_sums): if cur_sum + s <= bin_size: - return i + return i return -1 From f3dd8745471cf6eb07628d6a4cd87a07a45aa178 Mon Sep 17 00:00:00 2001 From: Shaltiel Shmidman Date: Thu, 12 Mar 2026 00:32:54 +0200 Subject: [PATCH 5/5] pre-commit run --all-files --- src/megatron/bridge/data/datasets/packed_sequence.py | 11 ++++++++--- src/megatron/bridge/data/datasets/packing_utils.py | 2 +- .../data/datasets/test_packing_utils.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index fd2ba23a32..07c1a9c58b 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -19,9 +19,8 @@ from pathlib import Path import numpy as np -from tqdm import tqdm - from megatron.core.msc_utils import MultiStorageClientFeature +from tqdm import tqdm from megatron.bridge.data.datasets.packing_utils import create_hist, create_packing_strategy, fill_packing_strategy from megatron.bridge.data.datasets.sft import create_sft_dataset @@ -31,13 +30,17 @@ logger = logging.getLogger(__name__) _shared_dataset = None + + def _tokenize_get_item(i): return _shared_dataset[i] + def _tokenize_init_worker(dataset): global _shared_dataset _shared_dataset = dataset + def _retrieve_tokenized(dataset, num_workers): if num_workers == 1: return np.array([dataset[i] for i in tqdm(range(len(dataset)))]) @@ -45,6 +48,7 @@ def _retrieve_tokenized(dataset, num_workers): with Pool(num_workers, initializer=_tokenize_init_worker, initargs=(dataset,)) as pool: return np.array(list(tqdm(pool.imap(_tokenize_get_item, range(len(dataset))), total=len(dataset)))) + def tokenize_dataset( path: Path, tokenizer: MegatronTokenizer, @@ -139,6 +143,7 @@ def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): return dataset + def prepare_packed_sequence_data( input_path: Path, output_path: Path, @@ -181,7 +186,7 @@ def prepare_packed_sequence_data( seed, dataset_kwargs, pad_seq_to_mult=pad_seq_to_mult, - num_tokenizer_workers=num_tokenizer_workers + num_tokenizer_workers=num_tokenizer_workers, ) sequences, histogram = create_hist(dataset, max_seq_length) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index 28312fd6d2..5fa31be02c 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -37,7 +37,7 @@ def find_first_bin_that_fits(bin_sums: List[int], s: int, bin_size: int) -> int: Returns: The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. """ - for i,cur_sum in enumerate(bin_sums): + for i, cur_sum in enumerate(bin_sums): if cur_sum + s <= bin_size: return i return -1 diff --git a/tests/functional_tests/data/datasets/test_packing_utils.py b/tests/functional_tests/data/datasets/test_packing_utils.py index 9115c27399..6b08311110 100644 --- a/tests/functional_tests/data/datasets/test_packing_utils.py +++ b/tests/functional_tests/data/datasets/test_packing_utils.py @@ -34,7 +34,7 @@ def test_find_first_bin_that_fits(self): [17, 11, 0, -5], [100, 200], ] - bin_sums = list(map(sum, bins)) + bin_sums = list(map(sum, bins)) bin_size = 1 s = 11 first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size)