diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index cfe83c575a..07c1a9c58b 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -13,11 +13,14 @@ # limitations under the License. import json import logging +import multiprocessing as mp from dataclasses import dataclass +from multiprocessing import Pool from pathlib import Path import numpy as np from megatron.core.msc_utils import MultiStorageClientFeature +from tqdm import tqdm from megatron.bridge.data.datasets.packing_utils import create_hist, create_packing_strategy, fill_packing_strategy from megatron.bridge.data.datasets.sft import create_sft_dataset @@ -26,6 +29,25 @@ logger = logging.getLogger(__name__) +_shared_dataset = None + + +def _tokenize_get_item(i): + return _shared_dataset[i] + + +def _tokenize_init_worker(dataset): + global _shared_dataset + _shared_dataset = dataset + + +def _retrieve_tokenized(dataset, num_workers): + if num_workers == 1: + return np.array([dataset[i] for i in tqdm(range(len(dataset)))]) + num_workers = num_workers if num_workers > 0 else mp.cpu_count() + with Pool(num_workers, initializer=_tokenize_init_worker, initargs=(dataset,)) as pool: + return np.array(list(tqdm(pool.imap(_tokenize_get_item, range(len(dataset))), total=len(dataset)))) + def tokenize_dataset( path: Path, @@ -34,6 +56,7 @@ def tokenize_dataset( seed: int, dataset_kwargs: dict | None = None, pad_seq_to_mult: int | None = 1, + num_tokenizer_workers: int = -1, ): """ Tokenizes a dataset from the provided path using the specified tokenizer @@ -88,7 +111,7 @@ def tokenize_dataset( pad_id = dataset.tokenizer.eod pad_seq_length_to_mult = dataset.pad_seq_length_to_mult max_seq_length = dataset.max_seq_length - dataset = np.array([dataset[i] for i in range(len(dataset))]) + dataset = _retrieve_tokenized(dataset, num_tokenizer_workers) if pad_seq_to_mult > 1: @@ -132,6 +155,7 @@ def prepare_packed_sequence_data( packing_algorithm: str = "first_fit_shuffle", dataset_kwargs: dict | None = None, pad_seq_to_mult: int | None = 1, + num_tokenizer_workers: int = -1, ): """ Prepares a packed sequence dataset from a given input file and saves it to an output file. @@ -162,6 +186,7 @@ def prepare_packed_sequence_data( seed, dataset_kwargs, pad_seq_to_mult=pad_seq_to_mult, + num_tokenizer_workers=num_tokenizer_workers, ) sequences, histogram = create_hist(dataset, max_seq_length) @@ -220,6 +245,12 @@ class PackedSequenceSpecs: This field is set by llm.finetune api. """ + num_tokenizer_workers: int = -1 + """ + The number of worker processes to use for tokenization when preparing the packed sequence dataset. + If -1, the number of workers will be set to the number of CPU cores available + """ + packed_train_data_path: str = None """ If specified, use this file for the packed training dataset instead of the default path. diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index b53edd3563..8606eb5e77 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int: +def find_first_bin_that_fits(bin_sums: List[int], s: int, bin_size: int) -> int: """ Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'. @@ -37,8 +37,8 @@ def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> in Returns: The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. """ - for i, abin in enumerate(bins): - if sum(abin) + s <= bin_size: + for i, cur_sum in enumerate(bin_sums): + if cur_sum + s <= bin_size: return i return -1 @@ -56,12 +56,15 @@ def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]: of the sequences assigned to that bin. """ res = [] - for s in seqlens: - first_bin = find_first_bin_that_fits(res, s, pack_size) + res_sums = [] + for s in tqdm(seqlens): + first_bin = find_first_bin_that_fits(res_sums, s, pack_size) if first_bin == -1: # open a new bin res.append([s]) + res_sums.append(s) else: res[first_bin].append(s) + res_sums[first_bin] += s return res diff --git a/tests/functional_tests/data/datasets/test_packing_utils.py b/tests/functional_tests/data/datasets/test_packing_utils.py index cc53be34a7..6b08311110 100644 --- a/tests/functional_tests/data/datasets/test_packing_utils.py +++ b/tests/functional_tests/data/datasets/test_packing_utils.py @@ -34,14 +34,15 @@ def test_find_first_bin_that_fits(self): [17, 11, 0, -5], [100, 200], ] + bin_sums = list(map(sum, bins)) bin_size = 1 s = 11 - first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size) + first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size) assert first_bin_that_fits == -1 bin_size = 1000 - first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size) + first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size) assert first_bin_that_fits == 1