Skip to content

perf: Fuse sequence packing for loss function#1904

Merged
terrykong merged 18 commits intomainfrom
mloh/seqpack_fusion
Mar 17, 2026
Merged

perf: Fuse sequence packing for loss function#1904
terrykong merged 18 commits intomainfrom
mloh/seqpack_fusion

Conversation

@nujoug
Copy link
Contributor

@nujoug nujoug commented Feb 10, 2026

What does this PR do ?

Apply a single call of loss function to all sequences instead of calling for each individual sequences

Issues

Issue #1247

Usage

Set flag policy.sequence_packing.fuse_loss to true to turn on this feature

Results

Observed up to 15% speedup on policy training flops

W B Chart 3_10_2026, 9_28_09 PM

Validation results show similar accuracy curve

W B Chart 3_10_2026, 9_27_19 PM

Additional Information

Check out this report for more detailed analysis

Summary by CodeRabbit

  • New Features

    • Introduced fused sequence packing loss computation, enabling single-pass loss calculation instead of per-sequence processing. Users can now enable loss fusion via the fuse_loss configuration flag to improve training performance.
  • Tests

    • Added comprehensive validation tests for fused loss wrapper across multiple distributed training configurations.

@nujoug nujoug force-pushed the mloh/seqpack_fusion branch 2 times, most recently from 197b783 to c6d4fbb Compare February 11, 2026 00:37
@nujoug nujoug marked this pull request as ready for review February 11, 2026 18:42
@nujoug nujoug requested review from a team as code owners February 11, 2026 18:42
@nujoug nujoug requested a review from guyueh1 February 11, 2026 18:42
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

📝 Walkthrough

Walkthrough

Introduces a fused sequence packing loss computation workflow through a new SequencePackingFusionLossWrapper that processes packed sequences in one pass, alongside updated model utilities supporting pre-rolled targets and new loss preparation functions to enable optimized distributed loss calculation.

Changes

Cohort / File(s) Summary
Core Model Utilities
nemo_rl/distributed/model_utils.py
Modified from_parallel_logits_to_logprobs_packed_sequences to accept target_is_pre_rolled parameter and renamed cu_seqlens to cu_seqlens_padded. Added conditional control flow to bypass internal rolling/CP-shard logic when targets are pre-rolled, otherwise perform standard roll and CP-shard operations.
Loss Computation Infrastructure
nemo_rl/algorithms/loss/utils.py, nemo_rl/algorithms/loss/wrapper.py
Added prepare_packed_loss_input utility function to compute logprobs via fused packed-sequence processing. Introduced SequencePackingFusionLossWrapper class that performs one-shot forward pass on packed logits, eliminating per-sequence processing loops and associated kernel launches.
API Exports
nemo_rl/algorithms/loss/__init__.py
Updated public API to expose prepare_packed_loss_input and SequencePackingFusionLossWrapper alongside existing exports.
Training Integration
nemo_rl/models/megatron/train.py
Modified LossPostProcessor.__call__ to dynamically select between SequencePackingFusionLossWrapper and standard SequencePackingLossWrapper based on fuse_loss configuration flag, with corresponding loss preparation function selection.
Testing
tests/unit/algorithms/test_sequence_packing_fusion.py
Comprehensive test module validating SequencePackingFusionLossWrapper against SequencePackingLossWrapper across multiple CP/TP distributed configurations (1x1, 1x2, 2x1, 2x2, 2x4, 4x2), verifying forward loss and backward gradient consistency.

Sequence Diagram(s)

