diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index 163ce71a24..a2d404fdaa 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -26,8 +26,12 @@ PreferenceLossDataDict, PreferenceLossFn, ) -from nemo_rl.algorithms.loss.utils import prepare_loss_input +from nemo_rl.algorithms.loss.utils import ( + prepare_loss_input, + prepare_packed_loss_input, +) from nemo_rl.algorithms.loss.wrapper import ( + SequencePackingFusionLossWrapper, SequencePackingLossWrapper, wrap_loss_fn_with_input_preparation, ) @@ -46,6 +50,8 @@ "PreferenceLossDataDict", "PreferenceLossFn", "prepare_loss_input", + "prepare_packed_loss_input", + "SequencePackingFusionLossWrapper", "SequencePackingLossWrapper", "wrap_loss_fn_with_input_preparation", ] diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 70edc220e0..ad92522db0 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -24,6 +24,8 @@ from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( + _get_tokens_on_this_cp_rank, + from_parallel_logits_to_logprobs_packed_sequences, get_distillation_topk_logprobs_from_logits, get_next_token_logprobs_from_logits, ) @@ -119,3 +121,159 @@ def prepare_loss_input( raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") return loss_input, data + + +def _pack_input_ids( + input_ids: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_rank: int = 0, + cp_size: int = 1, + roll_shift: int = 0, +) -> torch.Tensor: + """Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries. + + Each sequence is individually padded to its padded length (from + cu_seqlens_q_padded), optionally rolled, and CP-sharded at that padded + length before being placed into the packed output. This matches how + Megatron packs and CP-shards sequences in _pack_sequences_for_megatron. + + Args: + input_ids: Unpacked input IDs [B, S]. + cu_seqlens_q: Unpadded cumulative sequence lengths [B+1]. + cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1]. + cp_rank: Context parallelism rank. + cp_size: Context parallelism size. + roll_shift: If non-zero, roll each padded sequence by this amount + before CP-sharding. Use -1 to build shifted targets for + next-token prediction. + """ + batch_size = input_ids.shape[0] + total_packed_len = int(cu_seqlens_q_padded[-1].item()) // cp_size + packed = torch.zeros( + total_packed_len, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) + padded_len = int((cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i]).item()) + packed_start = int(cu_seqlens_q_padded[i].item()) + seq = torch.zeros(padded_len, dtype=input_ids.dtype, device=input_ids.device) + seq[:actual_len] = input_ids[i, :actual_len] + if roll_shift != 0: + seq = seq.roll(shifts=roll_shift, dims=0) + sharded = _get_tokens_on_this_cp_rank(seq, cp_rank, cp_size, seq_dim=0) + packed[packed_start // cp_size : (packed_start + padded_len) // cp_size] = ( + sharded + ) + return packed.unsqueeze(0) + + +def prepare_packed_loss_input( + logits: torch.Tensor, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + cu_seqlens_q: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: Optional[TrainingSamplingParams] = None, +) -> tuple[dict[str, Any], BatchedDataDict[Any]]: + """Prepare loss input from packed logits in a single fused pass. + + Unlike prepare_loss_input which operates on a single (unpacked) sequence, + this function computes log probabilities from packed logits across all + sequences at once using from_parallel_logits_to_logprobs_packed_sequences. + + Currently only supports LossInputType.LOGPROB. + + Args: + logits: Packed logits from the model [1, T_packed // CP, V // TP]. + data: Microbatch data (unpacked, [B, S]). + loss_fn: Loss function (must have input_type == LossInputType.LOGPROB). + cu_seqlens_q: Unpadded cumulative sequence lengths [B+1]. + cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1]. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + sampling_params: Sampling parameters. + + Returns: + tuple(loss_input, maybe_updated_data) + """ + if loss_fn.input_type != LossInputType.LOGPROB: + raise ValueError( + f"prepare_packed_loss_input only supports LossInputType.LOGPROB, " + f"got {loss_fn.input_type}. Use SequencePackingLossWrapper with " + f"prepare_loss_input for other types." + ) + assert vocab_parallel_group is not None, ( + "prepare_packed_loss_input requires vocab_parallel_group (Megatron TP)." + ) + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided with vocab_parallel_group." + ) + + input_ids = data["input_ids"] + unpacked_seqlen = input_ids.shape[1] + cp_size = ( + 1 + if context_parallel_group is None + else torch.distributed.get_world_size(context_parallel_group) + ) + cp_rank = ( + 0 + if context_parallel_group is None + else torch.distributed.get_rank(context_parallel_group) + ) + + packed_rolled_targets = _pack_input_ids( + input_ids, + cu_seqlens_q, + cu_seqlens_q_padded, + cp_rank=cp_rank, + cp_size=cp_size, + roll_shift=-1, + ) + + logprobs = from_parallel_logits_to_logprobs_packed_sequences( + logits.to(torch.float32), + packed_rolled_targets, + cu_seqlens_q_padded, + unpacked_seqlen, + vocab_start_index=vocab_parallel_rank * logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1], + group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + sampling_params=sampling_params, + target_is_pre_rolled=True, + ) + + # Match prepare_loss_input behavior for top-k/top-p filtered training: + # use filtered curr_logprobs for actor loss, but keep unfiltered values for KL. + if need_top_k_or_top_p_filtering(sampling_params): + mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) + logprobs = mask_out_neg_inf_logprobs(logprobs, mask[:, 1:], "curr_logprobs") + + if ( + hasattr(loss_fn, "reference_policy_kl_penalty") + and loss_fn.reference_policy_kl_penalty != 0 + ): + data["curr_logprobs_unfiltered"] = ( + from_parallel_logits_to_logprobs_packed_sequences( + logits.to(torch.float32), + packed_rolled_targets, + cu_seqlens_q_padded, + unpacked_seqlen, + vocab_start_index=vocab_parallel_rank * logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1], + group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + sampling_params=None, + target_is_pre_rolled=True, + ) + ) + + return {"next_token_logprobs": logprobs}, data diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index b669c494d3..a28bb18a19 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -152,6 +152,68 @@ def __call__( return loss_accum, metrics_accum +class SequencePackingFusionLossWrapper: + """Fused sequence packing loss wrapper that processes all sequences in one forward pass. + + Unlike SequencePackingLossWrapper which iterates over sequences one at a time, + this wrapper calls prepare_fn once on the packed logits to compute log + probabilities in a single shot, then calls the loss function once with the + pre-computed result. + + This avoids per-sequence kernel launches and TP/CP communication overhead while + producing numerically identical results. + + The prepare_fn should be prepare_packed_loss_input (from nemo_rl.algorithms.loss.utils), + which currently only supports LossInputType.LOGPROB. + """ + + def __init__( + self, + loss_fn: LossFunction, + prepare_fn: Callable[..., Any], + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + self.loss_fn = loss_fn + self.prepare_fn = prepare_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = ( + cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q + ) + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + ) -> tuple[Tensor, dict[str, Any]]: + """Compute loss for all packed sequences in one forward pass.""" + loss_input, prepared_data = self.prepare_fn( + logits=next_token_logits, + data=data, + loss_fn=self.loss_fn, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_q_padded=self.cu_seqlens_q_padded, + vocab_parallel_rank=self.vocab_parallel_rank, + vocab_parallel_group=self.vocab_parallel_group, + context_parallel_group=self.context_parallel_group, + ) + + return self.loss_fn( + data=prepared_data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + + def wrap_loss_fn_with_input_preparation( next_token_logits: Tensor, data: BatchedDataDict[Any], diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index a5837e7753..2575fd891d 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -942,16 +942,19 @@ def from_parallel_logits_to_logprobs_packed_sequences( cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, sampling_params: Optional[TrainingSamplingParams] = None, + target_is_pre_rolled: bool = False, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. Args: vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T // CP, vocab_size//TP] where T is the total number of tokens across all packed sequences. - target (torch.Tensor): Packed target token indices with shape [1, T]. - NOTE: Must be the unmodified targets as this function will shift them internally. - cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. - cu_seqlens[i] indicates the start position of sequence i in the packed format. + target (torch.Tensor): Packed target token indices. + If target_is_pre_rolled=False: shape [1, T] — unmodified targets, rolled internally. + If target_is_pre_rolled=True: shape [1, T // CP] — pre-rolled and pre-CP-sharded. + cu_seqlens_padded (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. + cu_seqlens_padded[i] indicates the start position of sequence i in the packed format + (full, not CP-adjusted). unpacked_seqlen (int): The length of the unpacked sequence tensor. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. @@ -959,44 +962,48 @@ def from_parallel_logits_to_logprobs_packed_sequences( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering. + target_is_pre_rolled (bool): If True, target is already shifted and CP-sharded to match + vocab_parallel_logits shape, skipping the internal per-sequence roll+CP-shard loop. Returns: torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. The total length is reduced by batch_size due to target shifting (one token per sequence). """ - # Remove batch dimension to work with [T, vocab_size] and [T] - vocab_parallel_logits = vocab_parallel_logits.squeeze(0) - target = target.squeeze(0) - batch_size = cu_seqlens_padded.shape[0] - 1 cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) - cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group) - # Roll each sequence individually - rolled_targets = torch.zeros( - target.shape[0] // cp_size, dtype=target.dtype, device=target.device - ) - for i in range(batch_size): - start_idx = cu_seqlens_padded[i].item() - end_idx = cu_seqlens_padded[i + 1].item() + if not target_is_pre_rolled: + # Roll each sequence individually and CP-shard the targets + # Remove batch dimension to work with [T, vocab_size] and [T] + vocab_parallel_logits = vocab_parallel_logits.squeeze(0) + target = target.squeeze(0) + cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group) - # Get the sequence targets and roll by -1 - seq_targets = target[start_idx:end_idx] - rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) - rolled_targets[start_idx // cp_size : end_idx // cp_size] = ( - _get_tokens_on_this_cp_rank(rolled_seq_targets, cp_rank, cp_size, seq_dim=0) + rolled_targets = torch.zeros( + target.shape[0] // cp_size, dtype=target.dtype, device=target.device ) + for i in range(batch_size): + start_idx = cu_seqlens_padded[i].item() + end_idx = cu_seqlens_padded[i + 1].item() + + seq_targets = target[start_idx:end_idx] + rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) + rolled_targets[start_idx // cp_size : end_idx // cp_size] = ( + _get_tokens_on_this_cp_rank( + rolled_seq_targets, cp_rank, cp_size, seq_dim=0 + ) + ) - # Add batch dimension back for DistributedLogprob - rolled_targets = rolled_targets.unsqueeze(0) - vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) + target = rolled_targets.unsqueeze(0) + vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation if need_top_k_or_top_p_filtering(sampling_params): if chunk_size is not None: probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, group, sampling_params.top_k, sampling_params.top_p, @@ -1006,7 +1013,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( else: probs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, group, sampling_params.top_k, sampling_params.top_p, @@ -1016,7 +1023,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( if chunk_size is not None: probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, vocab_start_index, vocab_end_index, chunk_size, @@ -1026,7 +1033,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( else: probs: torch.Tensor = DistributedLogprob.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, vocab_start_index, vocab_end_index, group, diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index cfbb913395..883aa44ad7 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -34,8 +34,10 @@ need_top_k_or_top_p_filtering, ) from nemo_rl.algorithms.loss import ( + SequencePackingFusionLossWrapper, SequencePackingLossWrapper, prepare_loss_input, + prepare_packed_loss_input, wrap_loss_fn_with_input_preparation, ) from nemo_rl.algorithms.loss.interfaces import LossFunction @@ -322,9 +324,19 @@ def __call__( # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: - loss_fn_wrapped = SequencePackingLossWrapper( + fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", False) + if fuse_loss: + wrapper_cls = SequencePackingFusionLossWrapper + prepare_fn = partial( + prepare_packed_loss_input, sampling_params=self.sampling_params + ) + else: + wrapper_cls = SequencePackingLossWrapper + prepare_fn = prepare_loss_input_wrapped + + loss_fn_wrapped = wrapper_cls( loss_fn=self.loss_fn, - prepare_fn=prepare_loss_input_wrapped, + prepare_fn=prepare_fn, cu_seqlens_q=packed_seq_params.cu_seqlens_q, cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, vocab_parallel_rank=get_tensor_model_parallel_rank(), diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py new file mode 100644 index 0000000000..cb887b2b37 --- /dev/null +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -0,0 +1,431 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests to ensure SequencePackingFusionLossWrapper produces identical results +to SequencePackingLossWrapper. + +Uses distributed_test_runner (torch.multiprocessing.spawn) instead of Ray actors +so that pytest + code coverage work correctly. + +For loss function, currently only supports ClippedPGLossFn. +""" + +import functools + +import pytest +import torch + +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams +from nemo_rl.algorithms.loss import ( + ClippedPGLossFn, + SequencePackingFusionLossWrapper, + SequencePackingLossWrapper, + prepare_loss_input, + prepare_packed_loss_input, +) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + +def _setup_2d_process_groups(rank, world_size, cp_size, tp_size): + """Create 2D (cp, tp) process groups. + + Rank layout (outer cp, inner tp): + [[0, 1, ..., tp_size-1], + [tp_size, ..., 2*tp_size-1], + ...] + """ + cp_groups = [] + tp_groups = [] + + for tp_rank in range(tp_size): + ranks = [cp_rank * tp_size + tp_rank for cp_rank in range(cp_size)] + cp_groups.append(torch.distributed.new_group(ranks=ranks)) + + for cp_rank in range(cp_size): + ranks = [cp_rank * tp_size + tp_rank for tp_rank in range(tp_size)] + tp_groups.append(torch.distributed.new_group(ranks=ranks)) + + my_tp_rank = rank % tp_size + my_cp_rank = rank // tp_size + cp_group = cp_groups[my_tp_rank] + tp_group = tp_groups[my_cp_rank] + return my_cp_rank, my_tp_rank, cp_group, tp_group + + +def _build_test_case(cp_size, tp_size, my_tp_rank, cp_group): + """Build a small packed batch with CP-aware packing.""" + from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + from nemo_rl.models.megatron.data import _pack_sequences_for_megatron + + device = torch.device("cuda") + torch.manual_seed(42) + + batch_size = 4 + max_seq_len = 512 + if max_seq_len % (2 * cp_size) != 0: + max_seq_len = (max_seq_len // (2 * cp_size) + 1) * (2 * cp_size) + + vocab_size_total = 512 + assert vocab_size_total % tp_size == 0 + vocab_size_local = vocab_size_total // tp_size + + seq_lengths = torch.tensor( + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 3, max_seq_len * 3 // 4], + dtype=torch.int32, + device=device, + ) + + input_ids = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) + token_mask = torch.zeros( + batch_size, max_seq_len, dtype=torch.float32, device=device + ) + for i in range(batch_size): + L = int(seq_lengths[i].item()) + input_ids[i, :L] = torch.randint(0, vocab_size_total, (L,), device=device) + token_mask[i, :L] = 1.0 + + sample_mask = torch.ones(batch_size, dtype=torch.float32, device=device) + advantages = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + prev_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + generation_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + reference_policy_logprobs = generation_logprobs.clone() + + data_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": seq_lengths, + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + } + ) + + pad_to_multiple = cp_size * 2 + ( + _packed_input_ids, + _packed_input_ids_cp, + _packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=pad_to_multiple, + pad_packed_seq_to=max_seq_len * batch_size if cp_size > 1 else None, + cp_rank=torch.distributed.get_rank(cp_group), + cp_size=cp_size, + ) + assert cu_seqlens_padded is not None + + full_logits = torch.randn( + batch_size, max_seq_len, vocab_size_total, device=device, dtype=torch.float32 + ) + + def make_logits_and_packed_logits(): + logits_local = ( + full_logits[ + :, + :, + my_tp_rank * vocab_size_local : (my_tp_rank + 1) * vocab_size_local, + ] + .clone() + .detach() + .requires_grad_(True) + ) + + total_padded_tokens = int(cu_seqlens_padded[-1].item()) + packed_logits = torch.zeros( + 1, total_padded_tokens // cp_size, vocab_size_local, device=device + ) + + run_seq = 0 + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + padded_seq_len = int( + (cu_seqlens_padded[i + 1] - cu_seqlens_padded[i]).item() + ) + tmp = torch.zeros(1, padded_seq_len, vocab_size_local, device=device) + tmp[:, :seq_len, :] = logits_local[i : i + 1, :seq_len, :] + packed_logits[ + :, + run_seq // cp_size : (run_seq + padded_seq_len) // cp_size, + :, + ] = _get_tokens_on_this_cp_rank( + tmp, torch.distributed.get_rank(cp_group), cp_size + ) + run_seq += padded_seq_len + + return logits_local, packed_logits + + loss_cfg = { + "reference_policy_kl_penalty": 0.01, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + "ratio_clip_c": None, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, + "force_on_policy_ratio": False, + } + + valid_toks = int(torch.clamp(seq_lengths - 1, min=0).sum().item()) + global_valid_toks = torch.tensor(valid_toks, dtype=torch.float32, device=device) + global_valid_seqs = torch.tensor(batch_size, dtype=torch.float32, device=device) + + return { + "loss_cfg": loss_cfg, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "data_dict": data_dict, + "global_valid_seqs": global_valid_seqs, + "global_valid_toks": global_valid_toks, + "make_logits_and_packed_logits": make_logits_and_packed_logits, + } + + +def _run_compare_sequence_packing_wrappers(rank, world_size, cp_size, tp_size): + """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. + + Verifies that the fused wrapper produces identical loss values and + backward gradients w.r.t. vocab-parallel logits. + """ + _my_cp_rank, my_tp_rank, cp_group, tp_group = _setup_2d_process_groups( + rank, world_size, cp_size, tp_size + ) + tc = _build_test_case(cp_size, tp_size, my_tp_rank, cp_group) + base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + data_dict = tc["data_dict"] + + baseline_wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + candidate_wrapper = SequencePackingFusionLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_packed_loss_input, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + # Baseline run + baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() + baseline_loss, _baseline_metrics = baseline_wrapper( + baseline_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (baseline_loss / cp_size).backward() + baseline_grad = baseline_logits.grad.clone() + + # Candidate run (fresh logits, identical values) + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_loss, _candidate_metrics = candidate_wrapper( + candidate_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (candidate_loss / cp_size).backward() + candidate_grad = candidate_logits.grad.clone() + + # Sanity: gradients must be non-None and non-zero + assert baseline_grad.abs().sum() > 0, f"baseline grad is all zeros on rank {rank}" + assert candidate_grad.abs().sum() > 0, f"candidate grad is all zeros on rank {rank}" + + # Forward: loss values must match + torch.testing.assert_close( + baseline_loss, + candidate_loss, + atol=1e-5, + rtol=1e-5, + msg=f"Loss mismatch on rank {rank}", + ) + + # Backward: gradients w.r.t. logits must match + torch.testing.assert_close( + baseline_grad, + candidate_grad, + atol=1e-5, + rtol=1e-5, + msg=f"Gradient mismatch on rank {rank}", + ) + + +def _run_compare_sequence_packing_wrappers_with_sampling( + rank, world_size, cp_size, tp_size +): + """Compare fused vs unfused wrappers with sampling params enabled.""" + _my_cp_rank, my_tp_rank, cp_group, tp_group = _setup_2d_process_groups( + rank, world_size, cp_size, tp_size + ) + tc = _build_test_case(cp_size, tp_size, my_tp_rank, cp_group) + base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + data_dict = tc["data_dict"] + + sampling_params = TrainingSamplingParams(top_k=8, top_p=0.9, temperature=1.0) + prepare_loss_input_wrapped = functools.partial( + prepare_loss_input, sampling_params=sampling_params + ) + prepare_packed_loss_input_wrapped = functools.partial( + prepare_packed_loss_input, sampling_params=sampling_params + ) + + baseline_wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input_wrapped, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + candidate_wrapper = SequencePackingFusionLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_packed_loss_input_wrapped, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + # Baseline run + baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() + baseline_loss, baseline_metrics = baseline_wrapper( + baseline_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (baseline_loss / cp_size).backward() + baseline_grad = baseline_logits.grad.clone() + + # Candidate run (fresh logits, identical values) + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_loss, candidate_metrics = candidate_wrapper( + candidate_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (candidate_loss / cp_size).backward() + candidate_grad = candidate_logits.grad.clone() + + # Sanity: gradients must be non-None and non-zero + assert baseline_grad.abs().sum() > 0, f"baseline grad is all zeros on rank {rank}" + assert candidate_grad.abs().sum() > 0, f"candidate grad is all zeros on rank {rank}" + + # Forward: loss values must match + torch.testing.assert_close( + baseline_loss, + candidate_loss, + atol=1e-5, + rtol=1e-5, + msg=f"Loss mismatch with sampling params on rank {rank}", + ) + + # Metrics parity under sampling params + assert set(baseline_metrics.keys()) == set(candidate_metrics.keys()) + for k in baseline_metrics: + torch.testing.assert_close( + torch.as_tensor(baseline_metrics[k], device="cuda"), + torch.as_tensor(candidate_metrics[k], device="cuda"), + atol=1e-5, + rtol=1e-5, + msg=f"Metric mismatch for key={k} on rank {rank}", + ) + + # Backward: gradients w.r.t. logits must match + torch.testing.assert_close( + baseline_grad, + candidate_grad, + atol=1e-5, + rtol=1e-5, + msg=f"Gradient mismatch with sampling params on rank {rank}", + ) + + +@pytest.mark.parametrize( + "cp_tp", + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (2, 4), + (4, 2), + ], + ids=lambda cp_tp: f"cp{cp_tp[0]}_tp{cp_tp[1]}", +) +def test_sequence_packing_fusion_vs_baseline(distributed_test_runner, cp_tp): + """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. + + Verifies that the fused wrapper produces identical: + - loss values + - backward gradients w.r.t. vocab-parallel logits + for different CP and TP configurations. + """ + cp_size, tp_size = cp_tp + world_size = cp_size * tp_size + + test_fn = functools.partial( + _run_compare_sequence_packing_wrappers, + cp_size=cp_size, + tp_size=tp_size, + ) + distributed_test_runner(test_fn, world_size=world_size) + + +@pytest.mark.parametrize( + "cp_tp", + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + ], + ids=lambda cp_tp: f"sampling_cp{cp_tp[0]}_tp{cp_tp[1]}", +) +def test_sequence_packing_fusion_vs_baseline_with_sampling_params( + distributed_test_runner, cp_tp +): + """Compare fused vs unfused wrappers with top-k/top-p sampling params.""" + cp_size, tp_size = cp_tp + world_size = cp_size * tp_size + + test_fn = functools.partial( + _run_compare_sequence_packing_wrappers_with_sampling, + cp_size=cp_size, + tp_size=tp_size, + ) + distributed_test_runner(test_fn, world_size=world_size) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index aad98d1da4..6c602440ff 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import pytest @@ -243,6 +244,194 @@ def test_from_parallel_logits_to_logprobs_packed_sequences( cluster.shutdown() +# --------------------------------------------------------------------------- +# distributed_test_runner-based packed-sequences tests (coverage-friendly) +# --------------------------------------------------------------------------- + + +def _run_packed_sequences_equivalence(rank, world_size, tp_size, cp_size, chunk_size): + """Test from_parallel_logits_to_logprobs_packed_sequences with coverage. + + Uses _pack_input_ids to build packed targets and compares: + 1. target_is_pre_rolled=False against the unpacked baseline (CP=1 only) + 2. target_is_pre_rolled=True against target_is_pre_rolled=False + with variable-length sequences. + """ + from nemo_rl.algorithms.loss.utils import _pack_input_ids + + # Build 2-D process groups: inner=TP, outer=CP + tp_groups = [] + cp_groups = [] + for cp_r in range(cp_size): + ranks = [cp_r * tp_size + tp_r for tp_r in range(tp_size)] + tp_groups.append(torch.distributed.new_group(ranks=ranks)) + for tp_r in range(tp_size): + ranks = [cp_r * tp_size + tp_r for cp_r in range(cp_size)] + cp_groups.append(torch.distributed.new_group(ranks=ranks)) + + my_tp_rank = rank % tp_size + my_cp_rank = rank // tp_size + tp_group = tp_groups[my_cp_rank] + cp_group = cp_groups[my_tp_rank] if cp_size > 1 else None + my_cp_rank_val = 0 if cp_group is None else torch.distributed.get_rank(cp_group) + + batch_size = 4 + vocab_size = 1024 + vocab_part_size = vocab_size // tp_size + vocab_start_index = my_tp_rank * vocab_part_size + vocab_end_index = (my_tp_rank + 1) * vocab_part_size + + # Variable-length sequences + raw_seq_lengths = [24, 48, 16, 40] + max_seq_len = max(raw_seq_lengths) + + if cp_size > 1 and max_seq_len % (2 * cp_size) != 0: + max_seq_len = (max_seq_len // (2 * cp_size) + 1) * (2 * cp_size) + raw_seq_lengths = [min(l, max_seq_len) for l in raw_seq_lengths] + + pad_to = 2 * cp_size if cp_size > 1 else 1 + padded_seq_lengths = [ + ((l + pad_to - 1) // pad_to) * pad_to for l in raw_seq_lengths + ] + + # Build cu_seqlens / cu_seqlens_padded + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + for i in range(batch_size): + cu_seqlens[i + 1] = cu_seqlens[i] + raw_seq_lengths[i] + cu_seqlens_padded[i + 1] = cu_seqlens_padded[i] + padded_seq_lengths[i] + + total_padded = int(cu_seqlens_padded[-1].item()) + + torch.manual_seed(42) + unpacked_logits_full = torch.randn( + batch_size, max_seq_len, vocab_size, device="cuda" + ) + input_ids = torch.randint(0, vocab_size, (batch_size, max_seq_len), device="cuda") + + unpacked_logits_local = unpacked_logits_full[ + :, :, vocab_start_index:vocab_end_index + ] + + # --- Pack logits: [B, S, V_local] -> [1, T_padded // CP, V_local] --- + # Each sequence is individually padded and CP-sharded (matching production). + packed_logits = torch.zeros( + 1, total_padded // cp_size, vocab_part_size, device="cuda" + ) + for i in range(batch_size): + sl = raw_seq_lengths[i] + psl = padded_seq_lengths[i] + padded_seq = torch.zeros(1, psl, vocab_part_size, device="cuda") + padded_seq[:, :sl, :] = unpacked_logits_local[i : i + 1, :sl, :] + offset = int(cu_seqlens_padded[i].item()) + if cp_size > 1: + sharded = _get_tokens_on_this_cp_rank(padded_seq, my_cp_rank_val, cp_size) + packed_logits[:, offset // cp_size : (offset + psl) // cp_size, :] = sharded + else: + packed_logits[:, offset : offset + psl, :] = padded_seq + + # --- Path 1: target_is_pre_rolled=False --- + # Pack raw (unrolled) input_ids to [1, T_padded] using _pack_input_ids. + packed_target_raw = _pack_input_ids(input_ids, cu_seqlens, cu_seqlens_padded) + + logprobs_not_pre_rolled = from_parallel_logits_to_logprobs_packed_sequences( + packed_logits, + packed_target_raw, + cu_seqlens_padded, + max_seq_len, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + chunk_size=chunk_size, + target_is_pre_rolled=False, + ) + + # --- Path 2: target_is_pre_rolled=True --- + packed_target_pre_rolled = _pack_input_ids( + input_ids, + cu_seqlens, + cu_seqlens_padded, + cp_rank=my_cp_rank_val, + cp_size=cp_size, + roll_shift=-1, + ) + + logprobs_pre_rolled = from_parallel_logits_to_logprobs_packed_sequences( + packed_logits, + packed_target_pre_rolled, + cu_seqlens_padded, + max_seq_len, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + chunk_size=chunk_size, + target_is_pre_rolled=True, + ) + + # Both paths must produce identical results + for i in range(batch_size): + valid_len = raw_seq_lengths[i] - 1 + torch.testing.assert_close( + logprobs_pre_rolled[i, :valid_len], + logprobs_not_pre_rolled[i, :valid_len], + rtol=1e-5, + atol=1e-5, + msg=f"pre_rolled vs not_pre_rolled mismatch on rank {rank}, seq {i}", + ) + + # --- Also compare against the unpacked baseline --- + # The unpacked function CP-shards each row from max_seq_len, which matches + # the packed per-sequence CP-sharding only when CP=1. + if cp_size == 1: + baseline_logprobs = from_parallel_logits_to_logprobs( + unpacked_logits_local, + input_ids, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + ) + for i in range(batch_size): + valid_len = raw_seq_lengths[i] - 1 + torch.testing.assert_close( + logprobs_not_pre_rolled[i, :valid_len], + baseline_logprobs[i, :valid_len], + rtol=1e-5, + atol=1e-5, + msg=f"packed vs unpacked mismatch on rank {rank}, seq {i}", + ) + + +@pytest.mark.parametrize( + "tp_size, cp_size, chunk_size", + [ + (2, 1, None), + (1, 2, None), + (2, 1, 8), + (1, 2, 8), + ], + ids=lambda v: str(v), +) +def test_packed_sequences_with_distributed_runner( + distributed_test_runner, tp_size, cp_size, chunk_size +): + """Test from_parallel_logits_to_logprobs_packed_sequences using distributed_test_runner. + + Covers both target_is_pre_rolled paths, variable-length sequences, and chunk_size, + with proper code coverage tracking (unlike Ray-based tests). + """ + world_size = tp_size * cp_size + test_fn = functools.partial( + _run_packed_sequences_equivalence, + tp_size=tp_size, + cp_size=cp_size, + chunk_size=chunk_size, + ) + distributed_test_runner(test_fn, world_size=world_size) + + @ray.remote(num_gpus=1) class AllGatherCPTestActor: def __init__(self, cp_size):