From cc374496dfe6ecd4b505c138227528fb05453f12 Mon Sep 17 00:00:00 2001 From: mloh Date: Mon, 9 Feb 2026 16:02:10 -0800 Subject: [PATCH 01/18] Fuse sequence packing for loss function Signed-off-by: mloh --- nemo_rl/algorithms/loss/__init__.py | 2 + nemo_rl/algorithms/loss/wrapper.py | 108 +++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index 163ce71a24..ed64591544 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -28,6 +28,7 @@ ) from nemo_rl.algorithms.loss.utils import prepare_loss_input from nemo_rl.algorithms.loss.wrapper import ( + SequencePackingFusionLossWrapper, SequencePackingLossWrapper, wrap_loss_fn_with_input_preparation, ) @@ -46,6 +47,7 @@ "PreferenceLossDataDict", "PreferenceLossFn", "prepare_loss_input", + "SequencePackingFusionLossWrapper", "SequencePackingLossWrapper", "wrap_loss_fn_with_input_preparation", ] diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index b669c494d3..9319e3b0db 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -18,8 +18,11 @@ import torch import torch.distributed -from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + from_parallel_logits_to_logprobs_packed_sequences, +) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -152,6 +155,109 @@ def __call__( return loss_accum, metrics_accum +class SequencePackingFusionLossWrapper: + """Fused sequence packing loss wrapper that processes all sequences in one forward pass. + + Unlike SequencePackingLossWrapper which iterates over sequences one at a time, + this wrapper computes log probabilities from packed logits in a single shot using + from_parallel_logits_to_logprobs_packed_sequences, then calls the loss function + with the pre-computed logprobs. + + This avoids per-sequence kernel launches and TP/CP communication overhead while + producing numerically identical results. + + Requirements: + - The wrapped loss_fn must have input_type == LossInputType.LOGPROB. + - vocab_parallel_group and vocab_parallel_rank must be provided (Megatron TP). + """ + + def __init__( + self, + loss_fn: LossFunction, + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + assert loss_fn.input_type == LossInputType.LOGPROB, ( + f"SequencePackingFusionLossWrapper only supports LossInputType.LOGPROB, " + f"got {loss_fn.input_type}. Use SequencePackingLossWrapper for other types." + ) + assert vocab_parallel_group is not None, ( + "SequencePackingFusionLossWrapper requires vocab_parallel_group (Megatron TP)." + ) + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided with vocab_parallel_group." + ) + self.loss_fn = loss_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = ( + cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q + ) + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group + + def _pack_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Pack input_ids from [B, S] to [1, T_packed] using sequence boundaries. + + Each sequence i is placed at cu_seqlens_q_padded[i] in the packed tensor, + with actual_len = cu_seqlens_q[i+1] - cu_seqlens_q[i] tokens copied. + """ + batch_size = input_ids.shape[0] + total_packed_len = int(self.cu_seqlens_q_padded[-1].item()) + packed = torch.zeros( + 1, total_packed_len, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + actual_len = int( + (self.cu_seqlens_q[i + 1] - self.cu_seqlens_q[i]).item() + ) + packed_start = int(self.cu_seqlens_q_padded[i].item()) + packed[0, packed_start : packed_start + actual_len] = input_ids[ + i, :actual_len + ] + return packed + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + ) -> tuple[Tensor, dict[str, Any]]: + """Compute loss for all packed sequences in one forward pass. + + 1. Pack input_ids from [B, S] to [1, T_packed] to match packed logit layout. + 2. Compute logprobs from packed logits via + from_parallel_logits_to_logprobs_packed_sequences -> [B, S-1]. + 3. Call the loss function with the pre-computed logprobs. + """ + packed_input_ids = self._pack_input_ids(data["input_ids"]) + unpacked_seqlen = data["input_ids"].shape[1] + + logprobs = from_parallel_logits_to_logprobs_packed_sequences( + next_token_logits.to(torch.float32), + packed_input_ids, + self.cu_seqlens_q_padded, + unpacked_seqlen, + vocab_start_index=self.vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(self.vocab_parallel_rank + 1) + * next_token_logits.shape[-1], + group=self.vocab_parallel_group, + inference_only=False, + cp_group=self.context_parallel_group, + ) + + return self.loss_fn( + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + next_token_logprobs=logprobs, + ) + + def wrap_loss_fn_with_input_preparation( next_token_logits: Tensor, data: BatchedDataDict[Any], From 8922e0fdc0f8ff4068110043f579e8d05f45b54d Mon Sep 17 00:00:00 2001 From: mloh Date: Mon, 9 Feb 2026 19:55:38 -0800 Subject: [PATCH 02/18] Add test case to compare with SequencePackingWrapper Signed-off-by: mloh --- .../test_sequence_packing_fusion.py | 423 ++++++++++++++++++ 1 file changed, 423 insertions(+) create mode 100644 tests/unit/algorithms/test_sequence_packing_fusion.py diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py new file mode 100644 index 0000000000..a774977f93 --- /dev/null +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -0,0 +1,423 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests to ensure SequencePackingFusionLossWrapper works as SequencePackingLossWrapper. + - Without explicitly calling the loss_fn sequence by sequence. + +During the forward pass, compare the loss and metrics from the two wrappers. +During the backward pass, compare the gradients from the two wrappers. + +For parallelism, check for CP and TP. + +For loss function, right now only supports: +- ClippedPGLossFn +""" + +import os + +import pytest +import ray +import torch + +from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossFn, + SequencePackingLossWrapper, + SequencePackingFusionLossWrapper, +) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.ray_actor_environment_registry import ( + ACTOR_ENVIRONMENT_REGISTRY, + PY_EXECUTABLES, +) +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup + + +@ray.remote(num_gpus=1) +class SequencePackingLossWrapperBaselineActor: + def __init__(self, cp_size: int, tp_size: int): + self.cp_size = cp_size + self.tp_size = tp_size + self.env_vars = dict(os.environ) + + def _setup_process_groups(self): + torch.distributed.init_process_group(backend="nccl") + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert world_size == self.cp_size * self.tp_size, ( + f"Expected WORLD_SIZE={self.cp_size*self.tp_size}, got {world_size}." + ) + + # --------------------------------------------------------------------- + # Create 2D (cp, tp) process groups. + # Rank layout (outer cp, inner tp): + # [[0, 1, ..., tp_size-1], + # [tp_size, ..., 2*tp_size-1], + # ...] + # --------------------------------------------------------------------- + cp_groups: list[torch.distributed.ProcessGroup] = [] + tp_groups: list[torch.distributed.ProcessGroup] = [] + + # CP groups: one per tp_rank, varying cp coordinate + for tp_rank in range(self.tp_size): + ranks = [cp_rank * self.tp_size + tp_rank for cp_rank in range(self.cp_size)] + cp_groups.append(torch.distributed.new_group(ranks=ranks)) + + # TP groups: one per cp_rank, varying tp coordinate + for cp_rank in range(self.cp_size): + ranks = [cp_rank * self.tp_size + tp_rank for tp_rank in range(self.tp_size)] + tp_groups.append(torch.distributed.new_group(ranks=ranks)) + + my_tp_rank = rank % self.tp_size + my_cp_rank = rank // self.tp_size + cp_group = cp_groups[my_tp_rank] + tp_group = tp_groups[my_cp_rank] + return rank, my_cp_rank, my_tp_rank, cp_group, tp_group + + def _build_test_case(self, cp_group, my_tp_rank: int): + from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + from nemo_rl.models.megatron.data import _pack_sequences_for_megatron + + # --------------------------------------------------------------------- + # Build a small packed batch. + # --------------------------------------------------------------------- + device = torch.device("cuda") + torch.manual_seed(42) # For reproducibility / determinism + + batch_size = 4 + max_seq_len = 512 + # Ensure CP load balancing requirement: divisible by (2 * cp_size) + if max_seq_len % (2 * self.cp_size) != 0: + max_seq_len = (max_seq_len // (2 * self.cp_size) + 1) * (2 * self.cp_size) + + vocab_size_total = 512 + assert vocab_size_total % self.tp_size == 0 + vocab_size_local = vocab_size_total // self.tp_size + + # Variable lengths, but <= max_seq_len + seq_lengths = torch.tensor( + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 3, max_seq_len * 3 // 4], + dtype=torch.int32, + device=device, + ) + + # Input ids + masks + input_ids = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) + token_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) + for i in range(batch_size): + L = int(seq_lengths[i].item()) + input_ids[i, :L] = torch.randint(0, vocab_size_total, (L,), device=device) + token_mask[i, :L] = 1.0 + + sample_mask = torch.ones(batch_size, dtype=torch.float32, device=device) + + # Stable-ish random tensors (avoid extreme ratios/NaNs in unit test) + advantages = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + prev_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + generation_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + reference_policy_logprobs = generation_logprobs.clone() + + data_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": seq_lengths, + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + } + ) + + # Packed sequence metadata (CP-aware) + pad_to_multiple = self.cp_size * 2 + ( + _packed_input_ids, + _packed_input_ids_cp, + _packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=pad_to_multiple, + pad_packed_seq_to=max_seq_len * batch_size if self.cp_size > 1 else None, + cp_rank=torch.distributed.get_rank(cp_group), + cp_size=self.cp_size, + ) + assert cu_seqlens_padded is not None + + # --------------------------------------------------------------------- + # Create vocab-parallel logits, then pack + CP-shard them into [1, T//CP, V//TP] + # --------------------------------------------------------------------- + # Global logits (same across ranks), then slice by TP rank + full_logits = torch.randn( + batch_size, max_seq_len, vocab_size_total, device=device, dtype=torch.float32 + ) + + def make_logits_and_packed_logits(): + logits_local = ( + full_logits[ + :, + :, + my_tp_rank * vocab_size_local : (my_tp_rank + 1) * vocab_size_local, + ] + .clone() + .detach() + .requires_grad_(True) + ) + + total_padded_tokens = int(cu_seqlens_padded[-1].item()) + packed_logits = torch.zeros( + 1, total_padded_tokens // self.cp_size, vocab_size_local, device=device + ) + + run_seq = 0 + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + padded_seq_len = int( + (cu_seqlens_padded[i + 1] - cu_seqlens_padded[i]).item() + ) + tmp = torch.zeros(1, padded_seq_len, vocab_size_local, device=device) + tmp[:, :seq_len, :] = logits_local[i : i + 1, :seq_len, :] + packed_logits[ + :, + run_seq // self.cp_size : (run_seq + padded_seq_len) // self.cp_size, + :, + ] = _get_tokens_on_this_cp_rank( + tmp, torch.distributed.get_rank(cp_group), self.cp_size + ) + run_seq += padded_seq_len + + return logits_local, packed_logits + + # --------------------------------------------------------------------- + # Loss: SequencePackingLossWrapper + ClippedPGLossFn + # --------------------------------------------------------------------- + loss_cfg = { + # From examples/configs/grpo_math_1B_megatron.yaml (loss_fn section) + "reference_policy_kl_penalty": 0.01, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + "ratio_clip_c": None, + # Required by ClippedPGLossConfig but not in that YAML + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, + "force_on_policy_ratio": False, + } + + # Global normalization factors (token-level loss uses global_valid_toks) + valid_toks = int(torch.clamp(seq_lengths - 1, min=0).sum().item()) + global_valid_toks = torch.tensor(valid_toks, dtype=torch.float32, device=device) + global_valid_seqs = torch.tensor(batch_size, dtype=torch.float32, device=device) + + return { + "loss_cfg": loss_cfg, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "data_dict": data_dict, + "global_valid_seqs": global_valid_seqs, + "global_valid_toks": global_valid_toks, + "make_logits_and_packed_logits": make_logits_and_packed_logits, + } + + def run_compare_sequence_packing_wrappers(self): + """ + Compare helper (for when your candidate/fused wrapper exists): + - Builds inputs ONCE + - Runs baseline wrapper and candidate wrapper on identical inputs + - Returns loss/metrics + max grad for each + + Assumes candidate_wrapper_fqn points to a class with the same constructor as + SequencePackingLossWrapper: (loss_fn, cu_seqlens_q, cu_seqlens_q_padded). + """ + rank, _my_cp_rank, my_tp_rank, cp_group, tp_group = self._setup_process_groups() + tc = self._build_test_case(cp_group=cp_group, my_tp_rank=my_tp_rank) + base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + + # Instantiate wrappers + baseline_wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + ) + + candidate_wrapper = SequencePackingFusionLossWrapper( + loss_fn=base_loss_fn, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + ) + + # Baseline run (fresh logits) + baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() + baseline_loss, baseline_metrics = baseline_wrapper( + baseline_packed_logits, + tc["data_dict"], + tc["global_valid_seqs"], + tc["global_valid_toks"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + (baseline_loss / self.cp_size).backward() + baseline_grad = baseline_logits.grad.clone() + + # Candidate run (fresh logits, identical values) + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_loss, candidate_metrics = candidate_wrapper( + candidate_packed_logits, + tc["data_dict"], + tc["global_valid_seqs"], + tc["global_valid_toks"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + (candidate_loss / self.cp_size).backward() + candidate_grad = candidate_logits.grad.clone() + + return { + "rank": rank, + "cp_rank": int(torch.distributed.get_rank(cp_group)), + "tp_rank": int(torch.distributed.get_rank(tp_group)), + "baseline": { + "loss": baseline_loss, + "metrics_keys": sorted(list(baseline_metrics.keys())), + "logits_local_grad": baseline_grad, + }, + "candidate": { + "loss": candidate_loss, + "metrics_keys": sorted(list(candidate_metrics.keys())), + "logits_local_grad": candidate_grad, + }, + } + + +SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN = ( + f"{SequencePackingLossWrapperBaselineActor.__module__}.SequencePackingLossWrapperBaselineActor" +) + +@pytest.fixture +def register_sequence_packing_loss_wrapper_baseline_actor(): + """Register the actor in ACTOR_ENVIRONMENT_REGISTRY for RayWorkerGroup.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] = ( + PY_EXECUTABLES.MCORE + ) + yield SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN + + if SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.fixture(scope="function") +def cluster_fixture(request): + """Create and teardown a virtual cluster for CP/TP tests.""" + cp_size = int(request.node.callspec.params["cp_size"]) + tp_size = int(request.node.callspec.params["tp_size"]) + world_size = cp_size * tp_size + + if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: + pytest.skip( + f"Not enough GPUs available. Need {world_size}, got {torch.cuda.device_count()}" + ) + + if not ray.is_initialized(): + from nemo_rl.distributed.virtual_cluster import init_ray + + init_ray() + + cluster_name = f"test-seq-pack-fusion-cp{cp_size}-tp{tp_size}" + cluster = RayVirtualCluster( + name=cluster_name, bundle_ct_per_node_list=[world_size], use_gpus=True + ) + yield cluster + cluster.shutdown() + + +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_sequence_packing_loss_wrapper_baseline_cp_tp( + cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, cp_size, tp_size +): + """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. + + Verifies that the fused wrapper produces identical: + - loss values + - backward gradients w.r.t. vocab-parallel logits + for different CP and TP configurations. + """ + cluster = cluster_fixture + actor_fqn = register_sequence_packing_loss_wrapper_baseline_actor + world_size = cp_size * tp_size + + sharding_layout = [ + [cp_rank * tp_size + tp_rank for tp_rank in range(tp_size)] + for cp_rank in range(cp_size) + ] + sharding = NamedSharding(layout=sharding_layout, names=["cp", "tp"]) + builder = RayWorkerBuilder(actor_fqn, cp_size=cp_size, tp_size=tp_size) + + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + + futures = worker_group.run_all_workers_single_data( + "run_compare_sequence_packing_wrappers" + ) + results = ray.get(futures) + + if not isinstance(results, list): + results = [results] + + for r in results: + rank = r["rank"] + # Forward: loss values must match + torch.testing.assert_close( + r["baseline"]["loss"], + r["candidate"]["loss"], + atol=1e-5, + rtol=1e-5, + msg=f"Loss mismatch on rank {rank}", + ) + # Backward: gradients w.r.t. logits must match + torch.testing.assert_close( + r["baseline"]["logits_local_grad"], + r["candidate"]["logits_local_grad"], + atol=1e-5, + rtol=1e-5, + msg=f"Gradient mismatch on rank {rank}", + ) + + worker_group.shutdown(force=True) + From 1586bdf1ccbeb1c41fa4d3d5fc7047878acbce00 Mon Sep 17 00:00:00 2001 From: mloh Date: Mon, 9 Feb 2026 20:20:04 -0800 Subject: [PATCH 03/18] Expand test case to include 8 gpus Signed-off-by: mloh --- .../test_sequence_packing_fusion.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index a774977f93..b3d92cd9c4 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -340,20 +340,24 @@ def register_sequence_packing_loss_wrapper_baseline_actor(): @pytest.fixture(scope="function") def cluster_fixture(request): """Create and teardown a virtual cluster for CP/TP tests.""" - cp_size = int(request.node.callspec.params["cp_size"]) - tp_size = int(request.node.callspec.params["tp_size"]) + cp_size, tp_size = request.node.callspec.params["cp_tp"] world_size = cp_size * tp_size - if not torch.cuda.is_available() or torch.cuda.device_count() < world_size: - pytest.skip( - f"Not enough GPUs available. Need {world_size}, got {torch.cuda.device_count()}" - ) - if not ray.is_initialized(): from nemo_rl.distributed.virtual_cluster import init_ray - init_ray() + # Check available GPUs via Ray cluster resources (works across multi-node), + # falling back to local torch.cuda.device_count() if Ray has no GPU info. + available_gpus = int(ray.cluster_resources().get("GPU", 0)) + if available_gpus == 0: + available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + + if available_gpus < world_size: + pytest.skip( + f"Not enough GPUs available. Need {world_size}, got {available_gpus}" + ) + cluster_name = f"test-seq-pack-fusion-cp{cp_size}-tp{tp_size}" cluster = RayVirtualCluster( name=cluster_name, bundle_ct_per_node_list=[world_size], use_gpus=True @@ -362,10 +366,20 @@ def cluster_fixture(request): cluster.shutdown() -@pytest.mark.parametrize("cp_size", [1, 2]) -@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "cp_tp", + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (2, 4), + (4, 2), + ], + ids=lambda cp_tp: f"cp{cp_tp[0]}_tp{cp_tp[1]}", +) def test_sequence_packing_loss_wrapper_baseline_cp_tp( - cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, cp_size, tp_size + cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, cp_tp ): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. @@ -374,6 +388,7 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( - backward gradients w.r.t. vocab-parallel logits for different CP and TP configurations. """ + cp_size, tp_size = cp_tp cluster = cluster_fixture actor_fqn = register_sequence_packing_loss_wrapper_baseline_actor world_size = cp_size * tp_size From 2029231c2fb44da81b8f8bcadbf708a15fac0345 Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Feb 2026 10:13:53 -0800 Subject: [PATCH 04/18] Cache packed input ids Signed-off-by: mloh --- nemo_rl/algorithms/loss/loss_functions.py | 14 +++++++ nemo_rl/distributed/model_utils.py | 3 +- .../test_sequence_packing_fusion.py | 37 ++++++++++++++----- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index c72269eee1..1ae6982f7a 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -19,6 +19,20 @@ from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict +<<<<<<< HEAD:nemo_rl/algorithms/loss/loss_functions.py +======= +from nemo_rl.distributed.model_utils import ( + ChunkedDistributedEntropy, + ChunkedDistributedGatherLogprob, + _get_tokens_on_this_cp_rank, + allgather_cp_sharded_tensor, + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, + gather_logits_at_global_indices, + get_logprobs_from_vocab_parallel_logits, +) +from nemo_rl.utils.nsys import wrap_with_nvtx_name +>>>>>>> d3f2b224 (Cache packed input ids):nemo_rl/algorithms/loss_functions.py Tensor = TypeVar("Tensor", bound=torch.Tensor) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index a5837e7753..658da5f0ae 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -23,6 +23,7 @@ need_top_k_or_top_p_filtering, ) +from nemo_rl.utils.nsys import wrap_with_nvtx_name @torch.no_grad() def _compute_distributed_log_softmax( @@ -929,7 +930,7 @@ def from_parallel_logits_to_logprobs( return logprobs[:, :-1] - +@wrap_with_nvtx_name("from_parallel_logits_to_logprobs_packed_sequences") def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits: torch.Tensor, target: torch.Tensor, diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index b3d92cd9c4..87bf29392f 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -235,26 +235,31 @@ def make_logits_and_packed_logits(): "loss_cfg": loss_cfg, "cu_seqlens": cu_seqlens, "cu_seqlens_padded": cu_seqlens_padded, + "packed_input_ids": _packed_input_ids, # [1, T_packed] from _pack_sequences_for_megatron "data_dict": data_dict, "global_valid_seqs": global_valid_seqs, "global_valid_toks": global_valid_toks, "make_logits_and_packed_logits": make_logits_and_packed_logits, } - def run_compare_sequence_packing_wrappers(self): + def run_compare_sequence_packing_wrappers(self, use_cached_packed_input_ids: bool = False): """ Compare helper (for when your candidate/fused wrapper exists): - Builds inputs ONCE - Runs baseline wrapper and candidate wrapper on identical inputs - Returns loss/metrics + max grad for each - Assumes candidate_wrapper_fqn points to a class with the same constructor as - SequencePackingLossWrapper: (loss_fn, cu_seqlens_q, cu_seqlens_q_padded). + Args: + use_cached_packed_input_ids: If True, store pre-packed input_ids in data dict + so the fused wrapper skips _pack_input_ids. If False, the fused wrapper + packs on the fly (fallback path). """ rank, _my_cp_rank, my_tp_rank, cp_group, tp_group = self._setup_process_groups() tc = self._build_test_case(cp_group=cp_group, my_tp_rank=my_tp_rank) base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + data_dict = tc["data_dict"] + # Instantiate wrappers baseline_wrapper = SequencePackingLossWrapper( loss_fn=base_loss_fn, @@ -268,11 +273,11 @@ def run_compare_sequence_packing_wrappers(self): cu_seqlens_q_padded=tc["cu_seqlens_padded"], ) - # Baseline run (fresh logits) + # Baseline run (fresh logits) — uses the original data_dict without packed_input_ids baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() baseline_loss, baseline_metrics = baseline_wrapper( baseline_packed_logits, - tc["data_dict"], + data_dict, tc["global_valid_seqs"], tc["global_valid_toks"], vocab_parallel_rank=my_tp_rank, @@ -283,10 +288,16 @@ def run_compare_sequence_packing_wrappers(self): baseline_grad = baseline_logits.grad.clone() # Candidate run (fresh logits, identical values) + # Optionally add pre-packed input_ids to data dict for the fused wrapper only + candidate_data_dict = data_dict + if use_cached_packed_input_ids: + candidate_data_dict = BatchedDataDict(dict(data_dict)) + candidate_data_dict["packed_input_ids"] = tc["packed_input_ids"] + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() candidate_loss, candidate_metrics = candidate_wrapper( candidate_packed_logits, - tc["data_dict"], + candidate_data_dict, tc["global_valid_seqs"], tc["global_valid_toks"], vocab_parallel_rank=my_tp_rank, @@ -378,15 +389,22 @@ def cluster_fixture(request): ], ids=lambda cp_tp: f"cp{cp_tp[0]}_tp{cp_tp[1]}", ) +@pytest.mark.parametrize( + "use_cached_packed_input_ids", + [False, True], + ids=["pack_on_the_fly", "cached_packed_input_ids"], +) def test_sequence_packing_loss_wrapper_baseline_cp_tp( - cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, cp_tp + cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, + cp_tp, use_cached_packed_input_ids, ): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. Verifies that the fused wrapper produces identical: - loss values - backward gradients w.r.t. vocab-parallel logits - for different CP and TP configurations. + for different CP and TP configurations, and both with and without + pre-packed input_ids cached in the data dict. """ cp_size, tp_size = cp_tp cluster = cluster_fixture @@ -408,7 +426,8 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( ) futures = worker_group.run_all_workers_single_data( - "run_compare_sequence_packing_wrappers" + "run_compare_sequence_packing_wrappers", + use_cached_packed_input_ids=use_cached_packed_input_ids, ) results = ray.get(futures) From 6ba0a31310a290eee16760ae2d2ea12ab764c9a9 Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Feb 2026 16:36:33 -0800 Subject: [PATCH 05/18] Fix lint issue Signed-off-by: mloh --- .../test_sequence_packing_fusion.py | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index 87bf29392f..017e2e2a1c 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -32,8 +32,8 @@ from nemo_rl.algorithms.loss_functions import ( ClippedPGLossFn, - SequencePackingLossWrapper, SequencePackingFusionLossWrapper, + SequencePackingLossWrapper, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -58,7 +58,7 @@ def _setup_process_groups(self): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) assert world_size == self.cp_size * self.tp_size, ( - f"Expected WORLD_SIZE={self.cp_size*self.tp_size}, got {world_size}." + f"Expected WORLD_SIZE={self.cp_size * self.tp_size}, got {world_size}." ) # --------------------------------------------------------------------- @@ -73,12 +73,16 @@ def _setup_process_groups(self): # CP groups: one per tp_rank, varying cp coordinate for tp_rank in range(self.tp_size): - ranks = [cp_rank * self.tp_size + tp_rank for cp_rank in range(self.cp_size)] + ranks = [ + cp_rank * self.tp_size + tp_rank for cp_rank in range(self.cp_size) + ] cp_groups.append(torch.distributed.new_group(ranks=ranks)) # TP groups: one per cp_rank, varying tp coordinate for cp_rank in range(self.cp_size): - ranks = [cp_rank * self.tp_size + tp_rank for tp_rank in range(self.tp_size)] + ranks = [ + cp_rank * self.tp_size + tp_rank for tp_rank in range(self.tp_size) + ] tp_groups.append(torch.distributed.new_group(ranks=ranks)) my_tp_rank = rank % self.tp_size @@ -109,14 +113,23 @@ def _build_test_case(self, cp_group, my_tp_rank: int): # Variable lengths, but <= max_seq_len seq_lengths = torch.tensor( - [max_seq_len // 4, max_seq_len // 2, max_seq_len // 3, max_seq_len * 3 // 4], + [ + max_seq_len // 4, + max_seq_len // 2, + max_seq_len // 3, + max_seq_len * 3 // 4, + ], dtype=torch.int32, device=device, ) # Input ids + masks - input_ids = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) - token_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) + input_ids = torch.zeros( + batch_size, max_seq_len, dtype=torch.long, device=device, + ) + token_mask = torch.zeros( + batch_size, max_seq_len, dtype=torch.float32, device=device, + ) for i in range(batch_size): L = int(seq_lengths[i].item()) input_ids[i, :L] = torch.randint(0, vocab_size_total, (L,), device=device) @@ -166,7 +179,11 @@ def _build_test_case(self, cp_group, my_tp_rank: int): # --------------------------------------------------------------------- # Global logits (same across ranks), then slice by TP rank full_logits = torch.randn( - batch_size, max_seq_len, vocab_size_total, device=device, dtype=torch.float32 + batch_size, + max_seq_len, + vocab_size_total, + device=device, + dtype=torch.float32, ) def make_logits_and_packed_logits(): @@ -196,7 +213,8 @@ def make_logits_and_packed_logits(): tmp[:, :seq_len, :] = logits_local[i : i + 1, :seq_len, :] packed_logits[ :, - run_seq // self.cp_size : (run_seq + padded_seq_len) // self.cp_size, + run_seq // self.cp_size : (run_seq + padded_seq_len) + // self.cp_size, :, ] = _get_tokens_on_this_cp_rank( tmp, torch.distributed.get_rank(cp_group), self.cp_size @@ -242,7 +260,9 @@ def make_logits_and_packed_logits(): "make_logits_and_packed_logits": make_logits_and_packed_logits, } - def run_compare_sequence_packing_wrappers(self, use_cached_packed_input_ids: bool = False): + def run_compare_sequence_packing_wrappers( + self, use_cached_packed_input_ids: bool = False, + ): """ Compare helper (for when your candidate/fused wrapper exists): - Builds inputs ONCE @@ -294,7 +314,9 @@ def run_compare_sequence_packing_wrappers(self, use_cached_packed_input_ids: boo candidate_data_dict = BatchedDataDict(dict(data_dict)) candidate_data_dict["packed_input_ids"] = tc["packed_input_ids"] - candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_logits, candidate_packed_logits = tc[ + "make_logits_and_packed_logits" + ]() candidate_loss, candidate_metrics = candidate_wrapper( candidate_packed_logits, candidate_data_dict, @@ -324,9 +346,7 @@ def run_compare_sequence_packing_wrappers(self, use_cached_packed_input_ids: boo } -SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN = ( - f"{SequencePackingLossWrapperBaselineActor.__module__}.SequencePackingLossWrapperBaselineActor" -) +SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN = f"{SequencePackingLossWrapperBaselineActor.__module__}.SequencePackingLossWrapperBaselineActor" @pytest.fixture def register_sequence_packing_loss_wrapper_baseline_actor(): @@ -341,12 +361,13 @@ def register_sequence_packing_loss_wrapper_baseline_actor(): if SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: if original_registry_value is None: - del ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] + del ACTOR_ENVIRONMENT_REGISTRY[ + SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN + ] else: - ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] = ( - original_registry_value - ) - + ACTOR_ENVIRONMENT_REGISTRY[ + SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN + ] = original_registry_value @pytest.fixture(scope="function") def cluster_fixture(request): @@ -395,8 +416,10 @@ def cluster_fixture(request): ids=["pack_on_the_fly", "cached_packed_input_ids"], ) def test_sequence_packing_loss_wrapper_baseline_cp_tp( - cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, - cp_tp, use_cached_packed_input_ids, + cluster_fixture, + register_sequence_packing_loss_wrapper_baseline_actor, + cp_tp, + use_cached_packed_input_ids, ): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. From 6cd1f384a46a8e7de962488d72d4f7e9deff5049 Mon Sep 17 00:00:00 2001 From: mloh Date: Wed, 11 Feb 2026 10:42:03 -0800 Subject: [PATCH 06/18] Fix formatting issue Signed-off-by: mloh --- nemo_rl/algorithms/loss/loss_functions.py | 14 -------- nemo_rl/distributed/model_utils.py | 2 ++ .../test_sequence_packing_fusion.py | 35 ++++++++++++------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 1ae6982f7a..c72269eee1 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -19,20 +19,6 @@ from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -<<<<<<< HEAD:nemo_rl/algorithms/loss/loss_functions.py -======= -from nemo_rl.distributed.model_utils import ( - ChunkedDistributedEntropy, - ChunkedDistributedGatherLogprob, - _get_tokens_on_this_cp_rank, - allgather_cp_sharded_tensor, - from_parallel_logits_to_logprobs, - from_parallel_logits_to_logprobs_packed_sequences, - gather_logits_at_global_indices, - get_logprobs_from_vocab_parallel_logits, -) -from nemo_rl.utils.nsys import wrap_with_nvtx_name ->>>>>>> d3f2b224 (Cache packed input ids):nemo_rl/algorithms/loss_functions.py Tensor = TypeVar("Tensor", bound=torch.Tensor) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 658da5f0ae..7ee8917fa2 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -25,6 +25,7 @@ from nemo_rl.utils.nsys import wrap_with_nvtx_name + @torch.no_grad() def _compute_distributed_log_softmax( vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup @@ -930,6 +931,7 @@ def from_parallel_logits_to_logprobs( return logprobs[:, :-1] + @wrap_with_nvtx_name("from_parallel_logits_to_logprobs_packed_sequences") def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits: torch.Tensor, diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index 017e2e2a1c..cccc43dd0a 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -31,9 +31,9 @@ import torch from nemo_rl.algorithms.loss_functions import ( - ClippedPGLossFn, + ClippedPGLossFn, SequencePackingFusionLossWrapper, - SequencePackingLossWrapper, + SequencePackingLossWrapper, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -125,10 +125,16 @@ def _build_test_case(self, cp_group, my_tp_rank: int): # Input ids + masks input_ids = torch.zeros( - batch_size, max_seq_len, dtype=torch.long, device=device, + batch_size, + max_seq_len, + dtype=torch.long, + device=device, ) token_mask = torch.zeros( - batch_size, max_seq_len, dtype=torch.float32, device=device, + batch_size, + max_seq_len, + dtype=torch.float32, + device=device, ) for i in range(batch_size): L = int(seq_lengths[i].item()) @@ -179,10 +185,10 @@ def _build_test_case(self, cp_group, my_tp_rank: int): # --------------------------------------------------------------------- # Global logits (same across ranks), then slice by TP rank full_logits = torch.randn( - batch_size, - max_seq_len, - vocab_size_total, - device=device, + batch_size, + max_seq_len, + vocab_size_total, + device=device, dtype=torch.float32, ) @@ -261,7 +267,8 @@ def make_logits_and_packed_logits(): } def run_compare_sequence_packing_wrappers( - self, use_cached_packed_input_ids: bool = False, + self, + use_cached_packed_input_ids: bool = False, ): """ Compare helper (for when your candidate/fused wrapper exists): @@ -286,7 +293,7 @@ def run_compare_sequence_packing_wrappers( cu_seqlens_q=tc["cu_seqlens"], cu_seqlens_q_padded=tc["cu_seqlens_padded"], ) - + candidate_wrapper = SequencePackingFusionLossWrapper( loss_fn=base_loss_fn, cu_seqlens_q=tc["cu_seqlens"], @@ -348,6 +355,7 @@ def run_compare_sequence_packing_wrappers( SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN = f"{SequencePackingLossWrapperBaselineActor.__module__}.SequencePackingLossWrapperBaselineActor" + @pytest.fixture def register_sequence_packing_loss_wrapper_baseline_actor(): """Register the actor in ACTOR_ENVIRONMENT_REGISTRY for RayWorkerGroup.""" @@ -369,6 +377,7 @@ def register_sequence_packing_loss_wrapper_baseline_actor(): SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN ] = original_registry_value + @pytest.fixture(scope="function") def cluster_fixture(request): """Create and teardown a virtual cluster for CP/TP tests.""" @@ -377,6 +386,7 @@ def cluster_fixture(request): if not ray.is_initialized(): from nemo_rl.distributed.virtual_cluster import init_ray + init_ray() # Check available GPUs via Ray cluster resources (works across multi-node), @@ -416,9 +426,9 @@ def cluster_fixture(request): ids=["pack_on_the_fly", "cached_packed_input_ids"], ) def test_sequence_packing_loss_wrapper_baseline_cp_tp( - cluster_fixture, + cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, - cp_tp, + cp_tp, use_cached_packed_input_ids, ): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. @@ -477,4 +487,3 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( ) worker_group.shutdown(force=True) - From 50ba5c878f7a3fd3800514062307ca290ab83d9c Mon Sep 17 00:00:00 2001 From: mloh Date: Wed, 11 Feb 2026 12:53:03 -0800 Subject: [PATCH 07/18] Add checks and provide meaningful assertion error Signed-off-by: mloh --- nemo_rl/distributed/model_utils.py | 1 - tests/unit/algorithms/test_sequence_packing_fusion.py | 1 - 2 files changed, 2 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 7ee8917fa2..8a10ac5104 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -932,7 +932,6 @@ def from_parallel_logits_to_logprobs( return logprobs[:, :-1] -@wrap_with_nvtx_name("from_parallel_logits_to_logprobs_packed_sequences") def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits: torch.Tensor, target: torch.Tensor, diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index cccc43dd0a..664a8d36cb 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -442,7 +442,6 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( cp_size, tp_size = cp_tp cluster = cluster_fixture actor_fqn = register_sequence_packing_loss_wrapper_baseline_actor - world_size = cp_size * tp_size sharding_layout = [ [cp_rank * tp_size + tp_rank for tp_rank in range(tp_size)] From 3b723cd84b65a19e5e806caaf3c504fcb64d954e Mon Sep 17 00:00:00 2001 From: mloh Date: Fri, 20 Feb 2026 13:40:43 -0800 Subject: [PATCH 08/18] Compute rolled target once in fusion path Signed-off-by: mloh --- nemo_rl/algorithms/loss/loss_functions.py | 1 + nemo_rl/distributed/model_utils.py | 63 ++++++++++--------- .../test_sequence_packing_fusion.py | 38 +++-------- 3 files changed, 43 insertions(+), 59 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index c72269eee1..e138121493 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -115,6 +115,7 @@ class ClippedPGLossFn(LossFunction): input_type = LossInputType.LOGPROB def __init__(self, cfg: ClippedPGLossConfig): + self.chunk_size: Optional[int] = None self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 8a10ac5104..ba7810d790 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -944,16 +944,19 @@ def from_parallel_logits_to_logprobs_packed_sequences( cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, sampling_params: Optional[TrainingSamplingParams] = None, + target_is_pre_rolled: bool = False, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. Args: vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T // CP, vocab_size//TP] where T is the total number of tokens across all packed sequences. - target (torch.Tensor): Packed target token indices with shape [1, T]. - NOTE: Must be the unmodified targets as this function will shift them internally. - cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. - cu_seqlens[i] indicates the start position of sequence i in the packed format. + target (torch.Tensor): Packed target token indices. + If target_is_pre_rolled=False: shape [1, T] — unmodified targets, rolled internally. + If target_is_pre_rolled=True: shape [1, T // CP] — pre-rolled and pre-CP-sharded. + cu_seqlens_padded (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1]. + cu_seqlens_padded[i] indicates the start position of sequence i in the packed format + (full, not CP-adjusted). unpacked_seqlen (int): The length of the unpacked sequence tensor. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. @@ -961,44 +964,48 @@ def from_parallel_logits_to_logprobs_packed_sequences( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering. + target_is_pre_rolled (bool): If True, target is already shifted and CP-sharded to match + vocab_parallel_logits shape, skipping the internal per-sequence roll+CP-shard loop. Returns: torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. The total length is reduced by batch_size due to target shifting (one token per sequence). """ - # Remove batch dimension to work with [T, vocab_size] and [T] - vocab_parallel_logits = vocab_parallel_logits.squeeze(0) - target = target.squeeze(0) - batch_size = cu_seqlens_padded.shape[0] - 1 cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) - cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group) - # Roll each sequence individually - rolled_targets = torch.zeros( - target.shape[0] // cp_size, dtype=target.dtype, device=target.device - ) - for i in range(batch_size): - start_idx = cu_seqlens_padded[i].item() - end_idx = cu_seqlens_padded[i + 1].item() + if not target_is_pre_rolled: + # Roll each sequence individually and CP-shard the targets + # Remove batch dimension to work with [T, vocab_size] and [T] + vocab_parallel_logits = vocab_parallel_logits.squeeze(0) + target = target.squeeze(0) + cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group) - # Get the sequence targets and roll by -1 - seq_targets = target[start_idx:end_idx] - rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) - rolled_targets[start_idx // cp_size : end_idx // cp_size] = ( - _get_tokens_on_this_cp_rank(rolled_seq_targets, cp_rank, cp_size, seq_dim=0) + rolled_targets = torch.zeros( + target.shape[0] // cp_size, dtype=target.dtype, device=target.device ) + for i in range(batch_size): + start_idx = cu_seqlens_padded[i].item() + end_idx = cu_seqlens_padded[i + 1].item() + + seq_targets = target[start_idx:end_idx] + rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0) + rolled_targets[start_idx // cp_size : end_idx // cp_size] = ( + _get_tokens_on_this_cp_rank( + rolled_seq_targets, cp_rank, cp_size, seq_dim=0 + ) + ) - # Add batch dimension back for DistributedLogprob - rolled_targets = rolled_targets.unsqueeze(0) - vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) + target = rolled_targets.unsqueeze(0) + vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation if need_top_k_or_top_p_filtering(sampling_params): if chunk_size is not None: probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, group, sampling_params.top_k, sampling_params.top_p, @@ -1008,7 +1015,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( else: probs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, group, sampling_params.top_k, sampling_params.top_p, @@ -1018,7 +1025,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( if chunk_size is not None: probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, vocab_start_index, vocab_end_index, chunk_size, @@ -1028,7 +1035,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( else: probs: torch.Tensor = DistributedLogprob.apply( # type: ignore vocab_parallel_logits, - rolled_targets, + target, vocab_start_index, vocab_end_index, group, diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index 664a8d36cb..479c7d624f 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -259,27 +259,17 @@ def make_logits_and_packed_logits(): "loss_cfg": loss_cfg, "cu_seqlens": cu_seqlens, "cu_seqlens_padded": cu_seqlens_padded, - "packed_input_ids": _packed_input_ids, # [1, T_packed] from _pack_sequences_for_megatron "data_dict": data_dict, "global_valid_seqs": global_valid_seqs, "global_valid_toks": global_valid_toks, "make_logits_and_packed_logits": make_logits_and_packed_logits, } - def run_compare_sequence_packing_wrappers( - self, - use_cached_packed_input_ids: bool = False, - ): - """ - Compare helper (for when your candidate/fused wrapper exists): - - Builds inputs ONCE - - Runs baseline wrapper and candidate wrapper on identical inputs - - Returns loss/metrics + max grad for each - - Args: - use_cached_packed_input_ids: If True, store pre-packed input_ids in data dict - so the fused wrapper skips _pack_input_ids. If False, the fused wrapper - packs on the fly (fallback path). + def run_compare_sequence_packing_wrappers(self): + """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. + + Builds inputs once, runs both wrappers on identical inputs, + and returns loss/metrics + max grad for each. """ rank, _my_cp_rank, my_tp_rank, cp_group, tp_group = self._setup_process_groups() tc = self._build_test_case(cp_group=cp_group, my_tp_rank=my_tp_rank) @@ -315,18 +305,12 @@ def run_compare_sequence_packing_wrappers( baseline_grad = baseline_logits.grad.clone() # Candidate run (fresh logits, identical values) - # Optionally add pre-packed input_ids to data dict for the fused wrapper only - candidate_data_dict = data_dict - if use_cached_packed_input_ids: - candidate_data_dict = BatchedDataDict(dict(data_dict)) - candidate_data_dict["packed_input_ids"] = tc["packed_input_ids"] - candidate_logits, candidate_packed_logits = tc[ "make_logits_and_packed_logits" ]() candidate_loss, candidate_metrics = candidate_wrapper( candidate_packed_logits, - candidate_data_dict, + data_dict, tc["global_valid_seqs"], tc["global_valid_toks"], vocab_parallel_rank=my_tp_rank, @@ -420,24 +404,17 @@ def cluster_fixture(request): ], ids=lambda cp_tp: f"cp{cp_tp[0]}_tp{cp_tp[1]}", ) -@pytest.mark.parametrize( - "use_cached_packed_input_ids", - [False, True], - ids=["pack_on_the_fly", "cached_packed_input_ids"], -) def test_sequence_packing_loss_wrapper_baseline_cp_tp( cluster_fixture, register_sequence_packing_loss_wrapper_baseline_actor, cp_tp, - use_cached_packed_input_ids, ): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. Verifies that the fused wrapper produces identical: - loss values - backward gradients w.r.t. vocab-parallel logits - for different CP and TP configurations, and both with and without - pre-packed input_ids cached in the data dict. + for different CP and TP configurations. """ cp_size, tp_size = cp_tp cluster = cluster_fixture @@ -459,7 +436,6 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( futures = worker_group.run_all_workers_single_data( "run_compare_sequence_packing_wrappers", - use_cached_packed_input_ids=use_cached_packed_input_ids, ) results = ray.get(futures) From 275902ef87a500b4ebe7f4ffee59db783119c6a3 Mon Sep 17 00:00:00 2001 From: mloh Date: Sat, 21 Feb 2026 13:27:30 -0800 Subject: [PATCH 09/18] Add sanity check to ensure gradient is not None in unit test Signed-off-by: mloh --- tests/unit/algorithms/test_sequence_packing_fusion.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index 479c7d624f..4c03294485 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -444,6 +444,11 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( for r in results: rank = r["rank"] + # Sanity: gradients must be non-None and non-zero (autograd is connected) + for label in ("baseline", "candidate"): + grad = r[label]["logits_local_grad"] + assert grad is not None, f"{label} grad is None on rank {rank}" + assert grad.abs().sum() > 0, f"{label} grad is all zeros on rank {rank}" # Forward: loss values must match torch.testing.assert_close( r["baseline"]["loss"], From 1adc2420ffcf129d2f8cea7741bf27353285ee8b Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Mar 2026 17:04:54 -0700 Subject: [PATCH 10/18] Refactor seqpack fusion wrapper to follow the new interface Signed-off-by: mloh --- nemo_rl/algorithms/loss/__init__.py | 6 +- nemo_rl/algorithms/loss/loss_functions.py | 1 - nemo_rl/algorithms/loss/utils.py | 86 +++++++++++++++++++++++ nemo_rl/algorithms/loss/wrapper.py | 82 +++++---------------- nemo_rl/models/megatron/train.py | 16 ++++- 5 files changed, 124 insertions(+), 67 deletions(-) diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py index ed64591544..a2d404fdaa 100644 --- a/nemo_rl/algorithms/loss/__init__.py +++ b/nemo_rl/algorithms/loss/__init__.py @@ -26,7 +26,10 @@ PreferenceLossDataDict, PreferenceLossFn, ) -from nemo_rl.algorithms.loss.utils import prepare_loss_input +from nemo_rl.algorithms.loss.utils import ( + prepare_loss_input, + prepare_packed_loss_input, +) from nemo_rl.algorithms.loss.wrapper import ( SequencePackingFusionLossWrapper, SequencePackingLossWrapper, @@ -47,6 +50,7 @@ "PreferenceLossDataDict", "PreferenceLossFn", "prepare_loss_input", + "prepare_packed_loss_input", "SequencePackingFusionLossWrapper", "SequencePackingLossWrapper", "wrap_loss_fn_with_input_preparation", diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index e138121493..c72269eee1 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -115,7 +115,6 @@ class ClippedPGLossFn(LossFunction): input_type = LossInputType.LOGPROB def __init__(self, cfg: ClippedPGLossConfig): - self.chunk_size: Optional[int] = None self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 70edc220e0..d8e2030270 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -24,6 +24,7 @@ from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( + from_parallel_logits_to_logprobs_packed_sequences, get_distillation_topk_logprobs_from_logits, get_next_token_logprobs_from_logits, ) @@ -119,3 +120,88 @@ def prepare_loss_input( raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") return loss_input, data + +def _pack_input_ids( + input_ids: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, +) -> torch.Tensor: + """Pack input_ids from [B, S] to [1, T_packed] using sequence boundaries. + + Each sequence i is placed at cu_seqlens_q_padded[i] in the packed tensor, + with actual_len = cu_seqlens_q[i+1] - cu_seqlens_q[i] tokens copied. + """ + batch_size = input_ids.shape[0] + total_packed_len = int(cu_seqlens_q_padded[-1].item()) + packed = torch.zeros( + 1, total_packed_len, dtype=input_ids.dtype, device=input_ids.device + ) + for i in range(batch_size): + actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) + packed_start = int(cu_seqlens_q_padded[i].item()) + packed[0, packed_start : packed_start + actual_len] = input_ids[ + i, :actual_len + ] + return packed + + +def prepare_packed_loss_input( + logits: torch.Tensor, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + cu_seqlens_q: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +) -> dict[str, Any]: + """Prepare loss input from packed logits in a single fused pass. + + Unlike prepare_loss_input which operates on a single (unpacked) sequence, + this function computes log probabilities from packed logits across all + sequences at once using from_parallel_logits_to_logprobs_packed_sequences. + + Currently only supports LossInputType.LOGPROB. + + Args: + logits: Packed logits from the model [1, T_packed // CP, V // TP]. + data: Microbatch data (unpacked, [B, S]). + loss_fn: Loss function (must have input_type == LossInputType.LOGPROB). + cu_seqlens_q: Unpadded cumulative sequence lengths [B+1]. + cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1]. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + + Returns: + Loss input dict with key "next_token_logprobs". + """ + if loss_fn.input_type != LossInputType.LOGPROB: + raise ValueError( + f"prepare_packed_loss_input only supports LossInputType.LOGPROB, " + f"got {loss_fn.input_type}. Use SequencePackingLossWrapper with " + f"prepare_loss_input for other types." + ) + assert vocab_parallel_group is not None, ( + "prepare_packed_loss_input requires vocab_parallel_group (Megatron TP)." + ) + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided with vocab_parallel_group." + ) + + packed_input_ids = _pack_input_ids(data["input_ids"], cu_seqlens_q, cu_seqlens_q_padded) + unpacked_seqlen = data["input_ids"].shape[1] + + logprobs = from_parallel_logits_to_logprobs_packed_sequences( + logits.to(torch.float32), + packed_input_ids, + cu_seqlens_q_padded, + unpacked_seqlen, + vocab_start_index=vocab_parallel_rank * logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1], + group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + + return {"next_token_logprobs": logprobs} diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index 9319e3b0db..0fa3cac2d3 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -18,11 +18,8 @@ import torch import torch.distributed -from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - from_parallel_logits_to_logprobs_packed_sequences, -) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -159,38 +156,29 @@ class SequencePackingFusionLossWrapper: """Fused sequence packing loss wrapper that processes all sequences in one forward pass. Unlike SequencePackingLossWrapper which iterates over sequences one at a time, - this wrapper computes log probabilities from packed logits in a single shot using - from_parallel_logits_to_logprobs_packed_sequences, then calls the loss function - with the pre-computed logprobs. + this wrapper calls prepare_fn once on the packed logits to compute log + probabilities in a single shot, then calls the loss function once with the + pre-computed result. This avoids per-sequence kernel launches and TP/CP communication overhead while producing numerically identical results. - Requirements: - - The wrapped loss_fn must have input_type == LossInputType.LOGPROB. - - vocab_parallel_group and vocab_parallel_rank must be provided (Megatron TP). + The prepare_fn should be prepare_packed_loss_input (from nemo_rl.algorithms.loss.utils), + which currently only supports LossInputType.LOGPROB. """ def __init__( self, loss_fn: LossFunction, + prepare_fn: Callable[..., Any], cu_seqlens_q: Tensor, cu_seqlens_q_padded: Optional[Tensor] = None, vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ): - assert loss_fn.input_type == LossInputType.LOGPROB, ( - f"SequencePackingFusionLossWrapper only supports LossInputType.LOGPROB, " - f"got {loss_fn.input_type}. Use SequencePackingLossWrapper for other types." - ) - assert vocab_parallel_group is not None, ( - "SequencePackingFusionLossWrapper requires vocab_parallel_group (Megatron TP)." - ) - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided with vocab_parallel_group." - ) self.loss_fn = loss_fn + self.prepare_fn = prepare_fn self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_q_padded = ( cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q @@ -199,27 +187,6 @@ def __init__( self.vocab_parallel_group = vocab_parallel_group self.context_parallel_group = context_parallel_group - def _pack_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Pack input_ids from [B, S] to [1, T_packed] using sequence boundaries. - - Each sequence i is placed at cu_seqlens_q_padded[i] in the packed tensor, - with actual_len = cu_seqlens_q[i+1] - cu_seqlens_q[i] tokens copied. - """ - batch_size = input_ids.shape[0] - total_packed_len = int(self.cu_seqlens_q_padded[-1].item()) - packed = torch.zeros( - 1, total_packed_len, dtype=input_ids.dtype, device=input_ids.device - ) - for i in range(batch_size): - actual_len = int( - (self.cu_seqlens_q[i + 1] - self.cu_seqlens_q[i]).item() - ) - packed_start = int(self.cu_seqlens_q_padded[i].item()) - packed[0, packed_start : packed_start + actual_len] = input_ids[ - i, :actual_len - ] - return packed - def __call__( self, next_token_logits: Tensor, @@ -227,34 +194,23 @@ def __call__( global_valid_seqs: Tensor | None, global_valid_toks: Tensor | None, ) -> tuple[Tensor, dict[str, Any]]: - """Compute loss for all packed sequences in one forward pass. - - 1. Pack input_ids from [B, S] to [1, T_packed] to match packed logit layout. - 2. Compute logprobs from packed logits via - from_parallel_logits_to_logprobs_packed_sequences -> [B, S-1]. - 3. Call the loss function with the pre-computed logprobs. - """ - packed_input_ids = self._pack_input_ids(data["input_ids"]) - unpacked_seqlen = data["input_ids"].shape[1] - - logprobs = from_parallel_logits_to_logprobs_packed_sequences( - next_token_logits.to(torch.float32), - packed_input_ids, - self.cu_seqlens_q_padded, - unpacked_seqlen, - vocab_start_index=self.vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(self.vocab_parallel_rank + 1) - * next_token_logits.shape[-1], - group=self.vocab_parallel_group, - inference_only=False, - cp_group=self.context_parallel_group, + """Compute loss for all packed sequences in one forward pass.""" + loss_input = self.prepare_fn( + logits=next_token_logits, + data=data, + loss_fn=self.loss_fn, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_q_padded=self.cu_seqlens_q_padded, + vocab_parallel_rank=self.vocab_parallel_rank, + vocab_parallel_group=self.vocab_parallel_group, + context_parallel_group=self.context_parallel_group, ) return self.loss_fn( data=data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, - next_token_logprobs=logprobs, + **loss_input, ) diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index cfbb913395..676060dc97 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -34,8 +34,10 @@ need_top_k_or_top_p_filtering, ) from nemo_rl.algorithms.loss import ( + SequencePackingFusionLossWrapper, SequencePackingLossWrapper, prepare_loss_input, + prepare_packed_loss_input, wrap_loss_fn_with_input_preparation, ) from nemo_rl.algorithms.loss.interfaces import LossFunction @@ -322,9 +324,19 @@ def __call__( # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: - loss_fn_wrapped = SequencePackingLossWrapper( + fuse_loss = self.cfg.get("sequence_packing", {}).get( + "fuse_loss", False + ) + if fuse_loss: + wrapper_cls = SequencePackingFusionLossWrapper + prepare_fn = prepare_packed_loss_input + else: + wrapper_cls = SequencePackingLossWrapper + prepare_fn = prepare_loss_input_wrapped + + loss_fn_wrapped = wrapper_cls( loss_fn=self.loss_fn, - prepare_fn=prepare_loss_input_wrapped, + prepare_fn=prepare_fn, cu_seqlens_q=packed_seq_params.cu_seqlens_q, cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, vocab_parallel_rank=get_tensor_model_parallel_rank(), From 4a7d3a29e04bd075e1c406a3cae3d2f5e42aeb18 Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Mar 2026 17:09:08 -0700 Subject: [PATCH 11/18] Fix format issue Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 8 ++++---- nemo_rl/models/megatron/train.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index d8e2030270..959f7d52e4 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -139,9 +139,7 @@ def _pack_input_ids( for i in range(batch_size): actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) packed_start = int(cu_seqlens_q_padded[i].item()) - packed[0, packed_start : packed_start + actual_len] = input_ids[ - i, :actual_len - ] + packed[0, packed_start : packed_start + actual_len] = input_ids[i, :actual_len] return packed @@ -189,7 +187,9 @@ def prepare_packed_loss_input( "vocab_parallel_rank must be provided with vocab_parallel_group." ) - packed_input_ids = _pack_input_ids(data["input_ids"], cu_seqlens_q, cu_seqlens_q_padded) + packed_input_ids = _pack_input_ids( + data["input_ids"], cu_seqlens_q, cu_seqlens_q_padded + ) unpacked_seqlen = data["input_ids"].shape[1] logprobs = from_parallel_logits_to_logprobs_packed_sequences( diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 676060dc97..a8a54027d8 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -324,9 +324,7 @@ def __call__( # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: - fuse_loss = self.cfg.get("sequence_packing", {}).get( - "fuse_loss", False - ) + fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", False) if fuse_loss: wrapper_cls = SequencePackingFusionLossWrapper prepare_fn = prepare_packed_loss_input From 5d82fa7f6cdb65f4b8c6f189fdad78a83493c31f Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Mar 2026 17:40:39 -0700 Subject: [PATCH 12/18] Refactor test case for the new interface Signed-off-by: mloh --- .../test_sequence_packing_fusion.py | 623 +++++++----------- 1 file changed, 232 insertions(+), 391 deletions(-) diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index 4c03294485..d31a0214e3 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -12,384 +12,271 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Unit tests to ensure SequencePackingFusionLossWrapper works as SequencePackingLossWrapper. - - Without explicitly calling the loss_fn sequence by sequence. +Unit tests to ensure SequencePackingFusionLossWrapper produces identical results +to SequencePackingLossWrapper. -During the forward pass, compare the loss and metrics from the two wrappers. -During the backward pass, compare the gradients from the two wrappers. +Uses distributed_test_runner (torch.multiprocessing.spawn) instead of Ray actors +so that pytest + code coverage work correctly. -For parallelism, check for CP and TP. - -For loss function, right now only supports: -- ClippedPGLossFn +For loss function, currently only supports ClippedPGLossFn. """ -import os +import functools import pytest -import ray import torch -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossFn, SequencePackingFusionLossWrapper, SequencePackingLossWrapper, + prepare_loss_input, + prepare_packed_loss_input, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.named_sharding import NamedSharding -from nemo_rl.distributed.ray_actor_environment_registry import ( - ACTOR_ENVIRONMENT_REGISTRY, - PY_EXECUTABLES, -) -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster -from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup -@ray.remote(num_gpus=1) -class SequencePackingLossWrapperBaselineActor: - def __init__(self, cp_size: int, tp_size: int): - self.cp_size = cp_size - self.tp_size = tp_size - self.env_vars = dict(os.environ) +def _setup_2d_process_groups(rank, world_size, cp_size, tp_size): + """Create 2D (cp, tp) process groups. - def _setup_process_groups(self): - torch.distributed.init_process_group(backend="nccl") + Rank layout (outer cp, inner tp): + [[0, 1, ..., tp_size-1], + [tp_size, ..., 2*tp_size-1], + ...] + """ + cp_groups = [] + tp_groups = [] - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - assert world_size == self.cp_size * self.tp_size, ( - f"Expected WORLD_SIZE={self.cp_size * self.tp_size}, got {world_size}." - ) + for tp_rank in range(tp_size): + ranks = [cp_rank * tp_size + tp_rank for cp_rank in range(cp_size)] + cp_groups.append(torch.distributed.new_group(ranks=ranks)) - # --------------------------------------------------------------------- - # Create 2D (cp, tp) process groups. - # Rank layout (outer cp, inner tp): - # [[0, 1, ..., tp_size-1], - # [tp_size, ..., 2*tp_size-1], - # ...] - # --------------------------------------------------------------------- - cp_groups: list[torch.distributed.ProcessGroup] = [] - tp_groups: list[torch.distributed.ProcessGroup] = [] - - # CP groups: one per tp_rank, varying cp coordinate - for tp_rank in range(self.tp_size): - ranks = [ - cp_rank * self.tp_size + tp_rank for cp_rank in range(self.cp_size) - ] - cp_groups.append(torch.distributed.new_group(ranks=ranks)) + for cp_rank in range(cp_size): + ranks = [cp_rank * tp_size + tp_rank for tp_rank in range(tp_size)] + tp_groups.append(torch.distributed.new_group(ranks=ranks)) - # TP groups: one per cp_rank, varying tp coordinate - for cp_rank in range(self.cp_size): - ranks = [ - cp_rank * self.tp_size + tp_rank for tp_rank in range(self.tp_size) - ] - tp_groups.append(torch.distributed.new_group(ranks=ranks)) - - my_tp_rank = rank % self.tp_size - my_cp_rank = rank // self.tp_size - cp_group = cp_groups[my_tp_rank] - tp_group = tp_groups[my_cp_rank] - return rank, my_cp_rank, my_tp_rank, cp_group, tp_group - - def _build_test_case(self, cp_group, my_tp_rank: int): - from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank - from nemo_rl.models.megatron.data import _pack_sequences_for_megatron - - # --------------------------------------------------------------------- - # Build a small packed batch. - # --------------------------------------------------------------------- - device = torch.device("cuda") - torch.manual_seed(42) # For reproducibility / determinism - - batch_size = 4 - max_seq_len = 512 - # Ensure CP load balancing requirement: divisible by (2 * cp_size) - if max_seq_len % (2 * self.cp_size) != 0: - max_seq_len = (max_seq_len // (2 * self.cp_size) + 1) * (2 * self.cp_size) - - vocab_size_total = 512 - assert vocab_size_total % self.tp_size == 0 - vocab_size_local = vocab_size_total // self.tp_size - - # Variable lengths, but <= max_seq_len - seq_lengths = torch.tensor( - [ - max_seq_len // 4, - max_seq_len // 2, - max_seq_len // 3, - max_seq_len * 3 // 4, - ], - dtype=torch.int32, - device=device, - ) + my_tp_rank = rank % tp_size + my_cp_rank = rank // tp_size + cp_group = cp_groups[my_tp_rank] + tp_group = tp_groups[my_cp_rank] + return my_cp_rank, my_tp_rank, cp_group, tp_group - # Input ids + masks - input_ids = torch.zeros( - batch_size, - max_seq_len, - dtype=torch.long, - device=device, - ) - token_mask = torch.zeros( - batch_size, - max_seq_len, - dtype=torch.float32, - device=device, - ) - for i in range(batch_size): - L = int(seq_lengths[i].item()) - input_ids[i, :L] = torch.randint(0, vocab_size_total, (L,), device=device) - token_mask[i, :L] = 1.0 - - sample_mask = torch.ones(batch_size, dtype=torch.float32, device=device) - - # Stable-ish random tensors (avoid extreme ratios/NaNs in unit test) - advantages = 0.1 * torch.randn(batch_size, max_seq_len, device=device) - prev_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) - generation_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) - reference_policy_logprobs = generation_logprobs.clone() - - data_dict = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": seq_lengths, - "token_mask": token_mask, - "sample_mask": sample_mask, - "advantages": advantages, - "prev_logprobs": prev_logprobs, - "generation_logprobs": generation_logprobs, - "reference_policy_logprobs": reference_policy_logprobs, - } - ) - - # Packed sequence metadata (CP-aware) - pad_to_multiple = self.cp_size * 2 - ( - _packed_input_ids, - _packed_input_ids_cp, - _packed_seq_params, - cu_seqlens, - cu_seqlens_padded, - ) = _pack_sequences_for_megatron( - input_ids, - seq_lengths, - pad_individual_seqs_to_multiple_of=pad_to_multiple, - pad_packed_seq_to=max_seq_len * batch_size if self.cp_size > 1 else None, - cp_rank=torch.distributed.get_rank(cp_group), - cp_size=self.cp_size, - ) - assert cu_seqlens_padded is not None - - # --------------------------------------------------------------------- - # Create vocab-parallel logits, then pack + CP-shard them into [1, T//CP, V//TP] - # --------------------------------------------------------------------- - # Global logits (same across ranks), then slice by TP rank - full_logits = torch.randn( - batch_size, - max_seq_len, - vocab_size_total, - device=device, - dtype=torch.float32, - ) - def make_logits_and_packed_logits(): - logits_local = ( - full_logits[ - :, - :, - my_tp_rank * vocab_size_local : (my_tp_rank + 1) * vocab_size_local, - ] - .clone() - .detach() - .requires_grad_(True) - ) +def _build_test_case(cp_size, tp_size, my_tp_rank, cp_group): + """Build a small packed batch with CP-aware packing.""" + from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank + from nemo_rl.models.megatron.data import _pack_sequences_for_megatron - total_padded_tokens = int(cu_seqlens_padded[-1].item()) - packed_logits = torch.zeros( - 1, total_padded_tokens // self.cp_size, vocab_size_local, device=device - ) + device = torch.device("cuda") + torch.manual_seed(42) - run_seq = 0 - for i in range(batch_size): - seq_len = int(seq_lengths[i].item()) - padded_seq_len = int( - (cu_seqlens_padded[i + 1] - cu_seqlens_padded[i]).item() - ) - tmp = torch.zeros(1, padded_seq_len, vocab_size_local, device=device) - tmp[:, :seq_len, :] = logits_local[i : i + 1, :seq_len, :] - packed_logits[ - :, - run_seq // self.cp_size : (run_seq + padded_seq_len) - // self.cp_size, - :, - ] = _get_tokens_on_this_cp_rank( - tmp, torch.distributed.get_rank(cp_group), self.cp_size - ) - run_seq += padded_seq_len - - return logits_local, packed_logits - - # --------------------------------------------------------------------- - # Loss: SequencePackingLossWrapper + ClippedPGLossFn - # --------------------------------------------------------------------- - loss_cfg = { - # From examples/configs/grpo_math_1B_megatron.yaml (loss_fn section) - "reference_policy_kl_penalty": 0.01, - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - "ratio_clip_c": None, - # Required by ClippedPGLossConfig but not in that YAML - "reference_policy_kl_type": "k3", - "kl_input_clamp_value": 20.0, - "kl_output_clamp_value": 10.0, - "truncated_importance_sampling_ratio": None, - "sequence_level_importance_ratios": False, - "force_on_policy_ratio": False, - } + batch_size = 4 + max_seq_len = 512 + if max_seq_len % (2 * cp_size) != 0: + max_seq_len = (max_seq_len // (2 * cp_size) + 1) * (2 * cp_size) - # Global normalization factors (token-level loss uses global_valid_toks) - valid_toks = int(torch.clamp(seq_lengths - 1, min=0).sum().item()) - global_valid_toks = torch.tensor(valid_toks, dtype=torch.float32, device=device) - global_valid_seqs = torch.tensor(batch_size, dtype=torch.float32, device=device) - - return { - "loss_cfg": loss_cfg, - "cu_seqlens": cu_seqlens, - "cu_seqlens_padded": cu_seqlens_padded, - "data_dict": data_dict, - "global_valid_seqs": global_valid_seqs, - "global_valid_toks": global_valid_toks, - "make_logits_and_packed_logits": make_logits_and_packed_logits, - } + vocab_size_total = 512 + assert vocab_size_total % tp_size == 0 + vocab_size_local = vocab_size_total // tp_size - def run_compare_sequence_packing_wrappers(self): - """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. + seq_lengths = torch.tensor( + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 3, max_seq_len * 3 // 4], + dtype=torch.int32, + device=device, + ) - Builds inputs once, runs both wrappers on identical inputs, - and returns loss/metrics + max grad for each. - """ - rank, _my_cp_rank, my_tp_rank, cp_group, tp_group = self._setup_process_groups() - tc = self._build_test_case(cp_group=cp_group, my_tp_rank=my_tp_rank) - base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + input_ids = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) + token_mask = torch.zeros( + batch_size, max_seq_len, dtype=torch.float32, device=device + ) + for i in range(batch_size): + L = int(seq_lengths[i].item()) + input_ids[i, :L] = torch.randint(0, vocab_size_total, (L,), device=device) + token_mask[i, :L] = 1.0 + + sample_mask = torch.ones(batch_size, dtype=torch.float32, device=device) + advantages = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + prev_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + generation_logprobs = 0.1 * torch.randn(batch_size, max_seq_len, device=device) + reference_policy_logprobs = generation_logprobs.clone() + + data_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": seq_lengths, + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "generation_logprobs": generation_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + } + ) - data_dict = tc["data_dict"] + pad_to_multiple = cp_size * 2 + ( + _packed_input_ids, + _packed_input_ids_cp, + _packed_seq_params, + cu_seqlens, + cu_seqlens_padded, + ) = _pack_sequences_for_megatron( + input_ids, + seq_lengths, + pad_individual_seqs_to_multiple_of=pad_to_multiple, + pad_packed_seq_to=max_seq_len * batch_size if cp_size > 1 else None, + cp_rank=torch.distributed.get_rank(cp_group), + cp_size=cp_size, + ) + assert cu_seqlens_padded is not None - # Instantiate wrappers - baseline_wrapper = SequencePackingLossWrapper( - loss_fn=base_loss_fn, - cu_seqlens_q=tc["cu_seqlens"], - cu_seqlens_q_padded=tc["cu_seqlens_padded"], - ) + full_logits = torch.randn( + batch_size, max_seq_len, vocab_size_total, device=device, dtype=torch.float32 + ) - candidate_wrapper = SequencePackingFusionLossWrapper( - loss_fn=base_loss_fn, - cu_seqlens_q=tc["cu_seqlens"], - cu_seqlens_q_padded=tc["cu_seqlens_padded"], + def make_logits_and_packed_logits(): + logits_local = ( + full_logits[ + :, + :, + my_tp_rank * vocab_size_local : (my_tp_rank + 1) * vocab_size_local, + ] + .clone() + .detach() + .requires_grad_(True) ) - # Baseline run (fresh logits) — uses the original data_dict without packed_input_ids - baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() - baseline_loss, baseline_metrics = baseline_wrapper( - baseline_packed_logits, - data_dict, - tc["global_valid_seqs"], - tc["global_valid_toks"], - vocab_parallel_rank=my_tp_rank, - vocab_parallel_group=tp_group, - context_parallel_group=cp_group, + total_padded_tokens = int(cu_seqlens_padded[-1].item()) + packed_logits = torch.zeros( + 1, total_padded_tokens // cp_size, vocab_size_local, device=device ) - (baseline_loss / self.cp_size).backward() - baseline_grad = baseline_logits.grad.clone() - - # Candidate run (fresh logits, identical values) - candidate_logits, candidate_packed_logits = tc[ - "make_logits_and_packed_logits" - ]() - candidate_loss, candidate_metrics = candidate_wrapper( - candidate_packed_logits, - data_dict, - tc["global_valid_seqs"], - tc["global_valid_toks"], - vocab_parallel_rank=my_tp_rank, - vocab_parallel_group=tp_group, - context_parallel_group=cp_group, - ) - (candidate_loss / self.cp_size).backward() - candidate_grad = candidate_logits.grad.clone() - - return { - "rank": rank, - "cp_rank": int(torch.distributed.get_rank(cp_group)), - "tp_rank": int(torch.distributed.get_rank(tp_group)), - "baseline": { - "loss": baseline_loss, - "metrics_keys": sorted(list(baseline_metrics.keys())), - "logits_local_grad": baseline_grad, - }, - "candidate": { - "loss": candidate_loss, - "metrics_keys": sorted(list(candidate_metrics.keys())), - "logits_local_grad": candidate_grad, - }, - } - - -SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN = f"{SequencePackingLossWrapperBaselineActor.__module__}.SequencePackingLossWrapperBaselineActor" + run_seq = 0 + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + padded_seq_len = int( + (cu_seqlens_padded[i + 1] - cu_seqlens_padded[i]).item() + ) + tmp = torch.zeros(1, padded_seq_len, vocab_size_local, device=device) + tmp[:, :seq_len, :] = logits_local[i : i + 1, :seq_len, :] + packed_logits[ + :, + run_seq // cp_size : (run_seq + padded_seq_len) // cp_size, + :, + ] = _get_tokens_on_this_cp_rank( + tmp, torch.distributed.get_rank(cp_group), cp_size + ) + run_seq += padded_seq_len + + return logits_local, packed_logits + + loss_cfg = { + "reference_policy_kl_penalty": 0.01, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + "ratio_clip_c": None, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, + "force_on_policy_ratio": False, + } + + valid_toks = int(torch.clamp(seq_lengths - 1, min=0).sum().item()) + global_valid_toks = torch.tensor(valid_toks, dtype=torch.float32, device=device) + global_valid_seqs = torch.tensor(batch_size, dtype=torch.float32, device=device) + + return { + "loss_cfg": loss_cfg, + "cu_seqlens": cu_seqlens, + "cu_seqlens_padded": cu_seqlens_padded, + "data_dict": data_dict, + "global_valid_seqs": global_valid_seqs, + "global_valid_toks": global_valid_toks, + "make_logits_and_packed_logits": make_logits_and_packed_logits, + } + + +def _run_compare_sequence_packing_wrappers(rank, world_size, cp_size, tp_size): + """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. -@pytest.fixture -def register_sequence_packing_loss_wrapper_baseline_actor(): - """Register the actor in ACTOR_ENVIRONMENT_REGISTRY for RayWorkerGroup.""" - original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( - SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN + Verifies that the fused wrapper produces identical loss values and + backward gradients w.r.t. vocab-parallel logits. + """ + _my_cp_rank, my_tp_rank, cp_group, tp_group = _setup_2d_process_groups( + rank, world_size, cp_size, tp_size ) - ACTOR_ENVIRONMENT_REGISTRY[SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN] = ( - PY_EXECUTABLES.MCORE + tc = _build_test_case(cp_size, tp_size, my_tp_rank, cp_group) + base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + data_dict = tc["data_dict"] + + baseline_wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, ) - yield SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN - - if SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: - if original_registry_value is None: - del ACTOR_ENVIRONMENT_REGISTRY[ - SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN - ] - else: - ACTOR_ENVIRONMENT_REGISTRY[ - SEQUENCE_PACKING_LOSS_WRAPPER_BASELINE_ACTOR_FQN - ] = original_registry_value - - -@pytest.fixture(scope="function") -def cluster_fixture(request): - """Create and teardown a virtual cluster for CP/TP tests.""" - cp_size, tp_size = request.node.callspec.params["cp_tp"] - world_size = cp_size * tp_size - - if not ray.is_initialized(): - from nemo_rl.distributed.virtual_cluster import init_ray - init_ray() - - # Check available GPUs via Ray cluster resources (works across multi-node), - # falling back to local torch.cuda.device_count() if Ray has no GPU info. - available_gpus = int(ray.cluster_resources().get("GPU", 0)) - if available_gpus == 0: - available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + candidate_wrapper = SequencePackingFusionLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_packed_loss_input, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) - if available_gpus < world_size: - pytest.skip( - f"Not enough GPUs available. Need {world_size}, got {available_gpus}" - ) + # Baseline run + baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() + baseline_loss, _baseline_metrics = baseline_wrapper( + baseline_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (baseline_loss / cp_size).backward() + baseline_grad = baseline_logits.grad.clone() + + # Candidate run (fresh logits, identical values) + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_loss, _candidate_metrics = candidate_wrapper( + candidate_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (candidate_loss / cp_size).backward() + candidate_grad = candidate_logits.grad.clone() + + # Sanity: gradients must be non-None and non-zero + assert baseline_grad.abs().sum() > 0, f"baseline grad is all zeros on rank {rank}" + assert candidate_grad.abs().sum() > 0, f"candidate grad is all zeros on rank {rank}" + + # Forward: loss values must match + torch.testing.assert_close( + baseline_loss, + candidate_loss, + atol=1e-5, + rtol=1e-5, + msg=f"Loss mismatch on rank {rank}", + ) - cluster_name = f"test-seq-pack-fusion-cp{cp_size}-tp{tp_size}" - cluster = RayVirtualCluster( - name=cluster_name, bundle_ct_per_node_list=[world_size], use_gpus=True + # Backward: gradients w.r.t. logits must match + torch.testing.assert_close( + baseline_grad, + candidate_grad, + atol=1e-5, + rtol=1e-5, + msg=f"Gradient mismatch on rank {rank}", ) - yield cluster - cluster.shutdown() @pytest.mark.parametrize( @@ -404,11 +291,7 @@ def cluster_fixture(request): ], ids=lambda cp_tp: f"cp{cp_tp[0]}_tp{cp_tp[1]}", ) -def test_sequence_packing_loss_wrapper_baseline_cp_tp( - cluster_fixture, - register_sequence_packing_loss_wrapper_baseline_actor, - cp_tp, -): +def test_sequence_packing_fusion_vs_baseline(distributed_test_runner, cp_tp): """Compare SequencePackingFusionLossWrapper vs SequencePackingLossWrapper. Verifies that the fused wrapper produces identical: @@ -417,53 +300,11 @@ def test_sequence_packing_loss_wrapper_baseline_cp_tp( for different CP and TP configurations. """ cp_size, tp_size = cp_tp - cluster = cluster_fixture - actor_fqn = register_sequence_packing_loss_wrapper_baseline_actor - - sharding_layout = [ - [cp_rank * tp_size + tp_rank for tp_rank in range(tp_size)] - for cp_rank in range(cp_size) - ] - sharding = NamedSharding(layout=sharding_layout, names=["cp", "tp"]) - builder = RayWorkerBuilder(actor_fqn, cp_size=cp_size, tp_size=tp_size) - - worker_group = RayWorkerGroup( - cluster=cluster, - remote_worker_builder=builder, - workers_per_node=None, - sharding_annotations=sharding, - ) + world_size = cp_size * tp_size - futures = worker_group.run_all_workers_single_data( - "run_compare_sequence_packing_wrappers", + test_fn = functools.partial( + _run_compare_sequence_packing_wrappers, + cp_size=cp_size, + tp_size=tp_size, ) - results = ray.get(futures) - - if not isinstance(results, list): - results = [results] - - for r in results: - rank = r["rank"] - # Sanity: gradients must be non-None and non-zero (autograd is connected) - for label in ("baseline", "candidate"): - grad = r[label]["logits_local_grad"] - assert grad is not None, f"{label} grad is None on rank {rank}" - assert grad.abs().sum() > 0, f"{label} grad is all zeros on rank {rank}" - # Forward: loss values must match - torch.testing.assert_close( - r["baseline"]["loss"], - r["candidate"]["loss"], - atol=1e-5, - rtol=1e-5, - msg=f"Loss mismatch on rank {rank}", - ) - # Backward: gradients w.r.t. logits must match - torch.testing.assert_close( - r["baseline"]["logits_local_grad"], - r["candidate"]["logits_local_grad"], - atol=1e-5, - rtol=1e-5, - msg=f"Gradient mismatch on rank {rank}", - ) - - worker_group.shutdown(force=True) + distributed_test_runner(test_fn, world_size=world_size) From 449db303dee8a63390fb85d084d4a511f725a1a8 Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Mar 2026 21:25:24 -0700 Subject: [PATCH 13/18] Avoid double packing logits Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 43 +++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 959f7d52e4..523891e15c 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -24,6 +24,7 @@ from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( + _get_tokens_on_this_cp_rank, from_parallel_logits_to_logprobs_packed_sequences, get_distillation_topk_logprobs_from_logits, get_next_token_logprobs_from_logits, @@ -125,22 +126,25 @@ def _pack_input_ids( input_ids: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, + cp_size: int = 1, ) -> torch.Tensor: """Pack input_ids from [B, S] to [1, T_packed] using sequence boundaries. - Each sequence i is placed at cu_seqlens_q_padded[i] in the packed tensor, - with actual_len = cu_seqlens_q[i+1] - cu_seqlens_q[i] tokens copied. + When cp_size > 1, input_ids is [B, S // cp_size] (already CP-sharded) and + offsets are divided by cp_size to produce [1, T_packed // cp_size]. """ batch_size = input_ids.shape[0] - total_packed_len = int(cu_seqlens_q_padded[-1].item()) + total_packed_len = int(cu_seqlens_q_padded[-1].item()) // cp_size packed = torch.zeros( - 1, total_packed_len, dtype=input_ids.dtype, device=input_ids.device + total_packed_len, dtype=input_ids.dtype, device=input_ids.device ) for i in range(batch_size): - actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) - packed_start = int(cu_seqlens_q_padded[i].item()) - packed[0, packed_start : packed_start + actual_len] = input_ids[i, :actual_len] - return packed + actual_len = int( + (cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item() + ) // cp_size + packed_start = int(cu_seqlens_q_padded[i].item()) // cp_size + packed[packed_start : packed_start + actual_len] = input_ids[i, :actual_len] + return packed.unsqueeze(0) def prepare_packed_loss_input( @@ -187,14 +191,28 @@ def prepare_packed_loss_input( "vocab_parallel_rank must be provided with vocab_parallel_group." ) - packed_input_ids = _pack_input_ids( - data["input_ids"], cu_seqlens_q, cu_seqlens_q_padded + input_ids = data["input_ids"] + unpacked_seqlen = input_ids.shape[1] + cp_size = ( + 1 + if context_parallel_group is None + else torch.distributed.get_world_size(context_parallel_group) + ) + cp_rank = ( + 0 + if context_parallel_group is None + else torch.distributed.get_rank(context_parallel_group) + ) + + rolled_ids = input_ids.roll(-1, dims=1) + rolled_ids = _get_tokens_on_this_cp_rank(rolled_ids, cp_rank, cp_size, seq_dim=1) + packed_rolled_targets = _pack_input_ids( + rolled_ids, cu_seqlens_q, cu_seqlens_q_padded, cp_size ) - unpacked_seqlen = data["input_ids"].shape[1] logprobs = from_parallel_logits_to_logprobs_packed_sequences( logits.to(torch.float32), - packed_input_ids, + packed_rolled_targets, cu_seqlens_q_padded, unpacked_seqlen, vocab_start_index=vocab_parallel_rank * logits.shape[-1], @@ -202,6 +220,7 @@ def prepare_packed_loss_input( group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + target_is_pre_rolled=True, ) return {"next_token_logprobs": logprobs} From ca11e36114fb52522097361137d5bb314d7b5689 Mon Sep 17 00:00:00 2001 From: mloh Date: Tue, 10 Mar 2026 21:26:40 -0700 Subject: [PATCH 14/18] Fix lint issue Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 523891e15c..18ff327f62 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -139,9 +139,7 @@ def _pack_input_ids( total_packed_len, dtype=input_ids.dtype, device=input_ids.device ) for i in range(batch_size): - actual_len = int( - (cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item() - ) // cp_size + actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) // cp_size packed_start = int(cu_seqlens_q_padded[i].item()) // cp_size packed[packed_start : packed_start + actual_len] = input_ids[i, :actual_len] return packed.unsqueeze(0) From 95745ce2a8eda29b5afdfa36d03a8d0f028c4c27 Mon Sep 17 00:00:00 2001 From: mloh Date: Wed, 11 Mar 2026 15:14:02 -0700 Subject: [PATCH 15/18] Add unit test for from_parallel_logits_to_logprobs_packed_sequences Signed-off-by: mloh --- tests/unit/distributed/test_model_utils.py | 197 +++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index aad98d1da4..6e80aeed33 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import pytest @@ -243,6 +244,202 @@ def test_from_parallel_logits_to_logprobs_packed_sequences( cluster.shutdown() +# --------------------------------------------------------------------------- +# distributed_test_runner-based packed-sequences tests (coverage-friendly) +# --------------------------------------------------------------------------- + + +def _run_packed_sequences_equivalence(rank, world_size, tp_size, cp_size, chunk_size): + """Test from_parallel_logits_to_logprobs_packed_sequences with coverage. + + Uses _pack_input_ids to build packed targets and compares: + 1. target_is_pre_rolled=False against the unpacked baseline (CP=1 only) + 2. target_is_pre_rolled=True against target_is_pre_rolled=False + with variable-length sequences. + """ + from nemo_rl.algorithms.loss.utils import _pack_input_ids + + # Build 2-D process groups: inner=TP, outer=CP + tp_groups = [] + cp_groups = [] + for cp_r in range(cp_size): + ranks = [cp_r * tp_size + tp_r for tp_r in range(tp_size)] + tp_groups.append(torch.distributed.new_group(ranks=ranks)) + for tp_r in range(tp_size): + ranks = [cp_r * tp_size + tp_r for cp_r in range(cp_size)] + cp_groups.append(torch.distributed.new_group(ranks=ranks)) + + my_tp_rank = rank % tp_size + my_cp_rank = rank // tp_size + tp_group = tp_groups[my_cp_rank] + cp_group = cp_groups[my_tp_rank] if cp_size > 1 else None + my_cp_rank_val = 0 if cp_group is None else torch.distributed.get_rank(cp_group) + + batch_size = 4 + vocab_size = 1024 + vocab_part_size = vocab_size // tp_size + vocab_start_index = my_tp_rank * vocab_part_size + vocab_end_index = (my_tp_rank + 1) * vocab_part_size + + # Variable-length sequences + raw_seq_lengths = [24, 48, 16, 40] + max_seq_len = max(raw_seq_lengths) + + if cp_size > 1 and max_seq_len % (2 * cp_size) != 0: + max_seq_len = (max_seq_len // (2 * cp_size) + 1) * (2 * cp_size) + raw_seq_lengths = [min(l, max_seq_len) for l in raw_seq_lengths] + + pad_to = 2 * cp_size if cp_size > 1 else 1 + padded_seq_lengths = [ + ((l + pad_to - 1) // pad_to) * pad_to for l in raw_seq_lengths + ] + + # Build cu_seqlens / cu_seqlens_padded + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + for i in range(batch_size): + cu_seqlens[i + 1] = cu_seqlens[i] + raw_seq_lengths[i] + cu_seqlens_padded[i + 1] = cu_seqlens_padded[i] + padded_seq_lengths[i] + + total_padded = int(cu_seqlens_padded[-1].item()) + + torch.manual_seed(42) + unpacked_logits_full = torch.randn( + batch_size, max_seq_len, vocab_size, device="cuda" + ) + input_ids = torch.randint(0, vocab_size, (batch_size, max_seq_len), device="cuda") + + unpacked_logits_local = unpacked_logits_full[ + :, :, vocab_start_index:vocab_end_index + ] + + # --- Pack logits: [B, S, V_local] -> [1, T_padded // CP, V_local] --- + # Each sequence is individually padded and CP-sharded (matching production). + packed_logits = torch.zeros( + 1, total_padded // cp_size, vocab_part_size, device="cuda" + ) + for i in range(batch_size): + sl = raw_seq_lengths[i] + psl = padded_seq_lengths[i] + padded_seq = torch.zeros(1, psl, vocab_part_size, device="cuda") + padded_seq[:, :sl, :] = unpacked_logits_local[i : i + 1, :sl, :] + offset = int(cu_seqlens_padded[i].item()) + if cp_size > 1: + sharded = _get_tokens_on_this_cp_rank(padded_seq, my_cp_rank_val, cp_size) + packed_logits[:, offset // cp_size : (offset + psl) // cp_size, :] = sharded + else: + packed_logits[:, offset : offset + psl, :] = padded_seq + + # --- Path 1: target_is_pre_rolled=False --- + # Pack raw (unrolled) input_ids to [1, T_padded] using _pack_input_ids. + packed_target_raw = _pack_input_ids(input_ids, cu_seqlens, cu_seqlens_padded) + + logprobs_not_pre_rolled = from_parallel_logits_to_logprobs_packed_sequences( + packed_logits, + packed_target_raw, + cu_seqlens_padded, + max_seq_len, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + chunk_size=chunk_size, + target_is_pre_rolled=False, + ) + + # --- Path 2: target_is_pre_rolled=True --- + # Prepare pre-rolled targets per-sequence (matching the internal logic of + # the target_is_pre_rolled=False path): each packed sequence is rolled and + # CP-sharded independently from its own padded_seq_len. + pre_rolled_packed = torch.zeros( + total_padded // cp_size, dtype=input_ids.dtype, device="cuda" + ) + for i in range(batch_size): + start = int(cu_seqlens_padded[i].item()) + end = int(cu_seqlens_padded[i + 1].item()) + seq_tokens = packed_target_raw[0, start:end] + rolled = seq_tokens.roll(shifts=-1, dims=0) + sharded = _get_tokens_on_this_cp_rank( + rolled, my_cp_rank_val, cp_size, seq_dim=0 + ) + pre_rolled_packed[start // cp_size : end // cp_size] = sharded + packed_target_pre_rolled = pre_rolled_packed.unsqueeze(0) + + logprobs_pre_rolled = from_parallel_logits_to_logprobs_packed_sequences( + packed_logits, + packed_target_pre_rolled, + cu_seqlens_padded, + max_seq_len, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + chunk_size=chunk_size, + target_is_pre_rolled=True, + ) + + # Both paths must produce identical results + for i in range(batch_size): + valid_len = raw_seq_lengths[i] - 1 + torch.testing.assert_close( + logprobs_pre_rolled[i, :valid_len], + logprobs_not_pre_rolled[i, :valid_len], + rtol=1e-5, + atol=1e-5, + msg=f"pre_rolled vs not_pre_rolled mismatch on rank {rank}, seq {i}", + ) + + # --- Also compare against the unpacked baseline --- + # The unpacked function CP-shards each row from max_seq_len, which matches + # the packed per-sequence CP-sharding only when CP=1. + if cp_size == 1: + baseline_logprobs = from_parallel_logits_to_logprobs( + unpacked_logits_local, + input_ids, + vocab_start_index, + vocab_end_index, + tp_group, + cp_group=cp_group, + ) + for i in range(batch_size): + valid_len = raw_seq_lengths[i] - 1 + torch.testing.assert_close( + logprobs_not_pre_rolled[i, :valid_len], + baseline_logprobs[i, :valid_len], + rtol=1e-5, + atol=1e-5, + msg=f"packed vs unpacked mismatch on rank {rank}, seq {i}", + ) + + +@pytest.mark.parametrize( + "tp_size, cp_size, chunk_size", + [ + (2, 1, None), + (1, 2, None), + (2, 1, 8), + (1, 2, 8), + ], + ids=lambda v: str(v), +) +def test_packed_sequences_with_distributed_runner( + distributed_test_runner, tp_size, cp_size, chunk_size +): + """Test from_parallel_logits_to_logprobs_packed_sequences using distributed_test_runner. + + Covers both target_is_pre_rolled paths, variable-length sequences, and chunk_size, + with proper code coverage tracking (unlike Ray-based tests). + """ + world_size = tp_size * cp_size + test_fn = functools.partial( + _run_packed_sequences_equivalence, + tp_size=tp_size, + cp_size=cp_size, + chunk_size=chunk_size, + ) + distributed_test_runner(test_fn, world_size=world_size) + + @ray.remote(num_gpus=1) class AllGatherCPTestActor: def __init__(self, cp_size): From 15144b1e0d2ce2585d29414158242867bbfba1d2 Mon Sep 17 00:00:00 2001 From: mloh Date: Thu, 12 Mar 2026 09:03:33 -0700 Subject: [PATCH 16/18] Modify prepare_packed_loss_input's return parameter to match prepare_loss_input Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 6 +++--- nemo_rl/algorithms/loss/wrapper.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 18ff327f62..4af0823646 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -154,7 +154,7 @@ def prepare_packed_loss_input( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, -) -> dict[str, Any]: +) -> tuple[dict[str, Any], BatchedDataDict[Any]]: """Prepare loss input from packed logits in a single fused pass. Unlike prepare_loss_input which operates on a single (unpacked) sequence, @@ -174,7 +174,7 @@ def prepare_packed_loss_input( context_parallel_group: Context parallel group. Returns: - Loss input dict with key "next_token_logprobs". + tuple(loss_input, maybe_updated_data) """ if loss_fn.input_type != LossInputType.LOGPROB: raise ValueError( @@ -221,4 +221,4 @@ def prepare_packed_loss_input( target_is_pre_rolled=True, ) - return {"next_token_logprobs": logprobs} + return {"next_token_logprobs": logprobs}, data diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index 0fa3cac2d3..a28bb18a19 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -195,7 +195,7 @@ def __call__( global_valid_toks: Tensor | None, ) -> tuple[Tensor, dict[str, Any]]: """Compute loss for all packed sequences in one forward pass.""" - loss_input = self.prepare_fn( + loss_input, prepared_data = self.prepare_fn( logits=next_token_logits, data=data, loss_fn=self.loss_fn, @@ -207,7 +207,7 @@ def __call__( ) return self.loss_fn( - data=data, + data=prepared_data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, **loss_input, From f956b2f5be190b197eadc8ec405f41b1a797d188 Mon Sep 17 00:00:00 2001 From: mloh Date: Thu, 12 Mar 2026 09:05:43 -0700 Subject: [PATCH 17/18] Fix ruff format issue Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 1 + nemo_rl/distributed/model_utils.py | 2 -- nemo_rl/models/megatron/train.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 4af0823646..7296350e9b 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -122,6 +122,7 @@ def prepare_loss_input( return loss_input, data + def _pack_input_ids( input_ids: torch.Tensor, cu_seqlens_q: torch.Tensor, diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index ba7810d790..2575fd891d 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -23,8 +23,6 @@ need_top_k_or_top_p_filtering, ) -from nemo_rl.utils.nsys import wrap_with_nvtx_name - @torch.no_grad() def _compute_distributed_log_softmax( diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index a8a54027d8..7ea195d6a3 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -331,7 +331,7 @@ def __call__( else: wrapper_cls = SequencePackingLossWrapper prepare_fn = prepare_loss_input_wrapped - + loss_fn_wrapped = wrapper_cls( loss_fn=self.loss_fn, prepare_fn=prepare_fn, From c4921303edb810754f1ba5634f8b97aba16a06f6 Mon Sep 17 00:00:00 2001 From: mloh Date: Thu, 12 Mar 2026 15:22:43 -0700 Subject: [PATCH 18/18] Add sampling to fusion wrapper Signed-off-by: mloh --- nemo_rl/algorithms/loss/utils.py | 72 +++++++++-- nemo_rl/models/megatron/train.py | 4 +- .../test_sequence_packing_fusion.py | 121 ++++++++++++++++++ tests/unit/distributed/test_model_utils.py | 22 +--- 4 files changed, 194 insertions(+), 25 deletions(-) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 7296350e9b..ad92522db0 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -127,12 +127,26 @@ def _pack_input_ids( input_ids: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, + cp_rank: int = 0, cp_size: int = 1, + roll_shift: int = 0, ) -> torch.Tensor: - """Pack input_ids from [B, S] to [1, T_packed] using sequence boundaries. + """Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries. - When cp_size > 1, input_ids is [B, S // cp_size] (already CP-sharded) and - offsets are divided by cp_size to produce [1, T_packed // cp_size]. + Each sequence is individually padded to its padded length (from + cu_seqlens_q_padded), optionally rolled, and CP-sharded at that padded + length before being placed into the packed output. This matches how + Megatron packs and CP-shards sequences in _pack_sequences_for_megatron. + + Args: + input_ids: Unpacked input IDs [B, S]. + cu_seqlens_q: Unpadded cumulative sequence lengths [B+1]. + cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1]. + cp_rank: Context parallelism rank. + cp_size: Context parallelism size. + roll_shift: If non-zero, roll each padded sequence by this amount + before CP-sharding. Use -1 to build shifted targets for + next-token prediction. """ batch_size = input_ids.shape[0] total_packed_len = int(cu_seqlens_q_padded[-1].item()) // cp_size @@ -140,9 +154,17 @@ def _pack_input_ids( total_packed_len, dtype=input_ids.dtype, device=input_ids.device ) for i in range(batch_size): - actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) // cp_size - packed_start = int(cu_seqlens_q_padded[i].item()) // cp_size - packed[packed_start : packed_start + actual_len] = input_ids[i, :actual_len] + actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item()) + padded_len = int((cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i]).item()) + packed_start = int(cu_seqlens_q_padded[i].item()) + seq = torch.zeros(padded_len, dtype=input_ids.dtype, device=input_ids.device) + seq[:actual_len] = input_ids[i, :actual_len] + if roll_shift != 0: + seq = seq.roll(shifts=roll_shift, dims=0) + sharded = _get_tokens_on_this_cp_rank(seq, cp_rank, cp_size, seq_dim=0) + packed[packed_start // cp_size : (packed_start + padded_len) // cp_size] = ( + sharded + ) return packed.unsqueeze(0) @@ -155,6 +177,7 @@ def prepare_packed_loss_input( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> tuple[dict[str, Any], BatchedDataDict[Any]]: """Prepare loss input from packed logits in a single fused pass. @@ -173,6 +196,7 @@ def prepare_packed_loss_input( vocab_parallel_rank: Vocab parallel rank. vocab_parallel_group: Vocab parallel group. context_parallel_group: Context parallel group. + sampling_params: Sampling parameters. Returns: tuple(loss_input, maybe_updated_data) @@ -203,10 +227,13 @@ def prepare_packed_loss_input( else torch.distributed.get_rank(context_parallel_group) ) - rolled_ids = input_ids.roll(-1, dims=1) - rolled_ids = _get_tokens_on_this_cp_rank(rolled_ids, cp_rank, cp_size, seq_dim=1) packed_rolled_targets = _pack_input_ids( - rolled_ids, cu_seqlens_q, cu_seqlens_q_padded, cp_size + input_ids, + cu_seqlens_q, + cu_seqlens_q_padded, + cp_rank=cp_rank, + cp_size=cp_size, + roll_shift=-1, ) logprobs = from_parallel_logits_to_logprobs_packed_sequences( @@ -219,7 +246,34 @@ def prepare_packed_loss_input( group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + sampling_params=sampling_params, target_is_pre_rolled=True, ) + # Match prepare_loss_input behavior for top-k/top-p filtered training: + # use filtered curr_logprobs for actor loss, but keep unfiltered values for KL. + if need_top_k_or_top_p_filtering(sampling_params): + mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) + logprobs = mask_out_neg_inf_logprobs(logprobs, mask[:, 1:], "curr_logprobs") + + if ( + hasattr(loss_fn, "reference_policy_kl_penalty") + and loss_fn.reference_policy_kl_penalty != 0 + ): + data["curr_logprobs_unfiltered"] = ( + from_parallel_logits_to_logprobs_packed_sequences( + logits.to(torch.float32), + packed_rolled_targets, + cu_seqlens_q_padded, + unpacked_seqlen, + vocab_start_index=vocab_parallel_rank * logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1], + group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + sampling_params=None, + target_is_pre_rolled=True, + ) + ) + return {"next_token_logprobs": logprobs}, data diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 7ea195d6a3..883aa44ad7 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -327,7 +327,9 @@ def __call__( fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", False) if fuse_loss: wrapper_cls = SequencePackingFusionLossWrapper - prepare_fn = prepare_packed_loss_input + prepare_fn = partial( + prepare_packed_loss_input, sampling_params=self.sampling_params + ) else: wrapper_cls = SequencePackingLossWrapper prepare_fn = prepare_loss_input_wrapped diff --git a/tests/unit/algorithms/test_sequence_packing_fusion.py b/tests/unit/algorithms/test_sequence_packing_fusion.py index d31a0214e3..cb887b2b37 100644 --- a/tests/unit/algorithms/test_sequence_packing_fusion.py +++ b/tests/unit/algorithms/test_sequence_packing_fusion.py @@ -26,6 +26,7 @@ import pytest import torch +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss import ( ClippedPGLossFn, SequencePackingFusionLossWrapper, @@ -279,6 +280,101 @@ def _run_compare_sequence_packing_wrappers(rank, world_size, cp_size, tp_size): ) +def _run_compare_sequence_packing_wrappers_with_sampling( + rank, world_size, cp_size, tp_size +): + """Compare fused vs unfused wrappers with sampling params enabled.""" + _my_cp_rank, my_tp_rank, cp_group, tp_group = _setup_2d_process_groups( + rank, world_size, cp_size, tp_size + ) + tc = _build_test_case(cp_size, tp_size, my_tp_rank, cp_group) + base_loss_fn = ClippedPGLossFn(tc["loss_cfg"]) + data_dict = tc["data_dict"] + + sampling_params = TrainingSamplingParams(top_k=8, top_p=0.9, temperature=1.0) + prepare_loss_input_wrapped = functools.partial( + prepare_loss_input, sampling_params=sampling_params + ) + prepare_packed_loss_input_wrapped = functools.partial( + prepare_packed_loss_input, sampling_params=sampling_params + ) + + baseline_wrapper = SequencePackingLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input_wrapped, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + candidate_wrapper = SequencePackingFusionLossWrapper( + loss_fn=base_loss_fn, + prepare_fn=prepare_packed_loss_input_wrapped, + cu_seqlens_q=tc["cu_seqlens"], + cu_seqlens_q_padded=tc["cu_seqlens_padded"], + vocab_parallel_rank=my_tp_rank, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, + ) + + # Baseline run + baseline_logits, baseline_packed_logits = tc["make_logits_and_packed_logits"]() + baseline_loss, baseline_metrics = baseline_wrapper( + baseline_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (baseline_loss / cp_size).backward() + baseline_grad = baseline_logits.grad.clone() + + # Candidate run (fresh logits, identical values) + candidate_logits, candidate_packed_logits = tc["make_logits_and_packed_logits"]() + candidate_loss, candidate_metrics = candidate_wrapper( + candidate_packed_logits, + data_dict, + tc["global_valid_seqs"], + tc["global_valid_toks"], + ) + (candidate_loss / cp_size).backward() + candidate_grad = candidate_logits.grad.clone() + + # Sanity: gradients must be non-None and non-zero + assert baseline_grad.abs().sum() > 0, f"baseline grad is all zeros on rank {rank}" + assert candidate_grad.abs().sum() > 0, f"candidate grad is all zeros on rank {rank}" + + # Forward: loss values must match + torch.testing.assert_close( + baseline_loss, + candidate_loss, + atol=1e-5, + rtol=1e-5, + msg=f"Loss mismatch with sampling params on rank {rank}", + ) + + # Metrics parity under sampling params + assert set(baseline_metrics.keys()) == set(candidate_metrics.keys()) + for k in baseline_metrics: + torch.testing.assert_close( + torch.as_tensor(baseline_metrics[k], device="cuda"), + torch.as_tensor(candidate_metrics[k], device="cuda"), + atol=1e-5, + rtol=1e-5, + msg=f"Metric mismatch for key={k} on rank {rank}", + ) + + # Backward: gradients w.r.t. logits must match + torch.testing.assert_close( + baseline_grad, + candidate_grad, + atol=1e-5, + rtol=1e-5, + msg=f"Gradient mismatch with sampling params on rank {rank}", + ) + + @pytest.mark.parametrize( "cp_tp", [ @@ -308,3 +404,28 @@ def test_sequence_packing_fusion_vs_baseline(distributed_test_runner, cp_tp): tp_size=tp_size, ) distributed_test_runner(test_fn, world_size=world_size) + + +@pytest.mark.parametrize( + "cp_tp", + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + ], + ids=lambda cp_tp: f"sampling_cp{cp_tp[0]}_tp{cp_tp[1]}", +) +def test_sequence_packing_fusion_vs_baseline_with_sampling_params( + distributed_test_runner, cp_tp +): + """Compare fused vs unfused wrappers with top-k/top-p sampling params.""" + cp_size, tp_size = cp_tp + world_size = cp_size * tp_size + + test_fn = functools.partial( + _run_compare_sequence_packing_wrappers_with_sampling, + cp_size=cp_size, + tp_size=tp_size, + ) + distributed_test_runner(test_fn, world_size=world_size) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 6e80aeed33..6c602440ff 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -348,22 +348,14 @@ def _run_packed_sequences_equivalence(rank, world_size, tp_size, cp_size, chunk_ ) # --- Path 2: target_is_pre_rolled=True --- - # Prepare pre-rolled targets per-sequence (matching the internal logic of - # the target_is_pre_rolled=False path): each packed sequence is rolled and - # CP-sharded independently from its own padded_seq_len. - pre_rolled_packed = torch.zeros( - total_padded // cp_size, dtype=input_ids.dtype, device="cuda" + packed_target_pre_rolled = _pack_input_ids( + input_ids, + cu_seqlens, + cu_seqlens_padded, + cp_rank=my_cp_rank_val, + cp_size=cp_size, + roll_shift=-1, ) - for i in range(batch_size): - start = int(cu_seqlens_padded[i].item()) - end = int(cu_seqlens_padded[i + 1].item()) - seq_tokens = packed_target_raw[0, start:end] - rolled = seq_tokens.roll(shifts=-1, dims=0) - sharded = _get_tokens_on_this_cp_rank( - rolled, my_cp_rank_val, cp_size, seq_dim=0 - ) - pre_rolled_packed[start // cp_size : end // cp_size] = sharded - packed_target_pre_rolled = pre_rolled_packed.unsqueeze(0) logprobs_pre_rolled = from_parallel_logits_to_logprobs_packed_sequences( packed_logits,