sequenceDiagram
    participant Train as Training Loop
    participant StdWrap as SequencePackingLossWrapper
    participant FusedWrap as SequencePackingFusionLossWrapper
    participant PrepStd as prepare_loss_input
    participant PrepFused as prepare_packed_loss_input
    participant LossFn as Loss Function
    
    rect rgba(200, 150, 255, 0.5)
    Note over Train,LossFn: Standard Sequence Packing Path
    Train->>StdWrap: forward(logits, data)
    loop For each sequence
        StdWrap->>PrepStd: per-sequence processing
        PrepStd->>StdWrap: loss_input
    end
    StdWrap->>LossFn: loss_input
    LossFn->>StdWrap: loss
    end
    
    rect rgba(150, 200, 255, 0.5)
    Note over Train,LossFn: Fused Sequence Packing Path
    Train->>FusedWrap: forward(logits, data)
    FusedWrap->>PrepFused: packed logits + packed sequences
    PrepFused->>PrepFused: fused pass via from_parallel_logits_to_logprobs_packed_sequences
    PrepFused->>FusedWrap: loss_input (all sequences)
    FusedWrap->>LossFn: loss_input
    LossFn->>FusedWrap: loss
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: Fuse sequence packing for loss function' clearly and concisely describes the main change: fusing sequence packing operations in loss computation for performance improvement.
Docstring Coverage ✅ Passed Docstring coverage is 81.25% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR includes comprehensive testing with new test module (310 lines), documented performance metrics (up to 15% speedup), and accuracy validation curves demonstrating no regression.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch mloh/seqpack_fusion

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1134-1144: The call to
from_parallel_logits_to_logprobs_packed_sequences uses vocab_parallel_rank
without guarding for None, which will raise a TypeError when None is provided;
add an explicit assertion (like the one used in _compute_curr_logprobs) that
vocab_parallel_rank is not None before computing
vocab_start_index/vocab_end_index, and raise a clear error message if it is None
so callers know the parallel-vocab requirement; update the block around the
curr_logprobs computation (the call to
from_parallel_logits_to_logprobs_packed_sequences and the variables
vocab_parallel_rank/vocab_parallel_group) to perform this check first.

