diff --git a/nemo/collections/llm/gpt/data/chat.py b/nemo/collections/llm/gpt/data/chat.py index 06eaab9ec1e5..7a8098046a78 100644 --- a/nemo/collections/llm/gpt/data/chat.py +++ b/nemo/collections/llm/gpt/data/chat.py @@ -77,16 +77,17 @@ def __init__( @lru_cache def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs): # pylint: disable=C0115,C0116 + is_not_packing = self.packed_sequence_size <= 0 return create_sft_dataset( path, tokenizer=self.tokenizer, - seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size), + seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size), memmap_workers=self.memmap_workers, seed=self.seed, chat=True, is_test=is_test, - pack_metadata_file_path=None, # packing is not supported - pad_cu_seqlens=False, - use_hf_tokenizer_chat_template=self.use_hf_tokenizer_chat_template, + pack_metadata_file_path=None if is_not_packing else pack_metadata_path, + pad_cu_seqlens=False if is_not_packing else self.pad_cu_seqlens, + use_hf_tokenizer_chat_template=self.use_hf_tokenizer_chat_template if is_not_packing else False, **kwargs, ) diff --git a/nemo/collections/llm/gpt/data/core.py b/nemo/collections/llm/gpt/data/core.py index 9719ff96c9ea..119f88b0d328 100644 --- a/nemo/collections/llm/gpt/data/core.py +++ b/nemo/collections/llm/gpt/data/core.py @@ -34,6 +34,7 @@ ) from nemo.core.classes import Dataset from nemo.lightning.base import NEMO_DATASETS_CACHE +from nemo.utils.sequence_packing_utils import generate_positional_ids_for_cp, pad_thd_sequences_for_cp logger = logging.getLogger(__name__) @@ -758,7 +759,28 @@ def _build_loss_mask(self, processed_example): def _maybe_cast_to_list(self, x): return [item.tolist() if isinstance(item, np.ndarray) else item for item in x] + @staticmethod + def _remove_empty_sequences(batch): + # remove sequence boundaries relative to empty sequences + for idx, item in enumerate(batch): + total_seqlens = np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1]) + if item["seq_boundaries"][0] != 0: + raise ValueError("First element of seq_boundaries must be 0") + if (total_seqlens < 0).any(): + raise ValueError("seq_boundaries are not strictly increasing") + if (total_seqlens == 0).any(): + new_seq_boundaries = np.concatenate( + ( + np.array(item["seq_boundaries"][0:1]), + np.array(item["seq_boundaries"][1:])[total_seqlens != 0], + ) + ) + batch[idx]["seq_boundaries"] = new_seq_boundaries.tolist() + return batch + def collate_fn(self, batch): + batch = self._remove_empty_sequences(batch) + input_ids = [ np.concatenate( [ @@ -782,6 +804,45 @@ def collate_fn(self, batch): token_count = [item.shape[0] for item in input_ids] + cu_seqlens_unpadded = [torch.zeros(0)] * len(batch) # will be overwritten below + cu_seqlens_padded = [torch.zeros(0)] * len(batch) # will be overwritten below + position_ids = [torch.zeros(0)] * len(batch) # will be overwritten below + for i, item in enumerate(batch): + input_ids[i] = torch.tensor(input_ids[i], dtype=torch.long) + labels[i] = torch.tensor(labels[i], dtype=torch.long) + loss_mask[i] = torch.tensor(loss_mask[i], dtype=torch.long) + + _seqlens_item = ( + np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1]) - 1 + ) # -1 because input_ids is truncated by 1 for labels, see above + cu_seqlens_unpadded[i] = torch.cat((torch.zeros(1), torch.cumsum(torch.tensor(_seqlens_item), 0))).to( + torch.int32 + ) + + # Pad input_ids, labels, loss_mask so that every sequence in the pack + # reaches a length which is a multiple of self.pad_seq_length_to_mult. + # Generate position_ids for the padded sequences. + input_ids[i], labels[i], cu_seqlens_padded[i] = pad_thd_sequences_for_cp( + input_ids=input_ids[i], + labels=labels[i], + cu_seqlens=cu_seqlens_unpadded[i], + divisibility_factor=self.pad_seq_length_to_mult, + padding_token_id=self.tokenizer.eos_id, + padding_label_id=self.tokenizer.eos_id, + ) + loss_mask[i], _, _ = pad_thd_sequences_for_cp( + input_ids=loss_mask[i], + labels=labels[i], # not used + cu_seqlens=cu_seqlens_unpadded[i], + divisibility_factor=self.pad_seq_length_to_mult, + padding_token_id=0, + padding_label_id=-1, # not used + ) + position_ids[i] = generate_positional_ids_for_cp( + cu_seqlens=cu_seqlens_unpadded[i], + divisibility_factor=self.pad_seq_length_to_mult, + ) + if self.pad_to_max_length: max_length = self.max_seq_length else: @@ -795,59 +856,33 @@ def collate_fn(self, batch): ) assert max_length <= self.max_seq_length - position_ids: List[List[int]] = [] - cu_seqlens: List[List[int]] = [] - cu_seqlens_unpadded: List[List[int]] = [] - for item in batch: - position_ids.append([]) - cu_seqlens.append([0]) - cu_seqlens_unpadded.append([0]) - seqlens = np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1]) - for length in seqlens: - # length minus 1 because input_ids is truncated by 1 for labels - position_ids[-1].extend(list(range(length - 1))) - cu_seqlens[-1].append(cu_seqlens[-1][-1] + length - 1) - - # the last seq needs to be the max seq len because rope and attn kernels expect no padding - assert cu_seqlens[-1][-1] <= max_length - - # since data is prepadded when cp_size > 1, there may be some extra padding at the end - # of the packed sequence. In this case, we need to add the max seq len to the end. - if cu_seqlens[-1][-1] != max_length: - cu_seqlens[-1].append(max_length) - - for i in range(len(item["seq_boundaries"]) - 1): - current_seq = item["input_ids"][item["seq_boundaries"][i] : item["seq_boundaries"][i + 1] - 1] - - # since the data could be prepadded with tokenizer's eos_id, - # we can find out the index of all the eos_id - eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id) - - # The second eos_id index marks the length of the original unpadded sequence if the sequence is - # prepadded for cp_size > 1. Otherwise, there is no extra padding. - seqlen_unpadded = eos_idx[0][1] + 1 if eos_idx[0].shape[0] > 1 else len(current_seq) - cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded) - - # if extra paddings are added in the packed sequence, they can't be counted as - # actual tokens for training - if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]): - cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1]) + # the last seq needs to be the max seq len because rope and attn kernels expect no padding. + assert all(x[-1] <= max_length for x in cu_seqlens_padded) - if self.pad_cu_seqlens: - # pad cu_seqlens to a constant shape with zero length sequences - max_samples_per_bin = max(p["max_samples_per_bin"] for p in self.pack_metadata) - # plus 2 since cu_seqlens additionally contains 0 and may append max_length - pad_num = max_samples_per_bin - len(cu_seqlens[-1]) + 2 - cu_seqlens[-1].extend([max_length] * pad_num) + for i in range(len(cu_seqlens_padded)): + if cu_seqlens_padded[i][-1] != max_length: + cu_seqlens_padded[i] = torch.cat((cu_seqlens_padded[i], torch.tensor([max_length]))) + # Keep same number of elements in cu_seqlens_unpadded and cu_seqlens_padded. + # Simply add sequence with length 0, have the same element twice at the end + # of cu_seqlens_unpadded. + cu_seqlens_unpadded[i] = torch.cat((cu_seqlens_unpadded[i], cu_seqlens_unpadded[i][-1:])) assert len(input_ids[0]) == len( position_ids[0] ), "Dataset problem: input_ids and position_ids lengths don't match" - input_ids = self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) - labels = self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id) - loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0) - position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0) + input_ids = self._collate_item( + [x.tolist() for x in input_ids], + max_length=max_length, + pad_id=self.tokenizer.eos_id, + ) + labels = self._collate_item( + [x.tolist() for x in labels], + max_length=max_length, + pad_id=self.tokenizer.eos_id, + ) + position_ids = self._collate_item([x.tolist() for x in position_ids], max_length=max_length, pad_id=0) + loss_mask = self._collate_item([x.tolist() for x in loss_mask], max_length=max_length, pad_id=0) processed_batch = { "tokens": torch.LongTensor(input_ids), @@ -858,21 +893,22 @@ def collate_fn(self, batch): } if self.return_cu_seqlen: - cu_seqlens = self._collate_item( - cu_seqlens, - max_length=max(len(length) for length in cu_seqlens) + 1, + # Finalize the cu_seqlens values, add them to the batch. The cu_seqlens_padded + # and unpadded need to terminate with a -1 (NeMo processes them that way, no + # idea why). + cu_seqlens_padded = self._collate_item( + [x.tolist() for x in cu_seqlens_padded], + max_length=max(len(l) for l in cu_seqlens_padded) + 1, pad_id=-1, ) + cu_seqlens_padded = torch.IntTensor(cu_seqlens_padded) + cu_seqlens_padded_argmin = torch.argmin(cu_seqlens_padded, dim=1, keepdim=True) + cu_seqlens_unpadded = self._collate_item( - cu_seqlens_unpadded, - max_length=max(len(length) for length in cu_seqlens_unpadded) + 1, + [x.tolist() for x in cu_seqlens_unpadded], + max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1, ) - # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. - cu_seqlens = torch.IntTensor(cu_seqlens) - cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded) cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True) @@ -886,20 +922,20 @@ def collate_fn(self, batch): # Use the larger of the two values to avoid NAN issues with attention kernel safe_max_seqlen = max(dataset_max_seqlen, padding_gap) - max_seqlen = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens)) + max_seqlen_padded = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens_padded)) else: - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + seqlens_padded = cu_seqlens_padded[:, 1:] - cu_seqlens_padded[:, :-1] + max_seqlen_padded, _ = seqlens_padded.max(dim=1, keepdim=True) processed_batch.update( { "attention_mask": torch.LongTensor( [1] * len(input_ids) ), # no attention mask is needed for packed seq - "cu_seqlens": torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 - "cu_seqlens_argmin": cu_seqlens_argmin, # only required for perf - "max_seqlen": max_seqlen, # only required for perf - "cu_seqlens_unpadded": torch.IntTensor(cu_seqlens_unpadded), + "cu_seqlens": cu_seqlens_padded, + "cu_seqlens_unpadded": cu_seqlens_unpadded, + "cu_seqlens_argmin": cu_seqlens_padded_argmin, # only required for perf "cu_seqlens_unpadded_argmin": cu_seqlens_unpadded_argmin, + "max_seqlen": max_seqlen_padded, # only required for perf } ) else: diff --git a/nemo/collections/llm/gpt/data/packed_sequence.py b/nemo/collections/llm/gpt/data/packed_sequence.py index e0f3fbe6a109..105ec0cbe5fa 100644 --- a/nemo/collections/llm/gpt/data/packed_sequence.py +++ b/nemo/collections/llm/gpt/data/packed_sequence.py @@ -76,9 +76,11 @@ def prepare_packed_sequence_data( packed_sequence_size: int, tokenizer: TokenizerSpec, max_seq_length: int, - seed: Optional[int] = 0, - packing_algorithm: str = "first_fit_shuffle", - dataset_kwargs: dict = None, + seed: int = 0, + packing_algorithm: str = "first_fit_shuffle_with_heap", + chat: bool = False, + divisibility_factor: Optional[int] = 16, + dataset_kwargs: Optional[dict] = None, ): """ Prepares a packed sequence dataset from a given input file and saves it to an output file. @@ -92,14 +94,17 @@ def prepare_packed_sequence_data( seed (Optional[int]): Random seed for shuffling (optional). packing_algorithm (str): The algorithm used for packing sequences currently supports "first_fit_shuffle" and "first_fit_decreasing". + chat (bool): Whether the dataset is a chat dataset. Defaults to False. + divisibility_factor (Optional[int]): If specified, each sequence length will be + rounded to the next integer multiple of this factor. Returns: None: Saves the packed sequence data to the specified output path. """ logging.info(f"Preparing packed sequence from {input_path}") - dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed, dataset_kwargs) - sequences, histogram = create_hist(dataset, max_seq_length) + dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed, chat, dataset_kwargs) + sequences, histogram = create_hist(dataset, max_seq_length, divisibility_factor) assignments, packing_metadata = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm) output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id) diff --git a/nemo/utils/sequence_packing_utils.py b/nemo/utils/sequence_packing_utils.py index 49b49576961f..624e734b0049 100644 --- a/nemo/utils/sequence_packing_utils.py +++ b/nemo/utils/sequence_packing_utils.py @@ -13,9 +13,11 @@ # limitations under the License. import collections -from typing import Dict, List, Tuple +import heapq +from typing import Dict, List, Optional, Tuple import numpy as np +import torch from tqdm import tqdm from nemo.utils import logging @@ -98,7 +100,63 @@ def first_fit_shuffle(seqlens: List[int], pack_size: int) -> List[List[int]]: return first_fit(shuffled_seqlens, pack_size) -def create_hist(dataset: np.array, truncate_seq_len: int): +def first_fit_shuffle_with_heap( + seqlens: list[int], pack_size: int, shuffle: bool = True, seed: int | None = 234 +) -> list[list[int]]: + """A custom packing routine. + Packs sequences of varying lengths into bins using a First-Fit-like algorithm. + + This routine is similar in logic to First-Fit: for every new sequence, look for an + existing bin that can fit it, otherwise open a new bin. + While the original First-Fit version uses a greedy function called + `find_first_bin_that_fits`, here we greedily look for an accomodating bin using a + continuously updated heap. For large datasets, this makes it 100x-1000x faster. + + In this routine, seqlens can be shuffled before packing, which is necessary to + preserve the packing efficiency (i.e. the average number of sequences per pack). + + It is recommended to use shuffle=True (default) to increase the packing efficiency. + + Args: + seqlens: A list of integers, representing the lengths of the sequences to be packed. + pack_size: The maximum capacity of each bin. + shuffle: Whether to shuffle the sequence lengths before packing. + seed: Random seed for shuffling. + + Returns: + A list of lists, similar to the output of the 'first_fit' function. + """ + + if not seqlens: + return [] + + if shuffle: + rng = np.random.default_rng(seed) + rng.shuffle(seqlens) + + s = seqlens[0] + res = [(s, [s])] + for s in tqdm(seqlens[1:], desc="Creating packing strategy"): + # Check the first bin: it is the one with the smallest total sequence length. + # If it is possible to add the sequence to it without exceeding the pack size, + # then add the sequence to the bin. Otherwise, open a new bin. + first_bin_sum = res[0][0] + if first_bin_sum + s <= pack_size: + first_bin_sum, first_bin = heapq.heappop(res) + first_bin.append(s) + first_bin_sum += s + heapq.heappush(res, (first_bin_sum, first_bin)) + else: + heapq.heappush(res, (s, [s])) + return [bin for _, bin in res] + + +def next_multiple_of(n, m): + """Return the next multiple of m greater than or equal to n.""" + return (n + m - 1) // m * m + + +def create_hist(dataset: np.array, truncate_seq_len: int, divisibility_factor: Optional[int] = 16): """ Creates a histogram of sequence lengths from a tokenized dataset. @@ -116,16 +174,27 @@ def create_hist(dataset: np.array, truncate_seq_len: int): """ logging.info("Creating histogram from tokenized dataset...") + if divisibility_factor is not None and truncate_seq_len % divisibility_factor: + raise ValueError(f"{truncate_seq_len=} must be a multiple of {divisibility_factor=}") + sequences = collections.defaultdict(list) counts = [0] * (truncate_seq_len + 1) for item_dict in dataset: - # Minus 1 here to account for the fact that transformer input and label - # have one less token than the full sequence. - # Input is missing the last token and label is missing the first token - # (this way the tokens are aligned for next token prediction). - # We want pack size to be the length of the actual input and label, hence minus 1. + # The data processing pipeline downstream is expected to be the following: + # - REMOVE THE LAST TOKEN -> the -1 here to account for the fact that + # transformer input and labels have one less token than the full sequence: + # input is missing the last token and label is missing the first token + # (this way the tokens are aligned for next token prediction). + # - (POSSIBLY) PAD TO THE NEXT MULTIPLE OF `divisibility_factor` -> we + # virtually pad the sequence length to the next multiple of this value (the + # sequence is not modified, it is only assigned to a different length bin). If + # the sequence is not padded downstream, nothing is impacted except the data + # packing is slightly less optimal, since we may pack less sequences together. + # - PACKING -> concatenate the resulting sequences into a single packed one. seq_len = len(item_dict["input_ids"]) - 1 + if divisibility_factor is not None: + seq_len = next_multiple_of(seq_len, divisibility_factor) sequences[seq_len].append(item_dict) counts[seq_len] += 1 @@ -222,24 +291,21 @@ def fill_packing_strategy( per_seq_data = sequences[seq_len] if len(per_seq_data) > 0: perm = np.random.permutation(len(per_seq_data)) - input_ids = np.array([x["input_ids"] for x in per_seq_data])[perm].tolist() + input_ids = [per_seq_data[idx]["input_ids"] for idx in perm] try: - loss_mask = np.array([x["loss_mask"] for x in per_seq_data])[perm].tolist() - # roll loss mask by 1 to align with labels. We want to train on the output after the last context token - loss_mask = [x[1:] + [False] for x in loss_mask] + loss_mask = [per_seq_data[idx]["loss_mask"] for idx in perm] except KeyError: try: - loss_mask = np.array( + loss_mask = [ [ - [ - # (x['answer_start_idx'] - 1) because we want to train on the output - # after the last context token - idx >= (x["answer_start_idx"] - 1) - for idx in range(len(x["input_ids"])) - ] - for x in per_seq_data + # (x['answer_start_idx'] - 1) because we want to train on the output + # after the last context token + idx >= (x["answer_start_idx"] - 1) + for idx in range(len(x["input_ids"])) ] - )[perm].tolist() + for x in per_seq_data + ] + loss_mask = [loss_mask[idx] for idx in perm] except KeyError as err: err_msg = "Key errors loss_mask and answer_start_idx missing in example - " err_msg += f"{err} {per_seq_data[0]}" @@ -248,25 +314,25 @@ def fill_packing_strategy( ifile_handles[seq_len] = (input_ids, loss_mask) - input_ids, loss_mask, seq_start_id = {}, {}, {} - - for oindex, assignment in tqdm(enumerate(assignments), total=len(assignments)): - _input_ids, _loss_mask, _seq_start_id = [], [], [0] - - for seq_length in assignment: - _input_ids.extend(ifile_handles[seq_length][0].pop()) - _loss_mask.extend(ifile_handles[seq_length][1].pop()) - _seq_start_id.append(len(_input_ids)) - - input_ids[oindex] = _input_ids - loss_mask[oindex] = _loss_mask - seq_start_id[oindex] = _seq_start_id[:-1] + input_ids = [[0] * len(assignment) for assignment in assignments] + loss_mask = [[0] * len(assignment) for assignment in assignments] + seq_start_id = [[0] * (len(assignment) + 1) for assignment in assignments] + for oindex, assignment in tqdm( + enumerate(assignments), + total=len(assignments), + desc="Creating packed sequences", + ): + seq_start_id[oindex][0] = 0 + for j, seq_length in enumerate(assignment): + input_ids[oindex][j] = ifile_handles[seq_length][0].pop() + loss_mask[oindex][j] = ifile_handles[seq_length][1].pop() + seq_start_id[oindex][j + 1] = len(input_ids[oindex][j]) + seq_start_id[oindex][j] output_data = [] for i in range(len(input_ids)): item_dict = { - "input_ids": input_ids[i], - "loss_mask": loss_mask[i], + "input_ids": np.concatenate([np.array(x) for x in input_ids[i]]).reshape(-1), + "loss_mask": np.concatenate([np.array(x) for x in loss_mask[i]]).reshape(-1), "seq_start_id": seq_start_id[i], } output_data.append(item_dict) @@ -274,3 +340,100 @@ def fill_packing_strategy( assert all(not seq[0] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" assert all(not seq[1] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" return output_data + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + Literally a copy-paste of the same function from transformer_engine, see + https://github.com/NVIDIA/TransformerEngine/blob/dfacd9f76bcabcdd53cb30a17679ad6032cf54f4/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens + ] + + # Extract sequences and labels for each batch item + batch_sequences = [input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] + batch_labels = [labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat( + [ + (torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) if pad > 0 else seq) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + (torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) if pad > 0 else seq) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum(torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + Literally a copy-paste of the same function from transformer_engine, see + https://github.com/NVIDIA/TransformerEngine/blob/dfacd9f76bcabcdd53cb30a17679ad6032cf54f4/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor - l.item() for l in seqlens + ] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat([torch.arange(0, int(length), dtype=dtype) for length in padded_lengths]) + + return positional_ids