Skip to content
9 changes: 5 additions & 4 deletions nemo/collections/llm/gpt/data/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,17 @@ def __init__(
@lru_cache
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
# pylint: disable=C0115,C0116
is_not_packing = self.packed_sequence_size <= 0
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size),
seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size),
memmap_workers=self.memmap_workers,
seed=self.seed,
chat=True,
is_test=is_test,
pack_metadata_file_path=None, # packing is not supported
pad_cu_seqlens=False,
use_hf_tokenizer_chat_template=self.use_hf_tokenizer_chat_template,
pack_metadata_file_path=None if is_not_packing else pack_metadata_path,
pad_cu_seqlens=False if is_not_packing else self.pad_cu_seqlens,
use_hf_tokenizer_chat_template=self.use_hf_tokenizer_chat_template if is_not_packing else False,
**kwargs,
)
164 changes: 100 additions & 64 deletions nemo/collections/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from nemo.core.classes import Dataset
from nemo.lightning.base import NEMO_DATASETS_CACHE
from nemo.utils.sequence_packing_utils import generate_positional_ids_for_cp, pad_thd_sequences_for_cp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -758,7 +759,28 @@ def _build_loss_mask(self, processed_example):
def _maybe_cast_to_list(self, x):
return [item.tolist() if isinstance(item, np.ndarray) else item for item in x]

@staticmethod
def _remove_empty_sequences(batch):
# remove sequence boundaries relative to empty sequences
for idx, item in enumerate(batch):
total_seqlens = np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1])
if item["seq_boundaries"][0] != 0:
raise ValueError("First element of seq_boundaries must be 0")
if (total_seqlens < 0).any():
raise ValueError("seq_boundaries are not strictly increasing")
if (total_seqlens == 0).any():
new_seq_boundaries = np.concatenate(
(
np.array(item["seq_boundaries"][0:1]),
np.array(item["seq_boundaries"][1:])[total_seqlens != 0],
)
)
batch[idx]["seq_boundaries"] = new_seq_boundaries.tolist()
return batch

def collate_fn(self, batch):
batch = self._remove_empty_sequences(batch)

input_ids = [
np.concatenate(
[
Expand All @@ -782,6 +804,45 @@ def collate_fn(self, batch):

token_count = [item.shape[0] for item in input_ids]

cu_seqlens_unpadded = [torch.zeros(0)] * len(batch) # will be overwritten below
cu_seqlens_padded = [torch.zeros(0)] * len(batch) # will be overwritten below
position_ids = [torch.zeros(0)] * len(batch) # will be overwritten below
for i, item in enumerate(batch):
input_ids[i] = torch.tensor(input_ids[i], dtype=torch.long)
labels[i] = torch.tensor(labels[i], dtype=torch.long)
loss_mask[i] = torch.tensor(loss_mask[i], dtype=torch.long)

_seqlens_item = (
np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1]) - 1
) # -1 because input_ids is truncated by 1 for labels, see above
cu_seqlens_unpadded[i] = torch.cat((torch.zeros(1), torch.cumsum(torch.tensor(_seqlens_item), 0))).to(
torch.int32
)

# Pad input_ids, labels, loss_mask so that every sequence in the pack
# reaches a length which is a multiple of self.pad_seq_length_to_mult.
# Generate position_ids for the padded sequences.
input_ids[i], labels[i], cu_seqlens_padded[i] = pad_thd_sequences_for_cp(
input_ids=input_ids[i],
labels=labels[i],
cu_seqlens=cu_seqlens_unpadded[i],
divisibility_factor=self.pad_seq_length_to_mult,
padding_token_id=self.tokenizer.eos_id,
padding_label_id=self.tokenizer.eos_id,
)
loss_mask[i], _, _ = pad_thd_sequences_for_cp(
input_ids=loss_mask[i],
labels=labels[i], # not used
cu_seqlens=cu_seqlens_unpadded[i],
divisibility_factor=self.pad_seq_length_to_mult,
padding_token_id=0,
padding_label_id=-1, # not used
)
position_ids[i] = generate_positional_ids_for_cp(
cu_seqlens=cu_seqlens_unpadded[i],
divisibility_factor=self.pad_seq_length_to_mult,
)

if self.pad_to_max_length:
max_length = self.max_seq_length
else:
Expand All @@ -795,59 +856,33 @@ def collate_fn(self, batch):
)
assert max_length <= self.max_seq_length

position_ids: List[List[int]] = []
cu_seqlens: List[List[int]] = []
cu_seqlens_unpadded: List[List[int]] = []
for item in batch:
position_ids.append([])
cu_seqlens.append([0])
cu_seqlens_unpadded.append([0])
seqlens = np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1])
for length in seqlens:
# length minus 1 because input_ids is truncated by 1 for labels
position_ids[-1].extend(list(range(length - 1)))
cu_seqlens[-1].append(cu_seqlens[-1][-1] + length - 1)

# the last seq needs to be the max seq len because rope and attn kernels expect no padding
assert cu_seqlens[-1][-1] <= max_length

# since data is prepadded when cp_size > 1, there may be some extra padding at the end
# of the packed sequence. In this case, we need to add the max seq len to the end.
if cu_seqlens[-1][-1] != max_length:
cu_seqlens[-1].append(max_length)

for i in range(len(item["seq_boundaries"]) - 1):
current_seq = item["input_ids"][item["seq_boundaries"][i] : item["seq_boundaries"][i + 1] - 1]

# since the data could be prepadded with tokenizer's eos_id,
# we can find out the index of all the eos_id
eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id)

# The second eos_id index marks the length of the original unpadded sequence if the sequence is
# prepadded for cp_size > 1. Otherwise, there is no extra padding.
seqlen_unpadded = eos_idx[0][1] + 1 if eos_idx[0].shape[0] > 1 else len(current_seq)
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded)

# if extra paddings are added in the packed sequence, they can't be counted as
# actual tokens for training
if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]):
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1])
# the last seq needs to be the max seq len because rope and attn kernels expect no padding.
assert all(x[-1] <= max_length for x in cu_seqlens_padded)

if self.pad_cu_seqlens:
# pad cu_seqlens to a constant shape with zero length sequences
max_samples_per_bin = max(p["max_samples_per_bin"] for p in self.pack_metadata)
# plus 2 since cu_seqlens additionally contains 0 and may append max_length
pad_num = max_samples_per_bin - len(cu_seqlens[-1]) + 2
cu_seqlens[-1].extend([max_length] * pad_num)
for i in range(len(cu_seqlens_padded)):
if cu_seqlens_padded[i][-1] != max_length:
cu_seqlens_padded[i] = torch.cat((cu_seqlens_padded[i], torch.tensor([max_length])))
# Keep same number of elements in cu_seqlens_unpadded and cu_seqlens_padded.
# Simply add sequence with length 0, have the same element twice at the end
# of cu_seqlens_unpadded.
cu_seqlens_unpadded[i] = torch.cat((cu_seqlens_unpadded[i], cu_seqlens_unpadded[i][-1:]))

assert len(input_ids[0]) == len(
position_ids[0]
), "Dataset problem: input_ids and position_ids lengths don't match"

input_ids = self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id)
labels = self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)
loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0)
position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0)
input_ids = self._collate_item(
[x.tolist() for x in input_ids],
max_length=max_length,
pad_id=self.tokenizer.eos_id,
)
labels = self._collate_item(
[x.tolist() for x in labels],
max_length=max_length,
pad_id=self.tokenizer.eos_id,
)
position_ids = self._collate_item([x.tolist() for x in position_ids], max_length=max_length, pad_id=0)
loss_mask = self._collate_item([x.tolist() for x in loss_mask], max_length=max_length, pad_id=0)

processed_batch = {
"tokens": torch.LongTensor(input_ids),
Expand All @@ -858,21 +893,22 @@ def collate_fn(self, batch):
}

