Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/megatron/bridge/data/datasets/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
import json
import logging
import multiprocessing as mp
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path

import numpy as np
from megatron.core.msc_utils import MultiStorageClientFeature
from tqdm import tqdm

from megatron.bridge.data.datasets.packing_utils import create_hist, create_packing_strategy, fill_packing_strategy
from megatron.bridge.data.datasets.sft import create_sft_dataset
Expand All @@ -26,6 +29,25 @@

logger = logging.getLogger(__name__)

_shared_dataset = None


def _tokenize_get_item(i):
return _shared_dataset[i]


def _tokenize_init_worker(dataset):
global _shared_dataset
_shared_dataset = dataset


def _retrieve_tokenized(dataset, num_workers):
if num_workers == 1:
return np.array([dataset[i] for i in tqdm(range(len(dataset)))])
num_workers = num_workers if num_workers > 0 else mp.cpu_count()
with Pool(num_workers, initializer=_tokenize_init_worker, initargs=(dataset,)) as pool:
return np.array(list(tqdm(pool.imap(_tokenize_get_item, range(len(dataset))), total=len(dataset))))


def tokenize_dataset(
path: Path,
Expand All @@ -34,6 +56,7 @@ def tokenize_dataset(
seed: int,
dataset_kwargs: dict | None = None,
pad_seq_to_mult: int | None = 1,
num_tokenizer_workers: int = -1,
):
"""
Tokenizes a dataset from the provided path using the specified tokenizer
Expand Down Expand Up @@ -88,7 +111,7 @@ def tokenize_dataset(
pad_id = dataset.tokenizer.eod
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
max_seq_length = dataset.max_seq_length
dataset = np.array([dataset[i] for i in range(len(dataset))])
dataset = _retrieve_tokenized(dataset, num_tokenizer_workers)

if pad_seq_to_mult > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential TypeError when pad_seq_to_mult is None.

If pad_seq_to_mult is None, the comparison pad_seq_to_mult > 1 will raise a TypeError in Python 3. The parameter has a default of 1 but can explicitly be passed as None.

Proposed fix
-    if pad_seq_to_mult > 1:
+    if pad_seq_to_mult is not None and pad_seq_to_mult > 1:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/datasets/packed_sequence.py` at line 112, The
condition `if pad_seq_to_mult > 1:` can raise TypeError when pad_seq_to_mult is
None; change the guard to explicitly handle None (e.g. check `pad_seq_to_mult is
not None and pad_seq_to_mult > 1`) or normalize pad_seq_to_mult to an int before
use; update the check wherever `pad_seq_to_mult` is evaluated in
packed_sequence.py (the variable `pad_seq_to_mult` and any function/method that
uses it) and ensure you also validate/coerce its type so comparisons always
occur against an int.


Expand Down Expand Up @@ -132,6 +155,7 @@ def prepare_packed_sequence_data(
packing_algorithm: str = "first_fit_shuffle",
dataset_kwargs: dict | None = None,
pad_seq_to_mult: int | None = 1,
num_tokenizer_workers: int = -1,
):
"""
Prepares a packed sequence dataset from a given input file and saves it to an output file.
Expand Down Expand Up @@ -162,6 +186,7 @@ def prepare_packed_sequence_data(
seed,
dataset_kwargs,
pad_seq_to_mult=pad_seq_to_mult,
num_tokenizer_workers=num_tokenizer_workers,
)
sequences, histogram = create_hist(dataset, max_seq_length)

Expand Down Expand Up @@ -220,6 +245,12 @@ class PackedSequenceSpecs:
This field is set by llm.finetune api.
"""

num_tokenizer_workers: int = -1
"""
The number of worker processes to use for tokenization when preparing the packed sequence dataset.
If -1, the number of workers will be set to the number of CPU cores available
"""

packed_train_data_path: str = None
"""
If specified, use this file for the packed training dataset instead of the default path.
Expand Down
13 changes: 8 additions & 5 deletions src/megatron/bridge/data/datasets/packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int:
def find_first_bin_that_fits(bin_sums: List[int], s: int, bin_size: int) -> int:
"""
Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'.

Expand All @@ -37,8 +37,8 @@ def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> in
Returns:
The index of the first bin that can fit the sequence 's', or -1 if no such bin exists.
"""
for i, abin in enumerate(bins):
if sum(abin) + s <= bin_size:
for i, cur_sum in enumerate(bin_sums):
if cur_sum + s <= bin_size:
return i
return -1

Expand All @@ -56,12 +56,15 @@ def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]:
of the sequences assigned to that bin.
"""
res = []
for s in seqlens:
first_bin = find_first_bin_that_fits(res, s, pack_size)
res_sums = []
for s in tqdm(seqlens):
first_bin = find_first_bin_that_fits(res_sums, s, pack_size)
if first_bin == -1: # open a new bin
res.append([s])
res_sums.append(s)
else:
res[first_bin].append(s)
res_sums[first_bin] += s
return res


Expand Down
5 changes: 3 additions & 2 deletions tests/functional_tests/data/datasets/test_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ def test_find_first_bin_that_fits(self):
[17, 11, 0, -5],
[100, 200],
]
bin_sums = list(map(sum, bins))
bin_size = 1
s = 11
first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size)
first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size)

assert first_bin_that_fits == -1

bin_size = 1000
first_bin_that_fits = find_first_bin_that_fits(bins, s, bin_size)
first_bin_that_fits = find_first_bin_that_fits(bin_sums, s, bin_size)

assert first_bin_that_fits == 1

Expand Down
Loading