diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index ce5ed73c17..3818fa2831 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -55,6 +55,9 @@ policy: dynamic_batching: enabled: False + sequence_packing: + enabled: False + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index f2552eea7e..dc9db4ceab 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -21,6 +21,9 @@ policy: dynamic_batching: enabled: False + sequence_packing: + enabled: False + optimizer: name: "torch.optim.AdamW" kwargs: diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index fd944fa9e7..da71d08108 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -63,6 +63,16 @@ policy: logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} sequence_length_round: 64 + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "concatenative" + sequence_length_round: 64 + + megatron_cfg: + enabled: false + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 7a8a651a54..93404a086a 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -51,12 +51,8 @@ policy: # training and logprob stages respectively. dynamic_batching: enabled: False - - sequence_packing: - enabled: False # coming soon train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_ffd" sequence_length_round: 64 max_grad_norm: 1.0 diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index fc839c8239..55f3b38073 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -72,4 +72,4 @@ policy: cluster: gpus_per_node: 8 - num_nodes: 1 \ No newline at end of file + num_nodes: 1 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 3839d455e2..17600c05ac 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -44,6 +44,12 @@ policy: dynamic_batching: enabled: false + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + algorithm: "concatenative" + sequence_length_round: 64 + # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} @@ -121,7 +127,7 @@ policy: average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" - + data: max_input_seq_length: ${policy.max_total_sequence_length} dataset_name: "squad" diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 2040bdd5ff..2e953a5d69 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -37,6 +37,9 @@ policy: context_parallel_size: 1 custom_parallel_plan: null + sequence_packing: + enabled: False + dynamic_batching: enabled: false diff --git a/examples/run_sft.py b/examples/run_sft.py index ce5b258b0c..df0d7ce3f7 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -31,6 +31,8 @@ from nemo_rl.utils.config import load_config, parse_hydra_overrides from nemo_rl.utils.logger import get_next_experiment_dir +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + def parse_args(): """Parse command line arguments.""" diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 1bf472d830..58ca0bd39a 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -601,3 +601,76 @@ def __call__( "rewards_rejected_mean": rewards_rejected_mean.item(), "num_valid_samples": num_valid_samples.item(), } + + +class SequencePackingLossWrapper: + def __init__( + self, + loss_fn: LossFunction, + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + ): + self.loss_fn = loss_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = cu_seqlens_q_padded + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid padding.""" + unpadded_cu_seqlens = self.cu_seqlens_q + unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + if self.cu_seqlens_q_padded is not None: + padded_cu_seqlens = self.cu_seqlens_q_padded + padded_seq_lengths = ( + self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] + ) + else: + padded_cu_seqlens = unpadded_cu_seqlens + padded_seq_lengths = unpadded_seq_lengths + seq_starts = padded_cu_seqlens[:-1] + seq_ends = padded_cu_seqlens[1:] + + loss_accum = 0 + metrics_accum = {} + for seq_idx in range(len(seq_starts)): + seq_start = seq_starts[seq_idx].item() + seq_end = seq_ends[seq_idx].item() + + # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors + seq_data = data.slice(seq_idx, seq_idx + 1) + unpadded_seq_data = {} + for k, v in seq_data.items(): + # print(f"k: {k}, v: {v.shape}") + if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: + unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] + else: + unpadded_seq_data[k] = v + + # get next_token_logits + next_token_logits_slice = next_token_logits[ + :, seq_start : seq_start + unpadded_seq_lengths[seq_idx], : + ] + # print(f"seq_start: {seq_start}, seq_end: {seq_end}, next_token_logits: {next_token_logits_slice.shape}") + + loss, metrics = self.loss_fn( + next_token_logits_slice, + unpadded_seq_data, + global_valid_seqs, + global_valid_toks, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + ) + loss_accum += loss + for k, v in metrics.items(): + if k not in metrics_accum: + metrics_accum[k] = 0 + metrics_accum[k] += v + + return loss_accum, metrics_accum diff --git a/nemo_rl/data/packing/__init__.py b/nemo_rl/data/packing/__init__.py new file mode 100644 index 0000000000..a955f681cc --- /dev/null +++ b/nemo_rl/data/packing/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.data.packing.algorithms import ( + ConcatenativePacker, + FirstFitDecreasingPacker, + FirstFitShufflePacker, + ModifiedFirstFitDecreasingPacker, + PackingAlgorithm, + SequencePacker, + get_packer, +) +from nemo_rl.data.packing.metrics import PackingMetrics + +__all__ = [ + "PackingAlgorithm", + "SequencePacker", + "ConcatenativePacker", + "FirstFitDecreasingPacker", + "FirstFitShufflePacker", + "ModifiedFirstFitDecreasingPacker", + "get_packer", + "PackingMetrics", +] diff --git a/nemo_rl/data/packing/algorithms.py b/nemo_rl/data/packing/algorithms.py new file mode 100644 index 0000000000..71e643f2b7 --- /dev/null +++ b/nemo_rl/data/packing/algorithms.py @@ -0,0 +1,571 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sequence packing algorithms for efficient batching of variable-length sequences.""" + +import enum +import math +import random +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Type, Union + + +class PackingAlgorithm(enum.Enum): + """Enum for supported sequence packing algorithms.""" + + CONCATENATIVE = "concatenative" + FIRST_FIT_DECREASING = "first_fit_decreasing" + FIRST_FIT_SHUFFLE = "first_fit_shuffle" + MODIFIED_FIRST_FIT_DECREASING = "modified_first_fit_decreasing" + + +class SequencePacker(ABC): + """Abstract base class for sequence packing algorithms. + + Sequence packing is the process of efficiently arranging sequences of different + lengths into fixed-capacity bins (batches) to maximize computational efficiency. + """ + + def __init__(self, bin_capacity: int, collect_metrics: bool = False): + """Initialize the sequence packer. + + Args: + bin_capacity: The maximum capacity of each bin. + collect_metrics: Whether to collect metrics across multiple packing operations. + """ + self.bin_capacity = bin_capacity + self.collect_metrics = collect_metrics + self.metrics = None + + if collect_metrics: + from nemo_rl.data.packing.metrics import PackingMetrics + + self.metrics = PackingMetrics() + + @abstractmethod + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Implementation of the packing algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + pass + + def pack(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences into bins and update metrics if enabled. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Call the implementation + bins = self._pack_implementation(sequence_lengths) + + # Update metrics if collection is enabled + if self.collect_metrics and self.metrics: + self.metrics.update(sequence_lengths, bins, self.bin_capacity) + + return bins + + def reset_metrics(self) -> None: + """Reset collected metrics.""" + if self.metrics: + self.metrics.reset() + + def compute_metrics( + self, sequence_lengths: List[int], bins: List[List[int]] + ) -> Dict[str, float]: + """Calculate metrics for a packing solution without updating the metrics tracker. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + + Returns: + Dictionary of packing metrics + """ + if self.metrics: + return self.metrics.calculate_stats_only( + sequence_lengths, bins, self.bin_capacity + ) + else: + # Create a temporary metrics object if not collecting + from nemo_rl.data.packing.metrics import PackingMetrics + + temp_metrics = PackingMetrics() + return temp_metrics.calculate_stats_only( + sequence_lengths, bins, self.bin_capacity + ) + + def get_aggregated_metrics(self) -> Dict[str, float]: + """Get aggregated metrics across all packing operations. + + Returns: + Dictionary of aggregated metrics, or empty dict if not collecting + """ + if self.metrics: + return self.metrics.get_aggregated_stats() + else: + return {} + + def print_metrics(self) -> None: + """Print the current metrics in a formatted way.""" + if not self.metrics: + print( + "Metrics collection is not enabled. Initialize with collect_metrics=True." + ) + return + + self.metrics.print_aggregated_stats() + + def _validate_sequence_lengths(self, sequence_lengths: List[int]) -> None: + """Validate that all sequence lengths are within bin capacity. + + Args: + sequence_lengths: A list of sequence lengths to validate. + + Raises: + ValueError: If any sequence length exceeds bin capacity. + """ + for length in sequence_lengths: + if length > self.bin_capacity: + raise ValueError( + f"Sequence length {length} exceeds bin capacity {self.bin_capacity}" + ) + + def _create_indexed_lengths( + self, sequence_lengths: List[int], reverse: bool = False + ) -> List[Tuple[int, int]]: + """Create a list of (length, index) pairs from sequence lengths. + + Args: + sequence_lengths: A list of sequence lengths. + reverse: Whether to sort in descending order (True) or ascending order (False). + + Returns: + A list of (length, index) pairs, optionally sorted. + """ + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + if reverse: + indexed_lengths.sort(reverse=True) # Sort in descending order + return indexed_lengths + + def _estimate_bins_needed(self, sequence_lengths: List[int]) -> int: + """Estimate the number of bins needed based on total length. + + Args: + sequence_lengths: A list of sequence lengths. + + Returns: + Estimated number of bins needed. + """ + total_length = sum(sequence_lengths) + return max(1, math.ceil(total_length / self.bin_capacity)) + + +class ConcatenativePacker(SequencePacker): + """Concatenative packing algorithm. + + This algorithm simply concatenates sequences in order until reaching the bin capacity, + then starts a new bin. It doesn't try to optimize the packing in any way. + + Time complexity: O(n) where n is the number of sequences. + + Example: + ```python + >>> examples = { + ... "sequence_lengths": [4, 1, 3, 2, 1, 3, 4, 5] + ... } + >>> # If packed with seq_length=5: + ... {"bins": [ [0, 1], [2, 3], [4, 5], [6], [7] ]} + >>> # If packed with seq_length=8: + ... {"bins": [ [0, 1, 2], [3, 4, 5], [6], [7] ]} + """ + + # Global class variable to limit the number of sequences packed in a unit + # -1 disables this limit + max_sequences_per_bin = 4 # Useful for debugging and testing + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the Concatenative algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Validate sequence lengths + self._validate_sequence_lengths(sequence_lengths) + + bins = [] # List of bins, each bin is a list of sequence indices + current_bin = [] # Current bin being filled + current_length = 0 # Current length of sequences in the bin + + for i, length in enumerate(sequence_lengths): + # Check if adding this sequence would exceed bin capacity or sequence limit + exceeds_capacity = current_length + length > self.bin_capacity + exceeds_sequence_limit = ( + self.max_sequences_per_bin != -1 + and len(current_bin) >= self.max_sequences_per_bin + ) + + # If adding this sequence would exceed constraints, start a new bin + if exceeds_capacity or exceeds_sequence_limit: + if current_bin: # Only add the bin if it's not empty + bins.append(current_bin) + current_bin = [i] + current_length = length + else: + # Add the sequence to the current bin + current_bin.append(i) + current_length += length + + # Add the last bin if it's not empty + if current_bin: + bins.append(current_bin) + + return bins + + +class FirstFitPacker(SequencePacker): + """Base class for First-Fit algorithms. + + First-Fit algorithms place each sequence into the first bin where it fits. + If no bin can fit the sequence, a new bin is created. + + This is an abstract base class that provides the common implementation for + First-Fit variants. Subclasses must implement the _prepare_sequences method + to determine the order in which sequences are processed. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing. + + This method determines the order in which sequences are processed. + Subclasses must override this method. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs. + """ + raise NotImplementedError("Subclasses must implement _prepare_sequences") + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the First-Fit algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Prepare sequences for packing (order determined by subclass) + indexed_lengths = self._prepare_sequences(sequence_lengths) + + bins = [] # List of bins, each bin is a list of sequence indices + bin_remaining = [] # Remaining capacity for each bin + + for length, idx in indexed_lengths: + # If the sequence is larger than the bin capacity, it cannot be packed + if length > self.bin_capacity: + raise ValueError( + f"Sequence length {length} exceeds bin capacity {self.bin_capacity}" + ) + + # Try to find a bin where the sequence fits + bin_found = False + for i, remaining in enumerate(bin_remaining): + if remaining >= length: + # Add the sequence to this bin + bins[i].append(idx) + bin_remaining[i] -= length + bin_found = True + break + + # If no suitable bin was found, create a new one + if not bin_found: + bins.append([idx]) + bin_remaining.append(self.bin_capacity - length) + + return bins + + +class FirstFitDecreasingPacker(FirstFitPacker): + """First-Fit Decreasing (FFD) algorithm for sequence packing. + + This algorithm sorts sequences by length in descending order and then + places each sequence into the first bin where it fits. + + Time complexity: O(n log n) for sorting + O(n * m) for packing, + where n is the number of sequences and m is the number of bins. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing by sorting them in descending order. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs sorted by length in descending order. + """ + # Create a list of (length, index) pairs + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + + # Sort by length in descending order + indexed_lengths.sort(reverse=True) + + return indexed_lengths + + +class FirstFitShufflePacker(FirstFitPacker): + """First-Fit Shuffle algorithm for sequence packing. + + This algorithm randomly shuffles the sequences and then places each + sequence into the first bin where it fits. + + Time complexity: O(n * m) for packing, where n is the number of sequences + and m is the number of bins. + """ + + def _prepare_sequences(self, sequence_lengths: List[int]) -> List[Tuple[int, int]]: + """Prepare sequences for packing by randomly shuffling them. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of (length, index) pairs in random order. + """ + # Create a list of (length, index) pairs + indexed_lengths = [(length, i) for i, length in enumerate(sequence_lengths)] + + # Shuffle the sequences + random.shuffle(indexed_lengths) + + return indexed_lengths + + +class ModifiedFirstFitDecreasingPacker(SequencePacker): + """Modified First-Fit Decreasing (MFFD) algorithm for sequence packing. + + This algorithm implements the Johnson & Garey (1985) Modified First-Fit-Decreasing + heuristic. It classifies items into four categories (large, medium, small, tiny) + and uses a sophisticated 5-phase packing strategy to achieve better bin utilization + than standard First-Fit Decreasing. + + The algorithm phases: + 1. Classify items by size relative to bin capacity + 2. Create one bin per large item + 3. Add medium items to large bins (forward pass) + 4. Add pairs of small items to bins with medium items (backward pass) + 5. Greedily fit remaining items + 6. Apply FFD to any leftovers + + Time complexity: O(n log n) for sorting + O(n * m) for packing, + where n is the number of sequences and m is the number of bins. + """ + + def _classify_items( + self, items: List[Tuple[int, int]] + ) -> Tuple[ + List[Tuple[int, int]], + List[Tuple[int, int]], + List[Tuple[int, int]], + List[Tuple[int, int]], + ]: + """Split items into large / medium / small / tiny classes. + + Follows the classification used by Johnson & Garey: + large : (C/2, C] + medium : (C/3, C/2] + small : (C/6, C/3] + tiny : (0 , C/6] + + Args: + items: List of (index, size) tuples + + Returns: + Tuple of four lists (large, medium, small, tiny) without additional sorting. + """ + large, medium, small, tiny = [], [], [], [] + for idx, size in items: + if size > self.bin_capacity / 2: + large.append((idx, size)) + elif size > self.bin_capacity / 3: + medium.append((idx, size)) + elif size > self.bin_capacity / 6: + small.append((idx, size)) + else: + tiny.append((idx, size)) + return large, medium, small, tiny + + def _pack_implementation(self, sequence_lengths: List[int]) -> List[List[int]]: + """Pack sequences using the Modified First-Fit Decreasing algorithm. + + Args: + sequence_lengths: A list of sequence lengths to pack. + + Returns: + A list of bins, where each bin is a list of indices into the original + sequence_lengths list. + """ + # Validate inputs + if self.bin_capacity <= 0: + raise ValueError("bin_capacity must be positive") + if any(l <= 0 for l in sequence_lengths): + raise ValueError("sequence lengths must be positive") + + # Validate sequence lengths don't exceed capacity + self._validate_sequence_lengths(sequence_lengths) + + items: List[Tuple[int, int]] = [(i, l) for i, l in enumerate(sequence_lengths)] + + # Phase-0: classify + large, medium, small, tiny = self._classify_items(items) + + # Sort according to the rules of MFFD + large.sort(key=lambda x: x[1], reverse=True) # descending size + medium.sort(key=lambda x: x[1], reverse=True) + small.sort(key=lambda x: x[1]) # ascending size + tiny.sort(key=lambda x: x[1]) + + # Phase-1: start one bin per large item + bins: List[List[Tuple[int, int]]] = [[item] for item in large] + + # Phase-2: try to add one medium item to each large bin (forward pass) + for b in bins: + remaining = self.bin_capacity - sum(size for _, size in b) + for i, (idx, size) in enumerate(medium): + if size <= remaining: + b.append(medium.pop(i)) + break + + # Phase-3: backward pass – fill with two small items where possible + for b in reversed(bins): + has_medium = any( + self.bin_capacity / 3 < size <= self.bin_capacity / 2 for _, size in b + ) + if has_medium or len(small) < 2: + continue + remaining = self.bin_capacity - sum(size for _, size in b) + if small[0][1] + small[1][1] > remaining: + continue + first_small = small.pop(0) + # pick the *largest* small that fits with first_small (so iterate from end) + second_idx = None + for j in range(len(small) - 1, -1, -1): + if small[j][1] <= remaining - first_small[1]: + second_idx = j + break + if second_idx is not None: + second_small = small.pop(second_idx) + b.extend([first_small, second_small]) + + # Phase-4: forward greedy fit of remaining items + remaining_items = sorted( + medium + small + tiny, key=lambda x: x[1], reverse=True + ) + for b in bins: + while remaining_items: + rem = self.bin_capacity - sum(size for _, size in b) + # if even the smallest remaining doesn't fit we break + if rem < remaining_items[-1][1]: + break + + # pick the first (largest) that fits + chosen_idx = None + for i, (_, size) in enumerate(remaining_items): + if size <= rem: + chosen_idx = i + break + if chosen_idx is None: + break + b.append(remaining_items.pop(chosen_idx)) + + # Phase-5: FFD on leftovers + leftovers = remaining_items # renamed for clarity + ffd_bins: List[List[Tuple[int, int]]] = [] + for idx, size in sorted(leftovers, key=lambda x: x[1], reverse=True): + placed = False + for bin_ffd in ffd_bins: + if size <= self.bin_capacity - sum(s for _, s in bin_ffd): + bin_ffd.append((idx, size)) + placed = True + break + if not placed: + ffd_bins.append([(idx, size)]) + bins.extend(ffd_bins) + + # Convert to list of index lists (discard sizes) + return [[idx for idx, _ in b] for b in bins] + + +def get_packer( + algorithm: Union[PackingAlgorithm, str], + bin_capacity: int, + collect_metrics: bool = False, +) -> SequencePacker: + """Factory function to get a sequence packer based on the algorithm. + + Args: + algorithm: The packing algorithm to use. Can be either a PackingAlgorithm enum value + or a string (case-insensitive) matching one of the enum names. + bin_capacity: The maximum capacity of each bin. + collect_metrics: Whether to collect metrics across multiple packing operations. + + Returns: + A SequencePacker instance for the specified algorithm. + + Raises: + ValueError: If the algorithm is not recognized. + """ + packers: Dict[PackingAlgorithm, Type[SequencePacker]] = { + PackingAlgorithm.CONCATENATIVE: ConcatenativePacker, + PackingAlgorithm.FIRST_FIT_DECREASING: FirstFitDecreasingPacker, + PackingAlgorithm.FIRST_FIT_SHUFFLE: FirstFitShufflePacker, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING: ModifiedFirstFitDecreasingPacker, + } + + # Convert string to enum if needed + if isinstance(algorithm, str): + try: + algorithm = PackingAlgorithm[algorithm.upper()] + except KeyError: + available_algorithms = ", ".join([alg.name for alg in PackingAlgorithm]) + raise ValueError( + f"Unknown packing algorithm: {algorithm}. " + f"Available algorithms: {available_algorithms}" + ) + + if algorithm not in packers: + available_algorithms = ", ".join([alg.name for alg in PackingAlgorithm]) + raise ValueError( + f"Unknown packing algorithm: {algorithm}. " + f"Available algorithms: {available_algorithms}" + ) + + return packers[algorithm](bin_capacity, collect_metrics=collect_metrics) diff --git a/nemo_rl/data/packing/metrics.py b/nemo_rl/data/packing/metrics.py new file mode 100644 index 0000000000..f4c8da0aae --- /dev/null +++ b/nemo_rl/data/packing/metrics.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics for evaluating sequence packing algorithms.""" + +import math +import statistics +from typing import Dict, List, Optional + + +class PackingMetrics: + """Class for tracking and computing metrics for sequence packing algorithms. + + This class provides methods to calculate various metrics that evaluate the + efficiency and effectiveness of sequence packing algorithms, such as bin + utilization, waste, and imbalance. + """ + + def __init__(self): + """Initialize the metrics tracker.""" + self.reset() + + def reset(self) -> None: + """Reset all metrics.""" + # Counters for aggregated metrics + self.total_sequences = 0 + self.total_bins = 0 + self.total_sequence_length = 0 + self.total_bin_capacity = 0 + self.total_waste = 0 + self.bin_utilizations = [] + self.bin_counts = [] + self.packing_times = [] + + # Tracking best and worst cases + self.min_utilization = 1.0 + self.max_utilization = 0.0 + self.min_waste_ratio = 1.0 + self.max_waste_ratio = 0.0 + + def update( + self, + sequence_lengths: List[int], + bins: List[List[int]], + bin_capacity: int, + packing_time: Optional[float] = None, + ) -> Dict[str, float]: + """Update metrics with a new packing solution. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + bin_capacity: Maximum capacity of each bin + packing_time: Optional time taken to compute the packing solution + + Returns: + Dictionary of metrics for this packing solution + """ + # Calculate metrics for this solution + stats = self.calculate_stats_only(sequence_lengths, bins, bin_capacity) + + # Update counters + self.total_sequences += len(sequence_lengths) + self.total_bins += len(bins) + self.total_sequence_length += sum(sequence_lengths) + self.total_bin_capacity += len(bins) * bin_capacity + self.total_waste += stats["total_waste"] + self.bin_utilizations.append(stats["average_utilization"]) + self.bin_counts.append(len(bins)) + + if packing_time is not None: + self.packing_times.append(packing_time) + + # Update min/max values + self.min_utilization = min(self.min_utilization, stats["average_utilization"]) + self.max_utilization = max(self.max_utilization, stats["average_utilization"]) + self.min_waste_ratio = min(self.min_waste_ratio, stats["waste_ratio"]) + self.max_waste_ratio = max(self.max_waste_ratio, stats["waste_ratio"]) + + return stats + + def calculate_stats_only( + self, sequence_lengths: List[int], bins: List[List[int]], bin_capacity: int + ) -> Dict[str, float]: + """Calculate metrics for a packing solution without updating the tracker. + + Args: + sequence_lengths: List of sequence lengths + bins: List of bins, where each bin is a list of indices + bin_capacity: Maximum capacity of each bin + + Returns: + Dictionary of metrics for this packing solution + """ + if not bins: + return { + "num_sequences": 0, + "num_bins": 0, + "total_sequence_length": 0, + "total_bin_capacity": 0, + "total_waste": 0, + "average_utilization": 0.0, + "waste_ratio": 0.0, + "bin_balance": 0.0, + "theoretical_min_bins": 0, + "bin_efficiency": 0.0, + } + + # Calculate bin loads + bin_loads = [ + sum(sequence_lengths[idx] for idx in bin_indices) for bin_indices in bins + ] + + # Calculate basic metrics + num_sequences = len(sequence_lengths) + num_bins = len(bins) + total_sequence_length = sum(sequence_lengths) + total_bin_capacity = num_bins * bin_capacity + total_waste = total_bin_capacity - total_sequence_length + + # Calculate utilization metrics + bin_utilizations = [load / bin_capacity for load in bin_loads] + average_utilization = total_sequence_length / total_bin_capacity + waste_ratio = total_waste / total_bin_capacity + + # Calculate bin balance metrics (standard deviation of utilization) + if num_bins > 1: + bin_balance = 1.0 - statistics.stdev(bin_utilizations) / average_utilization + else: + bin_balance = 1.0 + + # Calculate theoretical minimum number of bins + theoretical_min_bins = math.ceil(total_sequence_length / bin_capacity) + + # Calculate bin efficiency (ratio of theoretical min bins to actual bins) + bin_efficiency = theoretical_min_bins / num_bins if num_bins > 0 else 0.0 + + return { + "num_sequences": num_sequences, + "num_bins": num_bins, + "total_sequence_length": total_sequence_length, + "total_bin_capacity": total_bin_capacity, + "total_waste": total_waste, + "average_utilization": average_utilization, + "waste_ratio": waste_ratio, + "bin_balance": bin_balance, + "theoretical_min_bins": theoretical_min_bins, + "bin_efficiency": bin_efficiency, + } + + def get_aggregated_stats(self) -> Dict[str, float]: + """Get aggregated metrics across all packing operations. + + Returns: + Dictionary of aggregated metrics + """ + if not self.bin_utilizations: + return {} + + # Calculate aggregated metrics + avg_utilization = ( + self.total_sequence_length / self.total_bin_capacity + if self.total_bin_capacity > 0 + else 0.0 + ) + avg_waste_ratio = ( + self.total_waste / self.total_bin_capacity + if self.total_bin_capacity > 0 + else 0.0 + ) + avg_bin_count = ( + sum(self.bin_counts) / len(self.bin_counts) if self.bin_counts else 0.0 + ) + + # Calculate theoretical minimum number of bins + theoretical_min_bins = ( + math.ceil( + self.total_sequence_length / (self.total_bin_capacity / self.total_bins) + ) + if self.total_bins > 0 + else 0 + ) + + # Calculate bin efficiency (ratio of theoretical min bins to actual bins) + bin_efficiency = ( + theoretical_min_bins / self.total_bins if self.total_bins > 0 else 0.0 + ) + + # Calculate average packing time if available + avg_packing_time = ( + sum(self.packing_times) / len(self.packing_times) + if self.packing_times + else None + ) + + stats = { + "total_sequences": self.total_sequences, + "total_bins": self.total_bins, + "average_utilization": avg_utilization, + "min_utilization": self.min_utilization, + "max_utilization": self.max_utilization, + "average_waste_ratio": avg_waste_ratio, + "min_waste_ratio": self.min_waste_ratio, + "max_waste_ratio": self.max_waste_ratio, + "average_bin_count": avg_bin_count, + "bin_efficiency": bin_efficiency, + } + + if avg_packing_time is not None: + stats["average_packing_time"] = avg_packing_time + + return stats + + def print_aggregated_stats(self) -> None: + """Print the aggregated metrics in a formatted way.""" + stats = self.get_aggregated_stats() + + if not stats: + print("No metrics collected yet.") + return + + print("\n=== Packing Metrics Summary ===") + print(f"Total sequences packed: {stats['total_sequences']}") + print(f"Total bins used: {stats['total_bins']}") + print( + f"Average bin utilization: {stats['average_utilization']:.4f} (min: {stats['min_utilization']:.4f}, max: {stats['max_utilization']:.4f})" + ) + print( + f"Average waste ratio: {stats['average_waste_ratio']:.4f} (min: {stats['min_waste_ratio']:.4f}, max: {stats['max_waste_ratio']:.4f})" + ) + print( + f"Bin efficiency (theoretical min bins / actual bins): {stats['bin_efficiency']:.4f}" + ) + + if "average_packing_time" in stats: + print(f"Average packing time: {stats['average_packing_time']:.6f} seconds") + + print("===============================\n") diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index dc30d68364..abfb131c7c 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -28,6 +28,7 @@ import torch from typing_extensions import Self +from nemo_rl.data.packing import get_packer from nemo_rl.distributed.collectives import ( gather_jagged_object_lists, rebalance_nd_tensor, @@ -36,6 +37,21 @@ DictT = TypeVar("DictT", bound=Mapping[str, Any]) +class SequencePackingArgs(TypedDict): + """Configuration settings for sequence packing. + + Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing. + """ + + max_tokens_per_microbatch: int + input_key: str + input_lengths_key: str + algorithm: str + sequence_length_pad_multiple: ( + int # pad each sequence to a multiple of this value (for CP/TP alignment) + ) + + class DynamicBatchingArgs(TypedDict): """Configuration settings for dynamic batching. @@ -58,6 +74,7 @@ def __init__(self, *args, **kwargs): self.micro_batch_indices = None self.micro_batch_lengths = None + self.elem_counts_per_gb = None @classmethod def from_batches( @@ -204,6 +221,7 @@ def shard_by_batch_size( batch_size: Optional[int] = None, allow_uneven_shards: bool = False, dynamic_batching_args: Optional[DynamicBatchingArgs] = None, + sequence_packing_args: Optional[SequencePackingArgs] = None, ) -> list["SlicedDataDict"] | tuple[list["SlicedDataDict"], list[int]]: """Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. @@ -219,7 +237,7 @@ def shard_by_batch_size( allow_uneven_shards (bool): Whether to allow shards to be unevenly sized. If True, the last shard may be smaller than the others. dynamic_batching_args (dict): If passed, preprocess batch for dynamic batching. This - dict requires two keys: + dict requires four keys: 1. max_tokens_per_microbatch (int): the maximum number of tokens in a microbatch 2. sequence_length_round (int): round each all @@ -229,6 +247,19 @@ def shard_by_batch_size( 4. input_lengths_key (str): the key in the batch which holds the sequence length per value. The sequence dim index is assumed to be 1. + Cannot be passed with sequence_packing_args. + + sequence_packing_args (dict): If passed, preprocess batch for sequence packing. This + dict requires three keys: + 1. max_tokens_per_microbatch (int): the maximum + number of tokens in a microbatch + 2. input_key (str): the key in the batch + which holds input ids. + 3. input_lengths_key (str): the key in the batch + which holds the sequence length per value. + The sequence dim index is assumed to be 1. + 4. algorithm (str): the algorithm to use for sequence packing. + Cannot be passed with dynamic_batching_args. Returns: list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. @@ -268,6 +299,9 @@ def shard_by_batch_size( assert batch_size is None, ( "batch_size must be None if allow_uneven_shards is True" ) + assert dynamic_batching_args is None or sequence_packing_args is None, ( + "dynamic_batching_args and sequence_packing_args cannot be passed together" + ) # Get the total batch size batch_sizes = set() @@ -336,6 +370,112 @@ def shard_by_batch_size( else: sorted_v = [v[i] for i in batch_sorted_indices] data[k] = sorted_v + + elif sequence_packing_args is not None: + bin_packer = get_packer( + algorithm=sequence_packing_args["algorithm"], + bin_capacity=sequence_packing_args["max_tokens_per_microbatch"], + collect_metrics=False, # TODO(ahmadki): make configurable + ) + + input_lengths_key = sequence_packing_args["input_lengths_key"] + input_lens = self.data[input_lengths_key] + if not isinstance(input_lens, torch.Tensor): + input_lens = torch.tensor(input_lens) + + pad_multiple = sequence_packing_args["sequence_length_pad_multiple"] + + def _get_padded_seqlen(seqlen: int) -> int: + return (seqlen + pad_multiple - 1) // pad_multiple * pad_multiple + + # Store bin assignments for each chunk to reuse later + all_chunk_bin_assignments = [] + + # Process each chunk separately to respect chunk boundaries + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * batch_size + chunk_end = (chunk_idx + 1) * batch_size + + # Get sequence lengths for this chunk + chunk_seqlens = input_lens[chunk_start:chunk_end] + chunk_padded_seqlens_list = [ + _get_padded_seqlen(seq_len.item()) for seq_len in chunk_seqlens + ] + + # Pack sequences in this chunk into bins + chunk_bin_assignments = bin_packer.pack( + sequence_lengths=chunk_padded_seqlens_list, + ) + all_chunk_bin_assignments.append(chunk_bin_assignments) + + # create shards with the packed bins + sharded_data: list[list[dict]] = [[] for _ in range(shards)] + sharded_micro_indices: list = [[] for _ in range(shards)] + sharded_micro_lengths: list = [[] for _ in range(shards)] + sharded_elem_counts_per_gb: list = [[] for _ in range(shards)] + global_indices_per_shard: list[list[int]] = [[] for _ in range(shards)] + for chunk_idx in range(num_chunks): + chunk_sharded_micro_indices: list[list[list[int]]] = [ + [] for _ in range(shards) + ] + chunk_sharded_micro_lengths: list[list[int]] = [ + [] for _ in range(shards) + ] + + num_bins = len(all_chunk_bin_assignments[chunk_idx]) + chunk_start = chunk_idx * batch_size + for bin_idx in range(num_bins): + shard_idx = bin_idx % shards + bin_indices = all_chunk_bin_assignments[chunk_idx][bin_idx] + global_bin_indices = [i + chunk_start for i in bin_indices] + sharded_data[shard_idx].append( + self.select_indices(global_bin_indices) + ) + global_indices_per_shard[shard_idx].extend(global_bin_indices) + bin_seqlen = sum( + [ + _get_padded_seqlen(input_lens[i].item()) + for i in global_bin_indices + ] + ) + + if chunk_sharded_micro_indices[shard_idx] == []: + chunk_sharded_micro_indices[shard_idx].append( + [0, len(bin_indices)] + ) + else: + prev_bin_end = chunk_sharded_micro_indices[shard_idx][-1][1] + chunk_sharded_micro_indices[shard_idx].append( + [prev_bin_end, prev_bin_end + len(bin_indices)] + ) + chunk_sharded_micro_lengths[shard_idx].append(bin_seqlen) + + for shard_idx in range(shards): + sharded_micro_indices[shard_idx].append( + chunk_sharded_micro_indices[shard_idx] + ) + sharded_micro_lengths[shard_idx].append( + chunk_sharded_micro_lengths[shard_idx] + ) + sharded_elem_counts_per_gb[shard_idx].append( + chunk_sharded_micro_indices[shard_idx][-1][1] + ) + + # flatten global_indices_per_shard + batch_sorted_indices = [] + for shard_idx in range(shards): + batch_sorted_indices.extend(global_indices_per_shard[shard_idx]) + + aggregated_shards = [] + for shard_idx in range(shards): + shard = SlicedDataDict.from_batches(sharded_data[shard_idx]) + shard.micro_batch_indices = sharded_micro_indices[shard_idx] + shard.micro_batch_lengths = sharded_micro_lengths[shard_idx] + shard.elem_counts_per_gb = sharded_elem_counts_per_gb[shard_idx] + aggregated_shards.append(shard) + + return aggregated_shards, batch_sorted_indices + else: data = self.data @@ -457,7 +597,7 @@ def shard_by_batch_size( return aggregated_shards - def get_batch(self, batch_idx, batch_size) -> "SlicedDataDict": + def get_batch(self, batch_idx, batch_size=None) -> "SlicedDataDict": """Slices a subbatch from the batch. Args: @@ -467,6 +607,21 @@ def get_batch(self, batch_idx, batch_size) -> "SlicedDataDict": Returns: BatchedDataDict: A new BatchedDataDict containing the sliced data """ + if self.elem_counts_per_gb is not None: + assert self.micro_batch_indices is not None, ( + "micro_batch_indices must be provided if sequence_packing is True" + ) + elem_count = self.elem_counts_per_gb[batch_idx] + cum_elem_count = [0] + for i in range(len(self.elem_counts_per_gb)): + cum_elem_count.append(cum_elem_count[i] + self.elem_counts_per_gb[i]) + + batch = self.slice(cum_elem_count[batch_idx], cum_elem_count[batch_idx + 1]) + batch.micro_batch_indices = [self.micro_batch_indices[batch_idx]] + batch.micro_batch_lengths = [self.micro_batch_lengths[batch_idx]] # type: ignore # This exists if idxs do + batch.elem_counts_per_gb = [elem_count] + return batch + start = batch_size * batch_idx end = batch_size * (batch_idx + 1) batch = self.slice(start, end) @@ -520,7 +675,7 @@ def make_microbatch_iterator_with_dynamic_shapes( self, sequence_dim: int = 1, ) -> Iterator["SlicedDataDict"]: - """Makes an interator that yields microbatchs of dynamic batch and sequence sizes. + """Makes an iterator that yields microbatchs of dynamic batch and sequence sizes. Args: sequence_dim: the index of the sequence dim for all tensors in the data dict @@ -542,9 +697,29 @@ def make_microbatch_iterator_with_dynamic_shapes( yield mb def get_microbatch_iterator_dynamic_shapes_len(self) -> int: - """Get the length of the microbatch iterator with dynamic shapes.""" + """Get the length of the microbatch iterator for dynamic shapes.""" return len(self.micro_batch_indices[0]) + def make_microbatch_iterator_for_packable_sequences( + self, + ) -> Iterator["SlicedDataDict"]: + """Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch.""" + assert ( + self.micro_batch_indices is not None + and len(self.micro_batch_indices) == 1 + and self.micro_batch_lengths is not None + ) + + for seqlen, (start_idx, end_idx) in zip( + self.micro_batch_lengths[0], self.micro_batch_indices[0] + ): + mb = self.slice(start_idx, end_idx) + yield mb + + def get_microbatch_iterator_for_packable_sequences_len(self) -> tuple[int, int]: + """Get the length of the microbatch iterator for sequence packing and the max packed seqlen.""" + return len(self.micro_batch_indices[0]), max(self.micro_batch_lengths[0]) + def make_microbatch_iterator( self, microbatch_size: int ) -> Iterator["SlicedDataDict"]: diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 31ac71cc23..2f54b97265 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -188,3 +188,96 @@ def from_parallel_logits_to_logprobs( assert probs.shape == target_shape return probs[:, :-1] + + +def from_parallel_logits_to_logprobs_packed_sequences( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + cu_seqlens: torch.Tensor, + unpacked_seqlen: int, + vocab_start_index: int, + vocab_end_index: int, + group: torch.distributed.ProcessGroup, + inference_only: bool = False, +) -> torch.Tensor: + """Get log probabilities from TP sharded vocab logits for packed sequences. + + Args: + vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T, vocab_size//TP] + where T is the total number of tokens across all packed sequences. + target (torch.Tensor): Packed target token indices with shape [1, T]. + NOTE: Must be the unmodified targets as this function will shift them internally. + cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. + cu_seqlens[i] indicates the start position of sequence i in the packed format. + unpacked_seqlen (int): The length of the unpacked sequence tensor. + vocab_start_index (int): Starting vocabulary index for this worker's partition. + vocab_end_index (int): Ending vocabulary index for this worker's partition. + group (torch.distributed.ProcessGroup): Process group for distributed communication. + inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + + Returns: + torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. + The total length is reduced by batch_size due to target shifting (one token per sequence). + """ + # Remove batch dimension to work with [T, vocab_size] and [T] + vocab_parallel_logits = vocab_parallel_logits.squeeze(0) + target = target.squeeze(0) + + batch_size = cu_seqlens.shape[0] - 1 + + # Roll each sequence individually + rolled_targets = torch.zeros_like(target) + for i in range(batch_size): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + + # Get the sequence targets and roll by -1 + seq_targets = target[start_idx:end_idx] + rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) + rolled_targets[start_idx:end_idx] = rolled_seq_targets + + # Add batch dimension back for DistributedLogprob + rolled_targets = rolled_targets.unsqueeze(0) + vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) + + # Apply distributed log probability computation + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() + + # Remove batch dimension for filtering + probs = probs.squeeze(0) + + # Ensure probs is 1D after squeezing + if probs.dim() != 1: + raise ValueError( + f"Expected probs to be 1D after squeezing, but got shape {probs.shape}. " + f"Original shape before squeeze: {probs.unsqueeze(0).shape}" + ) + + out_logprobs = torch.zeros( + (batch_size, unpacked_seqlen - 1), dtype=probs.dtype, device=probs.device + ) + # Filter out the last token of each sequence + for i in range(batch_size): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + + # Exclude the last position (which has the rolled target from position 0) + if end_idx - start_idx > 0: + seq_probs = probs[start_idx : end_idx - 1] + # Ensure seq_probs is 1D + if seq_probs.dim() > 1: + seq_probs = seq_probs.squeeze() + + # Ensure we don't exceed the unpacked sequence length + seq_len = min(seq_probs.shape[0], unpacked_seqlen - 1) + if seq_len > 0: + out_logprobs[i, :seq_len] = seq_probs[:seq_len] + + return out_logprobs diff --git a/nemo_rl/models/huggingface/common.py b/nemo_rl/models/huggingface/common.py index df913f95b4..cdfab8ef04 100644 --- a/nemo_rl/models/huggingface/common.py +++ b/nemo_rl/models/huggingface/common.py @@ -12,10 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from enum import Enum, auto +from typing import Optional, Tuple, TypeVar +import torch from transformers import AutoConfig +Tensor = TypeVar("Tensor", bound=torch.Tensor) + + +@dataclass +class FlashAttentionKwargs: + """Dataclass to hold FlashAttention v2 kwargs.""" + + cu_seqlens_q: Tensor + cu_seqlens_k: Tensor + max_seqlen_q: int + max_seqlen_k: int + class ModelFlag(Enum): """Enum that defines special flags for model-specific behaviors. @@ -53,3 +68,234 @@ def is_gemma_model(model_name: str) -> bool: "gemma3", "gemma3_text", ] + + +def group_and_cat_tensors( + tensors: list[torch.Tensor], group_sizes: list[int], padding_value: int = 0 +) -> torch.Tensor: + """Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. + + Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting + group tensors are padded to the same length and stacked into a 2D tensor. + + Args: + tensors: List of 1D tensors of varying lengths. + group_sizes: List of integers. Each integer specifies how many tensors to group. + padding_value: Integer used to pad shorter sequences. + + Returns: + A 2D tensor where each row is a padded concatenation of the grouped tensors. + + Example: + >>> tensors = [ + ... torch.tensor([1, 2]), + ... torch.tensor([3]), + ... torch.tensor([4, 5, 6]), + ... torch.tensor([7]) + ... ] + >>> group_sizes = [2, 2] + >>> group_and_cat_tensors(tensors, group_sizes, padding_value=-1) + tensor([[ 1, 2, 3, -1, -1], + [ 4, 5, 6, 7, -1]]) + """ + grouped = [] + index = 0 + for size in group_sizes: + group = tensors[index : index + size] + concat = torch.cat(group, dim=0) + grouped.append(concat) + index += size + + # Compute the maximum length for padding + max_len = max(t.size(0) for t in grouped) + + # Pad each tensor to max_len + padded = torch.stack( + [ + torch.nn.functional.pad(t, (0, max_len - t.size(0)), value=padding_value) + for t in grouped + ] + ) + + return padded + + +def pack_sequences( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + packed_sequence_size: list[int], + padding_value: int = 0, + return_attention_mask: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Packs sequences into rows where each row concatenates multiple sequences. + + Useful for sequence packing in transformer models (e.g. for SFT training). Returns: + packed input_ids, packed position_ids, and optional attention_mask. + + Args: + input_ids (torch.Tensor): Tensor of shape [num_sequences, max_seq_len] + input_lengths (torch.Tensor): Tensor of shape [num_sequences], containing true lengths + packed_sequence_size (List[int]): How many sequences to pack per row + padding_value (int): Pad value for input_ids + return_attention_mask (bool): Whether to return per-row causal attention mask + + Returns: + Tuple: + input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] + position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] + attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested + + Example: + >>> input_ids = torch.tensor([ + ... [1, 2, 0, 0], # len 2 + ... [3, 4, 5, 0], # len 3 + ... [6, 0, 0, 0], # len 1 + ... [7, 8, 9, 9], # len 4 + ... [8, 7, 0, 0], # len 2 + ... [6, 0, 0, 0], # len 1 + ... [5, 4, 3, 0], # len 3 + ... ]) + >>> input_lengths = torch.tensor([2, 3, 1, 4, 2, 1, 3]) + >>> packed_sequence_size = [3, 4] + >>> input_ids_packed, position_ids_packed, attention_mask = pack_sequences( + ... input_ids, input_lengths, packed_sequence_size, padding_value=-1, return_attention_mask=True + ... ) + >>> input_ids_packed + tensor([ + [ 1, 2, 3, 4, 5, 6, -1, -1, -1, -1], + [ 7, 8, 9, 9, 8, 7, 6, 5, 4, 3] + ]) + >>> position_ids_packed + tensor([ + [0, 1, 0, 1, 2, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 0, 1, 0, 0, 1, 2] + ]) + >>> attention_mask[0] + tensor([ + [ True, True, False, False, False, False, False, False, False, False], + [False, False, True, True, True, False, False, False, False, False], + [False, False, False, False, False, True, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False], + ]) + >>> attention_mask[1] + tensor([ + [ True, True, True, True, False, False, False, False, False, False], + [False, False, False, False, True, True, True, False, False, False], + [False, False, False, False, False, False, True, True, True, True], + [False, False, False, False, False, False, False, True, True, True], + ]) + """ + flat_input_ids = [] + position_ids = [] + flat_lengths = input_lengths.tolist() + + for i, seq_len in enumerate(flat_lengths): + flat_input_ids.append(input_ids[i, :seq_len]) + position_ids.append( + torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + ) + + # Group and pad + input_ids_packed = group_and_cat_tensors( + flat_input_ids, packed_sequence_size, padding_value + ) + position_ids_packed = group_and_cat_tensors( + position_ids, packed_sequence_size, padding_value=0 + ) + + # Compute max length + batch_size, max_seq_len = input_ids_packed.shape + + attention_mask = None + if return_attention_mask: + attention_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + index = 0 + for i, group_size in enumerate(packed_sequence_size): + group_lengths = flat_lengths[index : index + group_size] + total_len = sum(group_lengths) + attention_mask[i, :total_len, :total_len] = torch.tril( + torch.ones( + (total_len, total_len), dtype=torch.bool, device=input_ids.device + ) + ) + index += group_size + + return input_ids_packed, position_ids_packed, attention_mask + + +# TODO(ahmadki): the function doesn't actually handle returning 2D tensors because none of the backends support this. +# but we should support this anyways +def unpack_tensor(tensor, input_lengths): + """Unpacks a packed tensor into individual sequences padded to the same length. + + Args: + tensor (torch.Tensor): Packed tensor of shape [batch_size, packed_seq_len]. + packed_lengths (List[int]): Original sequence lengths in the order they were packed. + + Returns: + torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence. + + Example: + >>> packed_tensor = torch.tensor([ + ... [1, 2, 3, 4, 5, 6, -1, -1], + ... [7, 8, 9, 9, 8, 7, 6, -1] + ... ]) + >>> packed_lengths = [2, 3, 1, 4, 2] + >>> unpack_tensor(packed_tensor, packed_lengths) + tensor([ + [1, 2, 0, 0], + [3, 4, 5, 0], + [6, 0, 0, 0], + [7, 8, 9, 9], + [8, 7, 0, 0], + ]) + """ + packed_seqlen = tensor.shape[1] + splitsizes = input_lengths.tolist() + splitsizes.append(packed_seqlen - sum(splitsizes)) + tensor_split = torch.split(tensor, tuple(splitsizes), dim=1) + + max_len = max(input_lengths.tolist()) # max sequence length in the batch + + tensor_stacked = [] + for t in tensor_split[0:-1]: + padding_needed = max_len - t.shape[1] + tensor_stacked.append( + torch.nn.functional.pad( + t, (0, 0, 0, padding_needed), mode="constant", value=0.0 + ) + ) + return torch.cat(tensor_stacked, dim=0) + + +def get_flash_attention_kwargs(input_lengths: torch.Tensor) -> FlashAttentionKwargs: + """Returns kwargs required for FlashAttention v2 forward functions. + + Args: + input_lengths (torch.Tensor): [batch_size] containing lengths of each sequence + + Returns: + Dict[str, torch.Tensor | int]: + { + "cu_seqlens_q": Tensor[int32], + "cu_seqlens_k": Tensor[int32], + "max_seqlen_q": int, + "max_seqlen_k": int + } + """ + input_lengths_int32 = input_lengths.to(torch.int32) + cu_seqlens = torch.nn.functional.pad( + input_lengths_int32.cumsum(dim=0), (1, 0) + ) # prepend 0 + max_len = input_lengths.max().item() + + return FlashAttentionKwargs( + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens.clone(), # same for self-attention + max_seqlen_q=max_len, + max_seqlen_k=max_len, + ) diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 5c6431b15e..15e3a8a0e6 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -12,22 +12,224 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Iterator +from typing import Any, Iterator, Optional import torch import torch.distributed as dist from megatron.core.models.gpt import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, ) from megatron.training.utils import get_ltor_masks_and_position_ids from nemo.tron.state import GlobalState -from nemo_rl.algorithms.loss_functions import LossFunction +from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict +def _pack_sequences_for_megatron( + input_ids: torch.Tensor, + seq_lengths: torch.Tensor, + pad_individual_seqs_to_multiple_of: int = 1, + pad_packed_seq_to: Optional[int] = None, +) -> tuple[torch.Tensor, PackedSeqParams, torch.Tensor, Optional[torch.Tensor]]: + """Pack sequences for Megatron model processing with optional context parallelism. + + Args: + input_ids: Input token IDs [batch_size, seq_length] + seq_lengths: Actual sequence lengths for each sample [batch_size] + pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value + pad_packed_seq_to: Pad packed sequences to this value + + Returns: + Tuple of: + - packed_input_ids: Packed input tensor [1, T] + - packed_seq_params: PackedSeqParams object + - cu_seqlens: Cumulative sequence lengths + - cu_seqlens_padded: Padded cumulative sequence lengths (if CP > 1) + """ + batch_size = input_ids.shape[0] + + # Build cumulative sequence lengths (cu_seqlens) and extract valid tokens + cu_seqlens = [0] + cu_seqlens_padded = ( + [0] + if pad_individual_seqs_to_multiple_of > 1 or pad_packed_seq_to is not None + else None + ) + valid_tokens = [] + + pad_factor = pad_individual_seqs_to_multiple_of + + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + # Extract valid tokens for this sequence + valid_tokens.append(input_ids[b, :seq_len]) + + # Update cumulative sequence lengths + cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # For context parallelism, track padded sequence lengths + if pad_factor > 1 or pad_packed_seq_to is not None: + # Pad sequence length to multiple of (cp_size * 2) + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len) + + # Convert to tensors + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=input_ids.device) + if pad_factor > 1 or pad_packed_seq_to is not None: + cu_seqlens_padded = torch.tensor( + cu_seqlens_padded, dtype=torch.int32, device=input_ids.device + ) + if pad_packed_seq_to is not None: + cu_seqlens_padded[-1] = pad_packed_seq_to + + # Calculate max sequence length (padded if using CP) + if pad_factor > 1 or (pad_packed_seq_to is not None): + seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + max_seqlen = seq_lens_padded.max().item() + else: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + + # Concatenate all valid tokens + # If using individual padding, we need to pad individual sequences + if pad_factor > 1: + padded_tokens = [] + for b in range(batch_size): + seq_len = ( + seq_lengths[b].item() + if torch.is_tensor(seq_lengths[b]) + else seq_lengths[b] + ) + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + + # Pad this sequence to the required length + seq_tokens = input_ids[b, :seq_len] + if padded_seq_len > seq_len: + # Pad with zeros (or use a padding token if available) + padding = torch.zeros( + padded_seq_len - seq_len, + dtype=seq_tokens.dtype, + device=seq_tokens.device, + ) + seq_tokens = torch.cat([seq_tokens, padding]) + + padded_tokens.append(seq_tokens) + + # Concatenate all padded tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(padded_tokens, dim=0).unsqueeze(0) + else: + # No individual padding, just concatenate valid tokens + # For 'thd' format, the shape should be [1, T] where T is total tokens + packed_input_ids = torch.cat(valid_tokens, dim=0).unsqueeze(0) + + if pad_packed_seq_to is not None: + packed_input_ids = torch.cat( + [ + packed_input_ids, + torch.zeros( + 1, + pad_packed_seq_to - packed_input_ids.shape[1], + dtype=packed_input_ids.dtype, + device=packed_input_ids.device, + ), + ], + dim=1, + ) + + if cu_seqlens_padded is None: + cu_seqlens_padded = cu_seqlens.clone() + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=int(max_seqlen), + max_seqlen_kv=int(max_seqlen), + qkv_format="thd", + ) + + return ( + packed_input_ids.contiguous(), + packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) + + +def _unpack_sequences_from_megatron( + output_tensor: torch.Tensor, + seq_lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_padded: Optional[torch.Tensor], + original_batch_size: int, + original_seq_length: int, +) -> torch.Tensor: + """Unpack sequences from Megatron output format. + + Args: + output_tensor: Packed output tensor [1, T, vocab_size] + seq_lengths: Actual sequence lengths for each sample + cu_seqlens: Cumulative sequence lengths + cu_seqlens_padded: Padded cumulative sequence lengths (if CP was used) + original_batch_size: Original batch size + original_seq_length: Original maximum sequence length + + Returns: + Unpacked output tensor [batch_size, seq_length, vocab_size] + """ + # Remove the batch dimension to get [T, vocab_size] + output_tensor = output_tensor.squeeze(0) + + # Create a padded output tensor with original shape + vocab_size = output_tensor.shape[-1] + unpacked_output = torch.zeros( + (original_batch_size, original_seq_length, vocab_size), + dtype=output_tensor.dtype, + device=output_tensor.device, + ) + + # Get context parallel size to determine which cu_seqlens to use + cp_size = get_context_parallel_world_size() + + # Fill in the unpacked output tensor with valid tokens + for b in range(original_batch_size): + # Get actual sequence length for this sample + seq_len = ( + seq_lengths[b].item() if torch.is_tensor(seq_lengths[b]) else seq_lengths[b] + ) + + if cp_size > 1 and cu_seqlens_padded is not None: + # When using CP, we need to account for padding + # Calculate the padded sequence boundaries + pad_factor = cp_size * 2 + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + start_idx = cu_seqlens_padded[b].item() + + # Only copy the valid tokens (not the padding) + unpacked_output[b, :seq_len] = output_tensor[ + start_idx : start_idx + seq_len + ] + else: + # No CP, use regular cu_seqlens + start_idx = cu_seqlens[b].item() + end_idx = cu_seqlens[b + 1].item() + + # Copy the valid tokens to the unpacked tensor + unpacked_output[b, :seq_len] = output_tensor[start_idx:end_idx] + + return unpacked_output + + def forward_step_arbitrary_loss( state: GlobalState, global_valid_seqs: torch.Tensor, @@ -35,6 +237,10 @@ def forward_step_arbitrary_loss( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel, loss_fn: LossFunction, + pack_sequences: bool = False, + seq_length_key: Optional[str] = None, + pad_individual_seqs_to_multiple_of: int = 1, + pad_full_seq_to: Optional[int] = None, ): """Forward training step with support for packed sequences and context parallelism. @@ -45,18 +251,76 @@ def forward_step_arbitrary_loss( data_iterator: Input data iterator model (GPTModel): The GPT Model loss_fn (LossFunction): Loss function to apply + pack_sequences (bool): Whether to pack sequences for efficiency + seq_length_key (Optional[str]): Key in data_dict containing actual sequence lengths + + Notes on packed sequences with context parallelism (CP): + - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) + - The factor of 2 ensures load balancing for causal attention + - cu_seqlens tracks actual sequence boundaries + - cu_seqlens_padded tracks padded sequence boundaries for CP + - TransformerEngine automatically distributes tokens across CP ranks + - Requires TransformerEngine >= 1.10 for CP support """ straggler_timer = state.straggler_timer with straggler_timer(bdata=True): data_dict = next(data_iterator).to("cuda") input_ids = data_dict["input_ids"] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - input_ids, 0, False, False, False - ) + attention_mask = None + position_ids = None + packed_seq_params = None + + original_batch_size = input_ids.shape[0] + original_seq_length = input_ids.shape[1] + seq_lengths = None # Will be set if using packed sequences + cu_seqlens = None + cu_seqlens_padded = None + + if pack_sequences: + # For packed sequences with padded input, we need sequence lengths + assert seq_length_key is not None, ( + "seq_length_key must be provided for packed sequences" + ) + assert seq_length_key in data_dict, ( + f"{seq_length_key} not found in data_dict" + ) + + # Get sequence lengths and context parallel size + seq_lengths = data_dict[seq_length_key] + + # Pack sequences + input_ids, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of, + pad_full_seq_to, + ) + ) + + # For packed sequences, position_ids and attention_mask are typically None + # The PackedSeqParams handles all necessary sequence information + position_ids = None + attention_mask = None + else: + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) with straggler_timer: - output_tensor = model(input_ids, position_ids, attention_mask) + output_tensor = model( + input_ids, position_ids, attention_mask, packed_seq_params=packed_seq_params + ) + + # Unpack the output tensor if we did packed sequences + if pack_sequences and packed_seq_params is not None: + # remove padding + loss_fn = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=packed_seq_params.cu_seqlens_q, + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q, + ) loss_data = data_dict diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index b04ff2d159..c36cbf5899 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -117,7 +117,7 @@ class DynamicBatchingConfig(TypedDict): # training and logprob stages respectively. enabled: bool train_mb_tokens: int - logprob_mb_tokens: int + logprob_mb_tokens: NotRequired[int] = None sequence_length_round: int diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index df8b4e734f..5ca03b47a2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -40,6 +40,7 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.dtensor.parallelize import ( _parallelize_model, @@ -48,7 +49,11 @@ get_logprobs_from_vocab_parallel_logits, to_local_if_dtensor, ) -from nemo_rl.models.huggingface.common import ModelFlag +from nemo_rl.models.huggingface.common import ( + ModelFlag, + get_flash_attention_kwargs, + pack_sequences, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( LogprobOutputSpec, @@ -170,6 +175,14 @@ def __init__( else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] + if self.enable_seq_packing: + print( + f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}" + ) + print(f"[Rank {self.rank}] Using FlashAttention2 for sequence packing") + model_config = AutoConfig.from_pretrained( model_name, # Always load the model in float32 to keep master weights in float32. @@ -179,6 +192,9 @@ def __init__( **sliding_window_overwrite( model_name ), # due to https://github.com/huggingface/transformers/issues/38002 + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, ) full_state_dict = None @@ -499,30 +515,45 @@ def train( # so its safe to not check for the case where the last data slice is smaller than mbs if self.cfg["dynamic_batching"]["enabled"]: mb_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() + elif self.enable_seq_packing: + mb_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) else: mb_iterator = batch.make_microbatch_iterator(mbs) for mb in mb_iterator: - input_ids = mb.get("input_ids").cuda() - input_lengths = mb.get("input_lengths") - batch_size, seq_len = input_ids.shape + with torch.autocast(device_type="cuda", dtype=self.dtype): + if self.enable_seq_packing: + input_ids = mb.get("input_ids").cuda() + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[ + len(mb["input_lengths"]) + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=mb["input_lengths"], + ) - attention_mask = torch.zeros( - (batch_size, seq_len), dtype=torch.long, device=input_ids.device - ) - for i, length in enumerate(input_lengths): - # For right-padded sequence, set 1s at the beginning of the sequence - attention_mask[i, :length] = 1 + else: + input_ids = mb.get("input_ids").cuda() + batch_size, seq_len = input_ids.shape - with torch.autocast(device_type="cuda", dtype=self.dtype): - attention_mask_input_all_ones = torch.ones( - (batch_size, seq_len), - dtype=torch.long, - device=input_ids.device, - ) - position_ids = torch.arange( - seq_len, device=input_ids.device - ).repeat(batch_size, 1) + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.long, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} context_parallel_ctx = None if self.cp_size > 1: @@ -547,9 +578,10 @@ def train( with torch.autocast(device_type="cuda", dtype=self.dtype): outputs = self.model( input_ids=input_ids, - attention_mask=attention_mask_input_all_ones, + attention_mask=attention_mask, position_ids=position_ids, use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, ) # Get logprobs @@ -612,8 +644,20 @@ def train( placements=[Shard(sequence_dim), Shard(-1)], ) - loss, loss_metrics = loss_fn( - logits, mb, global_valid_seqs, global_valid_toks + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, ) ## scale by the number of global batches so we get the correct diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5e82b61d72..173e7bdb7f 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -24,6 +24,7 @@ from nemo_rl.distributed.batched_data_dict import ( BatchedDataDict, DynamicBatchingArgs, + SequencePackingArgs, SlicedDataDict, ) from nemo_rl.distributed.named_sharding import NamedSharding @@ -143,6 +144,26 @@ def __init__( else: self.use_dynamic_batches = False + if config["sequence_packing"]["enabled"]: + assert ( + config["megatron_cfg"]["enabled"] or config["dtensor_cfg"]["enabled"] + ), "Sequence packing requires Megatron or DTensor policies." + self.use_sequence_packing = True + self.sequence_packing_args: SequencePackingArgs = { + "train_mb_tokens": config["sequence_packing"]["train_mb_tokens"], + "logprob_mb_tokens": config["sequence_packing"].get( + "logprob_mb_tokens", None + ), + "algorithm": config["sequence_packing"]["algorithm"], + "input_key": "input_ids", + "input_lengths_key": "input_lengths", + "sequence_length_pad_multiple": (cp_size * 2 * tp_size) + if cp_size > 1 + else tp_size, + } + else: + self.use_sequence_packing = False + self.cfg = config def init_collective( @@ -179,6 +200,15 @@ def get_logprobs( batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( # type: ignore cp_size * dp_size, @@ -208,7 +238,7 @@ def get_logprobs( # dynamic batching sorts the inputs by sequence length to improve load balancing, # so change it back here - if self.use_dynamic_batches: + if self.use_dynamic_batches or self.use_sequence_packing: logprobs.reorder_data(unsorted_data_indices) return logprobs @@ -235,6 +265,15 @@ def get_reference_policy_logprobs( batch_size=None, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( # type: ignore cp_size * dp_size, @@ -267,7 +306,7 @@ def get_reference_policy_logprobs( # dynamic batching sorts the inputs by sequence length to improve load balancing, # so change it back here - if self.use_dynamic_batches: + if self.use_dynamic_batches or self.use_sequence_packing: logprobs.reorder_data(unsorted_data_indices) return logprobs @@ -294,6 +333,15 @@ def train( batch_size=batch_size, dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + sequence_packing_args=self.sequence_packing_args, + ) else: sharded_data = data.shard_by_batch_size( dp_size, diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 691e1ce5b3..5519f42671 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -88,7 +88,10 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.distributed.model_utils import ( + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, +) from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, @@ -96,6 +99,7 @@ verify_right_padding, ) from nemo_rl.models.megatron.common import ( + _pack_sequences_for_megatron, broadcast_tensor, forward_step_arbitrary_loss, ) @@ -800,11 +804,33 @@ def train( ) batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + pack_seqs = False + seqlen_key = None + pad_factor = 1 + pad_full_seq_to = None if self.cfg["dynamic_batching"]["enabled"]: data_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() data_iterator_len = ( batch.get_microbatch_iterator_dynamic_shapes_len() ) + micro_batch_size = self.cfg["train_micro_batch_size"] + elif self.cfg["sequence_packing"]["enabled"]: + data_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + data_iterator_len, seq_dim_size = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + micro_batch_size = 1 + pack_seqs = True + seqlen_key = "input_lengths" + tp_size = self.cfg["megatron_cfg"]["tensor_model_parallel_size"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + if self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1: + _, pad_full_seq_to = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) else: data_iterator = batch.make_microbatch_iterator(mbs) data_iterator_len = local_gbs // mbs @@ -823,6 +849,10 @@ def train( self.mcore_state, global_valid_seqs, global_valid_toks, + pack_sequences=pack_seqs, + seq_length_key=seqlen_key, + pad_individual_seqs_to_multiple_of=pad_factor, + pad_full_seq_to=pad_full_seq_to, ), data_iterator=data_iterator, model=self.model, @@ -968,32 +998,79 @@ def get_logprobs( pp_grp = get_pipeline_model_parallel_group() pp_size = get_pipeline_model_parallel_world_size() + # if pp_size > 1, we need to pad the full sequence to the max sequence length to maintain a static PP buffer + if ( + self.cfg["sequence_packing"]["enabled"] + and self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1 + ): + _, pad_full_seq_to = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + pp_seq_dim_size = pad_full_seq_to + else: + pad_full_seq_to = None + def forward_step_fn( data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel ): + nonlocal pad_full_seq_to data_dict = next(data_iterator).to("cuda") - input_ids = data_dict["input_ids"] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - input_ids, 0, False, False, False - ) + if self.cfg["sequence_packing"]["enabled"]: + original_seq_length = data_dict["input_ids"].shape[1] + tp_size = self.cfg["megatron_cfg"]["tensor_model_parallel_size"] + pp_size = self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + input_ids, packed_seq_params, cu_seqlens, cu_seqlens_padded = ( + _pack_sequences_for_megatron( + data_dict["input_ids"].clone(), + data_dict["input_lengths"], + pad_individual_seqs_to_multiple_of=pad_factor, + pad_packed_seq_to=pad_full_seq_to, + ) + ) + input_ids = input_ids + attention_mask, position_ids = None, None + unpacked_input_ids = data_dict["input_ids"] + else: + input_ids = data_dict["input_ids"] + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) + packed_seq_params = None + unpacked_input_ids = input_ids output_tensor = model( input_ids, position_ids, attention_mask, + packed_seq_params=packed_seq_params, ) def collection_fn(output_tensor): + stc = time.time() tp_grp = get_tensor_model_parallel_group() tp_rank = get_tensor_model_parallel_rank() - token_logprobs = from_parallel_logits_to_logprobs( - output_tensor.to(torch.float32), - target=input_ids, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - tp_group=tp_grp, - inference_only=True, - ) + if self.cfg["sequence_packing"]["enabled"]: + token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( + output_tensor, + target=input_ids, + cu_seqlens=cu_seqlens_padded, + unpacked_seqlen=original_seq_length, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + group=tp_grp, + inference_only=True, + ) + else: + token_logprobs = from_parallel_logits_to_logprobs( + output_tensor.to(torch.float32), + target=unpacked_input_ids, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + tp_group=tp_grp, + inference_only=True, + ) # Prepend 0 logprob for first token to maintain same sequence length as input token_logprobs = torch.cat( @@ -1008,6 +1085,13 @@ def collection_fn(output_tensor): if self.cfg["dynamic_batching"]["enabled"]: mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() data_iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + micro_batch_size = logprob_batch_size + elif self.cfg["sequence_packing"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + data_iterator_len, _ = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + micro_batch_size = 1 else: mb_iterator = data.make_microbatch_iterator(logprob_batch_size) data_iterator_len = max(1, data.size // logprob_batch_size) diff --git a/pyproject.toml b/pyproject.toml index 62b78d6d4d..892005cf4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,6 @@ requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" - [tool.setuptools] packages = ["nemo_rl"] diff --git a/tests/unit/data/packing/__init__.py b/tests/unit/data/packing/__init__.py new file mode 100644 index 0000000000..913e5a1c57 --- /dev/null +++ b/tests/unit/data/packing/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sequence packing algorithms.""" diff --git a/tests/unit/data/packing/test_algorithms.py b/tests/unit/data/packing/test_algorithms.py new file mode 100644 index 0000000000..765de4f246 --- /dev/null +++ b/tests/unit/data/packing/test_algorithms.py @@ -0,0 +1,330 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sequence packing algorithms.""" + +import random +from typing import Dict, List + +import pytest + +from nemo_rl.data.packing.algorithms import ( + PackingAlgorithm, + SequencePacker, + get_packer, +) + + +def validate_solution( + sequence_lengths: List[int], bins: List[List[int]], bin_capacity: int +) -> bool: + """Validate that a packing solution is valid. + + Args: + sequence_lengths: The original list of sequence lengths. + bins: The packing solution, where each bin is a list of indices into sequence_lengths. + bin_capacity: The maximum capacity of each bin. + + Returns: + True if the packing is valid, False otherwise. + """ + # Check that all sequences are packed + all_indices = set() + for bin_indices in bins: + all_indices.update(bin_indices) + + if len(all_indices) != len(sequence_lengths): + return False + + # Check that each bin doesn't exceed capacity + for bin_indices in bins: + bin_load = sum(sequence_lengths[idx] for idx in bin_indices) + if bin_load > bin_capacity: + return False + + return True + + +class TestSequencePacker: + """Test suite for sequence packing algorithms.""" + + @pytest.fixture + def bin_capacity(self) -> int: + """Fixture for bin capacity.""" + return 100 + + @pytest.fixture + def small_sequence_lengths(self) -> List[int]: + """Fixture for a small list of sequence lengths.""" + return [10, 20, 30, 40, 50, 60, 70, 80, 90] + + @pytest.fixture + def medium_sequence_lengths(self) -> List[int]: + """Fixture for a medium-sized list of sequence lengths.""" + return [25, 35, 45, 55, 65, 75, 85, 95, 15, 25, 35, 45, 55, 65, 75, 85, 95] + + @pytest.fixture + def large_sequence_lengths(self) -> List[int]: + """Fixture for a large list of sequence lengths.""" + # Set a seed for reproducibility + random.seed(42) + return [random.randint(10, 90) for _ in range(100)] + + @pytest.fixture + def edge_cases(self) -> Dict[str, List[int]]: + """Fixture for edge cases.""" + return { + "empty": [], + "single_item": [50], + "all_same_size": [30, 30, 30, 30, 30], + "max_size": [100, 100, 100], + "mixed_sizes": [10, 50, 100, 20, 80, 30, 70, 40, 60, 90], + } + + # TODO(ahmadki): use the function to specify all test algorithms ins tead of lists below + @pytest.fixture + def algorithms(self) -> List[PackingAlgorithm]: + """Fixture for packing algorithms.""" + return [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ] + + def test_get_packer(self, bin_capacity: int, algorithms: List[PackingAlgorithm]): + """Test the get_packer factory function.""" + # Test that each algorithm name returns the correct packer + for algorithm in algorithms: + packer = get_packer(algorithm, bin_capacity) + assert isinstance(packer, SequencePacker) + + # Test with an invalid algorithm value + with pytest.raises(ValueError): + # Create a non-existent enum value by using an arbitrary object + invalid_algorithm = object() + get_packer(invalid_algorithm, bin_capacity) # type: ignore + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_small_sequences( + self, + bin_capacity: int, + small_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing small sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(small_sequence_lengths) + + # Validate the packing + assert validate_solution(small_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for small sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_medium_sequences( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing medium-sized sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(medium_sequence_lengths) + + # Validate the packing + assert validate_solution(medium_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for medium sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_large_sequences( + self, + bin_capacity: int, + large_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test packing large sequences with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(large_sequence_lengths) + + # Validate the packing + assert validate_solution(large_sequence_lengths, bins, bin_capacity) + + # Print the number of bins used (for information) + print(f"{algorithm.name} used {len(bins)} bins for large sequences") + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + # TODO(ahmadki): use the function to specify all test algorithms instead of lists below + @pytest.mark.parametrize( + "case_name, sequence_lengths", + [ + ("single_item", [50]), + ("all_same_size", [30, 30, 30, 30, 30]), + ("max_size", [100, 100, 100]), + ("mixed_sizes", [10, 50, 100, 20, 80, 30, 70, 40, 60, 90]), + ], + ) + def test_edge_cases( + self, + bin_capacity: int, + algorithm: PackingAlgorithm, + case_name: str, + sequence_lengths: List[int], + ): + """Test edge cases with all algorithms.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack(sequence_lengths) + + # Validate the packing + assert validate_solution(sequence_lengths, bins, bin_capacity) + + # For single item, check that only one bin is created + if case_name == "single_item": + assert len(bins) == 1 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_empty_list(self, bin_capacity: int, algorithm: PackingAlgorithm): + """Test empty list with algorithms that can handle it.""" + packer = get_packer(algorithm, bin_capacity) + bins = packer.pack([]) + + # For empty list, check that no bins are created + assert len(bins) == 0 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.FIRST_FIT_SHUFFLE, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_error_cases(self, bin_capacity: int, algorithm: PackingAlgorithm): + """Test error cases with all algorithms.""" + # Test with a sequence length that exceeds bin capacity + sequence_lengths = [50, 150, 70] # 150 > bin_capacity (100) + + packer = get_packer(algorithm, bin_capacity) + with pytest.raises(ValueError): + packer.pack(sequence_lengths) + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.CONCATENATIVE, + PackingAlgorithm.FIRST_FIT_DECREASING, + PackingAlgorithm.MODIFIED_FIRST_FIT_DECREASING, + ], + ) + def test_deterministic( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test that deterministic algorithms produce the same result on multiple runs.""" + packer = get_packer(algorithm, bin_capacity) + + # Run the algorithm twice and check that the results are the same + bins1 = packer.pack(medium_sequence_lengths) + bins2 = packer.pack(medium_sequence_lengths) + + # Convert to a format that can be compared (sort each bin and then sort the bins) + sorted_bins1 = sorted([sorted(bin_indices) for bin_indices in bins1]) + sorted_bins2 = sorted([sorted(bin_indices) for bin_indices in bins2]) + + assert sorted_bins1 == sorted_bins2 + + @pytest.mark.parametrize( + "algorithm", + [ + PackingAlgorithm.FIRST_FIT_SHUFFLE, + ], + ) + def test_randomized( + self, + bin_capacity: int, + medium_sequence_lengths: List[int], + algorithm: PackingAlgorithm, + ): + """Test that randomized algorithms can produce different results on multiple runs.""" + # Note: This test might occasionally fail due to randomness + + # Set different seeds to ensure different random behavior + random.seed(42) + packer1 = get_packer(algorithm, bin_capacity) + bins1 = packer1.pack(medium_sequence_lengths) + + random.seed(43) + packer2 = get_packer(algorithm, bin_capacity) + bins2 = packer2.pack(medium_sequence_lengths) + + # Convert to a format that can be compared + sorted_bins1 = sorted([sorted(bin_indices) for bin_indices in bins1]) + sorted_bins2 = sorted([sorted(bin_indices) for bin_indices in bins2]) + + # Check if the results are different + # This is a weak test, as randomness might still produce the same result + if sorted_bins1 == sorted_bins2: + print( + f"Warning: {algorithm.name} produced the same result with different seeds" + ) + + +if __name__ == "__main__": + pytest.main(["-xvs", __file__]) diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py index 6b6c95c092..eaebf2dd8a 100644 --- a/tests/unit/distributed/test_batched_data_dict.py +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -14,7 +14,11 @@ import pytest import torch -from nemo_rl.distributed.batched_data_dict import BatchedDataDict, DynamicBatchingArgs +from nemo_rl.distributed.batched_data_dict import ( + BatchedDataDict, + DynamicBatchingArgs, + SequencePackingArgs, +) def test_shard_by_batch_size_basic(): @@ -236,3 +240,367 @@ def test_shard_by_batch_size_dynamic(): batch_size, seqlen = mb["data"].shape assert seqlen % 4 == 0 assert seqlen <= 32 + + +def test_sequence_packing_basic(): + """Test basic functionality of sequence packing with modified FFD algorithm.""" + # Create sample data with varying sequence lengths + batch_size = 8 + max_seq_length = 512 + + # Generate random sequence lengths between 50 and 400 + torch.manual_seed(42) + sequence_lengths = torch.randint(50, 400, (batch_size,)) + + # Create input tensors with padding + input_ids = [] + for seq_len in sequence_lengths: + # Create a sequence with actual tokens up to seq_len, then padding + seq = torch.cat( + [ + torch.randint(1, 1000, (seq_len,)), # Actual tokens + torch.zeros(max_seq_length - seq_len, dtype=torch.long), # Padding + ] + ) + input_ids.append(seq) + + input_ids = torch.stack(input_ids) + + # Create batch data dict + batch_data = BatchedDataDict( + { + "input_ids": input_ids, + "sequence_lengths": sequence_lengths, + "problem_ids": torch.arange(batch_size), + } + ) + + # Configure sequence packing + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + # Shard the batch with sequence packing + shards = 2 + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=shards, sequence_packing_args=sequence_packing_args + ) + + # Verify output structure + assert len(sharded_batches) == shards + assert len(sorted_indices) == batch_size + + # Verify each shard has microbatch indices and lengths + for shard in sharded_batches: + assert hasattr(shard, "micro_batch_indices") + assert hasattr(shard, "micro_batch_lengths") + assert len(shard.micro_batch_indices) > 0 + assert len(shard.micro_batch_lengths) > 0 + + problem_ids_seen = set() + + # Verify microbatch structure + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + assert len(chunk_indices) == len(chunk_lengths) + + # Verify each microbatch respects the token limit + for (start_idx, end_idx), packed_len in zip(chunk_indices, chunk_lengths): + assert packed_len <= sequence_packing_args["max_tokens_per_microbatch"] + + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_uniform_lengths(): + """Test sequence packing when all sequences have the same length.""" + batch_size = 12 + seq_length = 256 + + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(batch_size, seq_length, dtype=torch.long), + "sequence_lengths": torch.full((batch_size,), seq_length), + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=2, sequence_packing_args=sequence_packing_args + ) + + # With uniform lengths, sequences should be efficiently packed + assert len(sharded_batches) == 2 + len_0 = len( + list(sharded_batches[0].make_microbatch_iterator_for_packable_sequences()) + ) + len_1 = len( + list(sharded_batches[1].make_microbatch_iterator_for_packable_sequences()) + ) + assert len_0 + len_1 == 3 + assert min(len_0, len_1) == 1 + + # Each microbatch should pack as many sequences as possible + for shard in sharded_batches: + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + for (start_idx, end_idx), packed_len in zip(chunk_indices, chunk_lengths): + # With 256 tokens per sequence and 1024 max, should pack 4 sequences + assert packed_len <= 1024 + num_seqs = end_idx - start_idx + assert num_seqs <= 4 # Can fit at most 4 sequences of length 256 + + problem_ids_seen = set() + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_long_sequences(): + """Test sequence packing with very long sequences that require individual microbatches.""" + batch_size = 4 + + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(batch_size, 2048, dtype=torch.long), + "sequence_lengths": torch.tensor([900, 850, 1000, 950]), + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=1, + ) + + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=2, sequence_packing_args=sequence_packing_args + ) + + # Each sequence should be in its own microbatch due to length + for shard in sharded_batches: + for chunk_indices, chunk_lengths in zip( + shard.micro_batch_indices, shard.micro_batch_lengths + ): + for (start_idx, end_idx), max_len in zip(chunk_indices, chunk_lengths): + num_seqs = end_idx - start_idx + # Each long sequence should be alone in its microbatch + assert num_seqs == 1 + + problem_ids_seen = set() + for s in sharded_batches: + for mb in s.make_microbatch_iterator_for_packable_sequences(): + mb_len = mb["sequence_lengths"].sum().item() + assert mb_len <= sequence_packing_args["max_tokens_per_microbatch"] + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert len(problem_ids_seen) == batch_size + + +def test_sequence_packing_with_dynamic_batching_conflict(): + """Test that sequence packing and dynamic batching cannot be used together.""" + batch_data = BatchedDataDict( + { + "input_ids": torch.ones(4, 100, dtype=torch.long), + "sequence_lengths": torch.tensor([50, 60, 70, 80]), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=1024, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + ) + + dynamic_batching_args: DynamicBatchingArgs = { + "input_key": "input_ids", + "input_lengths_key": "sequence_lengths", + "sequence_length_round": 4, + "max_tokens_per_microbatch": 1024, + } + + with pytest.raises( + AssertionError, + match="dynamic_batching_args and sequence_packing_args cannot be passed together", + ): + batch_data.shard_by_batch_size( + shards=2, + sequence_packing_args=sequence_packing_args, + dynamic_batching_args=dynamic_batching_args, + ) + + +@pytest.mark.parametrize("pad_to_multiple_of", [1, 32, 64, 256]) +def test_sequence_packing_microbatch_boundaries(pad_to_multiple_of): + """Test that microbatch boundaries are correctly maintained across chunks with random sequences.""" + # Create a large batch with random sequence lengths to test boundary handling + torch.manual_seed(123) # For reproducible tests + batch_size = 1024 + num_global_batches = 4 + max_seq_length = 1024 + max_tokens_per_microbatch = 1200 + + def _get_padded_seqlen(seqlen: int) -> int: + return (seqlen + (pad_to_multiple_of - 1)) // pad_to_multiple_of + + # Generate random sequence lengths with good variety + sequence_lengths = torch.randint(50, 800, (batch_size,)) + + # Create input tensors with padding + input_ids = [] + for i, seq_len in enumerate(sequence_lengths): + # Create a sequence with actual tokens up to seq_len, then padding + seq = torch.cat( + [ + torch.randint(1, 1000, (seq_len,)), # Actual tokens + torch.zeros(max_seq_length - seq_len, dtype=torch.long), # Padding + ] + ) + input_ids.append(seq) + + input_ids = torch.stack(input_ids) + + batch_data = BatchedDataDict( + { + "input_ids": input_ids, + "sequence_lengths": sequence_lengths, + "problem_ids": torch.arange(batch_size), + } + ) + + sequence_packing_args = SequencePackingArgs( + max_tokens_per_microbatch=max_tokens_per_microbatch, + input_key="input_ids", + input_lengths_key="sequence_lengths", + algorithm="modified_first_fit_decreasing", + sequence_length_pad_multiple=pad_to_multiple_of, + ) + + # Test with multiple shards and explicit batch_size to create chunks + shards = 4 + chunk_batch_size = batch_size // num_global_batches + sharded_batches, sorted_indices = batch_data.shard_by_batch_size( + shards=shards, + batch_size=chunk_batch_size, + sequence_packing_args=sequence_packing_args, + ) + + # Verify output structure + assert len(sharded_batches) == shards + assert len(sorted_indices) == batch_size + + # Track all problem IDs to ensure completeness and no duplicates + problem_ids_seen = set() + + for gb_idx in range(num_global_batches): + mb_count_for_gb = 0 + min_mb_count = 100000000 # arbitrary large number + max_mb_count = 0 + legal_problem_ids = set( + range(gb_idx * chunk_batch_size, (gb_idx + 1) * chunk_batch_size) + ) + for shard_idx in range(shards): + shard_batch = sharded_batches[shard_idx].get_batch(gb_idx) + mb_count = 0 + for mb in shard_batch.make_microbatch_iterator_for_packable_sequences(): + mb_count += 1 + for i in range(mb["input_ids"].shape[0]): + problem_id = mb["problem_ids"][i].item() + assert problem_id in legal_problem_ids, ( + f"Problem ID {problem_id} not in legal problem IDs" + ) + assert problem_id not in problem_ids_seen, ( + f"Problem ID {problem_id} seen twice" + ) + problem_ids_seen.add(problem_id) + assert ( + _get_padded_seqlen(mb["sequence_lengths"]).sum().item() + <= max_tokens_per_microbatch + ), ( + f"Sequence length {_get_padded_seqlen(mb['sequence_lengths']).sum().item()} is greater than max tokens per microbatch {max_tokens_per_microbatch}" + ) + + min_mb_count = min(min_mb_count, mb_count) + max_mb_count = max(max_mb_count, mb_count) + mb_count_for_gb += mb_count + assert max_mb_count - min_mb_count <= 1 + + num_actual_tokens = sum( + sequence_lengths[ + gb_idx * chunk_batch_size : (gb_idx + 1) * chunk_batch_size + ] + ) + packing_efficiency = num_actual_tokens / ( + mb_count_for_gb * max_tokens_per_microbatch + ) + + pack_efficiency_standards = { + 1: (0.97, 1.0), + 32: (0.92, 0.97), + 64: (0.85, 0.92), + 256: (0.60, 0.80), + } + assert packing_efficiency >= pack_efficiency_standards[pad_to_multiple_of][0], ( + f"We expect packing efficiency to be above {pack_efficiency_standards[pad_to_multiple_of][0]} for these nice random inputs with padding to multiples of {pad_to_multiple_of}. Got {packing_efficiency}" + ) + assert packing_efficiency <= pack_efficiency_standards[pad_to_multiple_of][1], ( + f"We expect packing efficiency to be below {pack_efficiency_standards[pad_to_multiple_of][1]} for these nice random inputs with padding to multiples of {pad_to_multiple_of}. Got {packing_efficiency}" + ) + + assert len(problem_ids_seen) == batch_size + + # Finally, test that we can reorder everything back to how it was before + reconstructed = BatchedDataDict.from_batches(sharded_batches) + # check that it's different from the original + assert not torch.all(reconstructed["problem_ids"] == batch_data["problem_ids"]) + assert not torch.all(reconstructed["input_ids"] == batch_data["input_ids"]) + assert not torch.all( + reconstructed["sequence_lengths"] == batch_data["sequence_lengths"] + ) + + reconstructed.reorder_data(sorted_indices) + # check that it's the same as the original + assert torch.all(reconstructed["problem_ids"] == batch_data["problem_ids"]) + assert torch.all(reconstructed["input_ids"] == batch_data["input_ids"]) + assert torch.all( + reconstructed["sequence_lengths"] == batch_data["sequence_lengths"] + )