if self.return_cu_seqlen:
cu_seqlens = self._collate_item(
cu_seqlens,
max_length=max(len(length) for length in cu_seqlens) + 1,
# Finalize the cu_seqlens values, add them to the batch. The cu_seqlens_padded
# and unpadded need to terminate with a -1 (NeMo processes them that way, no
# idea why).
cu_seqlens_padded = self._collate_item(
[x.tolist() for x in cu_seqlens_padded],
max_length=max(len(l) for l in cu_seqlens_padded) + 1,
pad_id=-1,
)
cu_seqlens_padded = torch.IntTensor(cu_seqlens_padded)
cu_seqlens_padded_argmin = torch.argmin(cu_seqlens_padded, dim=1, keepdim=True)

cu_seqlens_unpadded = self._collate_item(
cu_seqlens_unpadded,
max_length=max(len(length) for length in cu_seqlens_unpadded) + 1,
[x.tolist() for x in cu_seqlens_unpadded],
max_length=max(len(l) for l in cu_seqlens_unpadded) + 1,
pad_id=-1,
)
# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded)
cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True)

Expand All @@ -886,20 +922,20 @@ def collate_fn(self, batch):

# Use the larger of the two values to avoid NAN issues with attention kernel
safe_max_seqlen = max(dataset_max_seqlen, padding_gap)
max_seqlen = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens))
max_seqlen_padded = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens_padded))
else:
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
seqlens_padded = cu_seqlens_padded[:, 1:] - cu_seqlens_padded[:, :-1]
max_seqlen_padded, _ = seqlens_padded.max(dim=1, keepdim=True)
processed_batch.update(
{
"attention_mask": torch.LongTensor(
[1] * len(input_ids)
), # no attention mask is needed for packed seq
"cu_seqlens": torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
"cu_seqlens_argmin": cu_seqlens_argmin, # only required for perf
"max_seqlen": max_seqlen, # only required for perf
"cu_seqlens_unpadded": torch.IntTensor(cu_seqlens_unpadded),
"cu_seqlens": cu_seqlens_padded,
"cu_seqlens_unpadded": cu_seqlens_unpadded,
"cu_seqlens_argmin": cu_seqlens_padded_argmin, # only required for perf
"cu_seqlens_unpadded_argmin": cu_seqlens_unpadded_argmin,
"max_seqlen": max_seqlen_padded, # only required for perf
}
)
else:
Expand Down
15 changes: 10 additions & 5 deletions nemo/collections/llm/gpt/data/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def prepare_packed_sequence_data(
packed_sequence_size: int,
tokenizer: TokenizerSpec,
max_seq_length: int,
seed: Optional[int] = 0,
packing_algorithm: str = "first_fit_shuffle",
dataset_kwargs: dict = None,
seed: int = 0,
packing_algorithm: str = "first_fit_shuffle_with_heap",
chat: bool = False,
divisibility_factor: Optional[int] = 16,
dataset_kwargs: Optional[dict] = None,
):
"""
Prepares a packed sequence dataset from a given input file and saves it to an output file.
Expand All @@ -92,14 +94,17 @@ def prepare_packed_sequence_data(
seed (Optional[int]): Random seed for shuffling (optional).
packing_algorithm (str): The algorithm used for packing sequences
currently supports "first_fit_shuffle" and "first_fit_decreasing".
chat (bool): Whether the dataset is a chat dataset. Defaults to False.
divisibility_factor (Optional[int]): If specified, each sequence length will be
rounded to the next integer multiple of this factor.

Returns:
None: Saves the packed sequence data to the specified output path.
"""

logging.info(f"Preparing packed sequence from {input_path}")
dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed, dataset_kwargs)
sequences, histogram = create_hist(dataset, max_seq_length)
dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed, chat, dataset_kwargs)
sequences, histogram = create_hist(dataset, max_seq_length, divisibility_factor)

assignments, packing_metadata = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm)
output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id)
Expand Down
Loading