In `@nemo_rl/models/megatron/common.py`:
- Around line 130-148: Add the new `fuse_loss` key to the SequencePackingConfig
TypedDict in nemo_rl/models/policy/__init__.py with a short docstring describing
purpose, valid values (bool), and the recommended default; remove the in-code
default by updating the usage in the megatron wrapper to read fuse_loss =
policy_cfg.get("sequence_packing", {}).get("fuse_loss") (no fallback to False)
so the code no longer embeds a default; place the default value for
sequence_packing.fuse_loss into the exemplar YAMLs under examples/configs/*.yaml
(so YAML is the single source of truth); keep the existing behavior that selects
SequencePackingFusionLossWrapper vs SequencePackingLossWrapper and the
conditional data_dict["packed_input_ids"] assignment, but rely on the
YAML-provided default rather than .get(..., False).

In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Line 445: Remove the unused variable assignment to `world_size = cp_size *
tp_size` in the test (the `world_size` local is never referenced); simply delete
that line (or if the value was intended to be used, replace the assignment by
using `world_size` where needed) so that `world_size`, `cp_size`, and `tp_size`
are not assigned without use in the test `test_sequence_packing_fusion.py`.
🧹 Nitpick comments (1)
nemo_rl/algorithms/loss_functions.py (1)

1147-1149: Duck-typing on _compute_loss_from_logprobs — consider a runtime check or Protocol.

self.loss_fn._compute_loss_from_logprobs(...) will raise AttributeError if the wrapped loss_fn doesn't implement this method (e.g., NLLLoss, DPOLossFn). While the docstring documents the requirement, a hasattr check in __init__ would produce a clear error at construction time rather than mid-training.

♻️ Proposed guard in __init__
     def __init__(
         self,
         loss_fn: LossFunction,
         cu_seqlens_q: torch.Tensor,
         cu_seqlens_q_padded: Optional[torch.Tensor] = None,
     ):
+        if not hasattr(loss_fn, "_compute_loss_from_logprobs"):
+            raise TypeError(
+                f"{type(loss_fn).__name__} does not implement _compute_loss_from_logprobs. "
+                "SequencePackingFusionLossWrapper requires a loss function with this method "
+                "(e.g., ClippedPGLossFn)."
+            )
         self.loss_fn = loss_fn

@nujoug nujoug force-pushed the mloh/seqpack_fusion branch from 69e416a to 371c308 Compare February 20, 2026 21:40
@nujoug nujoug added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 20, 2026
@nujoug
Copy link
Contributor Author

nujoug commented Feb 21, 2026

@CodeRabbit review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 21, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/models/megatron/train.py (1)

313-328: 🛠️ Refactor suggestion | 🟠 Major

Add fuse_loss to SequencePackingConfig TypedDict and document in exemplar YAMLs.

The fuse_loss key is used in code (line 314–316) but not declared in the SequencePackingConfig TypedDict definition. Per coding guidelines, new config keys must be documented in the TypedDict with type and NotRequired marker where appropriate, and reflected in exemplar YAMLs under examples/configs/*.yaml with recommended defaults.

Additionally, the self.cfg is not None guard on line 314 is redundant—self.cfg is already dereferenced without a None check on line 311.

Minor cleanup for the redundant check
-            fuse_loss = self.cfg is not None and self.cfg.get(
-                "sequence_packing", {}
-            ).get("fuse_loss", None)
+            fuse_loss = self.cfg.get("sequence_packing", {}).get("fuse_loss", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/models/megatron/train.py` around lines 313 - 328, Add the missing
fuse_loss entry to the SequencePackingConfig TypedDict (as NotRequired[bool])
and update the example config YAMLs under examples/configs/*.yaml to document
the key and a recommended default (false); then remove the redundant self.cfg is
not None guard in the fuse_loss conditional in the megatron train logic where
fuse_loss is read (the code that picks SequencePackingFusionLossWrapper vs
SequencePackingLossWrapper and constructs the wrapper with
cu_seqlens_q/cu_seqlens_q_padded) so the config is consistently typed and
documented and the conditional is simplified.
🧹 Nitpick comments (2)
nemo_rl/algorithms/loss_functions.py (2)

1188-1252: Verify that input_ids.roll(-1, dims=1) on padded [B, S] is equivalent to per-sequence rolling.

Line 1228 rolls the entire [B, S] tensor row-wise, which wraps the first token of each row into the last position (including padding positions). The non-fusion SequencePackingLossWrapper path instead rolls each sequence individually within its actual length boundaries via from_parallel_logits_to_logprobs_packed_sequences (model_utils.py, lines 599-600).

The semantic difference is at position seq_len - 1 of each sequence: the fusion path places input_ids[i, seq_len] (a padding zero) there, while the non-fusion path wraps input_ids[packed_start] (the first real token). However, position seq_len - 1 is excluded from the output at line 662 (probs[start_idx : end_idx - 1]), so the difference is harmless.

This is a subtle correctness invariant worth a brief inline comment.

💡 Suggested inline comment
         # Roll targets on [B, S] (each row shifts independently), then CP-shard and pack.
+        # NOTE: Full-row roll wraps padding into the last real position, but that
+        # position is excluded by from_parallel_logits_to_logprobs_packed_sequences
+        # (which drops the last token per sequence), so the result is equivalent to
+        # per-sequence rolling done by the non-fused path.
         rolled_ids = input_ids.roll(-1, dims=1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/loss_functions.py` around lines 1188 - 1252,
SequencePackingFusionLossWrapper.__call__ currently uses input_ids.roll(-1,
dims=1) which wraps padded zeros into the last token position for each row
(after _get_tokens_on_this_cp_rank and before packing), differing from the
per-sequence roll used by from_parallel_logits_to_logprobs_packed_sequences in
the non-fusion path; add a concise inline comment near the input_ids.roll call
explaining this semantic difference and why it is safe (the seq_len-1 position
is later excluded by the unpacking logic / probs[start_idx:end_idx-1]),
referencing input_ids.roll, _get_tokens_on_this_cp_rank, and
from_parallel_logits_to_logprobs_packed_sequences so future readers understand
the invariant.

128-128: chunk_size is never set from configuration.

self.chunk_size is initialized to None and never populated from ClippedPGLossConfig or any external source. While the attribute exists as a hook for SequencePackingFusionLossWrapper (line 1246: getattr(self.loss_fn, "chunk_size", None)), it will always be None in practice unless something externally mutates it. If chunked logprob computation during training loss is intended to be supported, consider wiring this through the config or the caller.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/loss_functions.py` at line 128, self.chunk_size is
initialized to None in the ClippedPGLoss/constructor but never populated from
ClippedPGLossConfig or any caller, so SequencePackingFusionLossWrapper's
getattr(self.loss_fn, "chunk_size", None) will always return None; update the
constructor or config wiring to read a chunk_size value from ClippedPGLossConfig
(or an explicit constructor param) and assign it to self.chunk_size (e.g.,
accept chunk_size in ClippedPGLossConfig or the ClippedPGLoss __init__ and set
self.chunk_size = config.chunk_size) so that downstream callers like
SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob
computation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unit/algorithms/test_sequence_packing_fusion.py`:
- Around line 195-230: The packing currently breaks autograd because you assign
into leaf tensors created with torch.zeros (packed_logits and tmp) using
in-place slices; instead, build packed_logits from tensors derived from
logits_local so gradients flow: in make_logits_and_packed_logits, stop creating
tmp and packed_logits as torch.zeros and avoid in-place slice assignments — for
each sequence produce a slice (e.g., tmp_slice = logits_local[i:i+1, :seq_len,
:] padded via differentiable ops or use indexing + torch.nn.functional.pad), run
it through _get_tokens_on_this_cp_rank, collect those outputs in a list, then
torch.cat the list into packed_logits (or set packed_logits = torch.cat(...)) so
the result is connected to logits_local; alternatively ensure the final
packed_logits has requires_grad_(True) and is produced by non-inplace ops
referencing logits_local rather than assigning into zero tensors.

---

Outside diff comments:
In `@nemo_rl/models/megatron/train.py`:
- Around line 313-328: Add the missing fuse_loss entry to the
SequencePackingConfig TypedDict (as NotRequired[bool]) and update the example
config YAMLs under examples/configs/*.yaml to document the key and a recommended
default (false); then remove the redundant self.cfg is not None guard in the
fuse_loss conditional in the megatron train logic where fuse_loss is read (the
code that picks SequencePackingFusionLossWrapper vs SequencePackingLossWrapper
and constructs the wrapper with cu_seqlens_q/cu_seqlens_q_padded) so the config
is consistently typed and documented and the conditional is simplified.

---

Nitpick comments:
In `@nemo_rl/algorithms/loss_functions.py`:
- Around line 1188-1252: SequencePackingFusionLossWrapper.__call__ currently
uses input_ids.roll(-1, dims=1) which wraps padded zeros into the last token
position for each row (after _get_tokens_on_this_cp_rank and before packing),
differing from the per-sequence roll used by
from_parallel_logits_to_logprobs_packed_sequences in the non-fusion path; add a
concise inline comment near the input_ids.roll call explaining this semantic
difference and why it is safe (the seq_len-1 position is later excluded by the
unpacking logic / probs[start_idx:end_idx-1]), referencing input_ids.roll,
_get_tokens_on_this_cp_rank, and
from_parallel_logits_to_logprobs_packed_sequences so future readers understand
the invariant.
- Line 128: self.chunk_size is initialized to None in the
ClippedPGLoss/constructor but never populated from ClippedPGLossConfig or any
caller, so SequencePackingFusionLossWrapper's getattr(self.loss_fn,
"chunk_size", None) will always return None; update the constructor or config
wiring to read a chunk_size value from ClippedPGLossConfig (or an explicit
constructor param) and assign it to self.chunk_size (e.g., accept chunk_size in
ClippedPGLossConfig or the ClippedPGLoss __init__ and set self.chunk_size =
config.chunk_size) so that downstream callers like
SequencePackingFusionLossWrapper and loss_fn can detect and use chunked logprob
computation.

@nujoug nujoug added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Feb 21, 2026
@guyueh1 guyueh1 requested a review from youngeunkwon0405 March 2, 2026 19:07
@guyueh1
Copy link
Contributor

guyueh1 commented Mar 2, 2026

The changes look good, i'll approve once you solve merge conflict
@terrykong can you check the codecov? I think there is already unit test and E2E test covering this

@guyueh1 guyueh1 added the Performance Related to improving performance label Mar 5, 2026
@nujoug nujoug force-pushed the mloh/seqpack_fusion branch from 4233c4f to 3ea5f1f Compare March 11, 2026 00:41
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@nujoug
Copy link
Contributor Author

nujoug commented Mar 11, 2026

@CodeRabbit review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

nujoug added 15 commits March 12, 2026 08:26
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
…loss_input

Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug force-pushed the mloh/seqpack_fusion branch from 2113bf9 to 15144b1 Compare March 12, 2026 16:04
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug added CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Mar 12, 2026
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug
Copy link
Contributor Author

nujoug commented Mar 12, 2026

/ok to test c492130

@guyueh1
Copy link
Contributor

guyueh1 commented Mar 12, 2026

@terrykong please review

Copy link
Collaborator

thanks, will review soon

@terrykong terrykong self-requested a review March 14, 2026 17:27
@terrykong terrykong merged commit 5f9d5cf into main Mar 17, 2026
31 checks passed
@terrykong terrykong deleted the mloh/seqpack_fusion branch March 17, 2026 07:59
nbasyl pushed a commit that referenced this pull request Mar 17, 2026
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: Shih-Yang Liu <shihyangl@nvidia.com>
nbasyl pushed a commit that referenced this pull request Mar 18, 2026
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: Shih-Yang Liu <shihyangl@nvidia.com>
@anwithk anwithk added this to the v0.6 Release milestone Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) Performance Related to improving performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants