Skip to content

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Nov 25, 2025

What does this PR do ?

Consolidate the sequence padding in Mcore sequence packing case, fix bugs for FP8.

The sequence packing now relies on three hyperparameters:

  • pad_individual_seq_to_multiple_of: pad individual sequence to multiple of. This is to ensure TP and CP splitting works.
  • pad_packed_seq_to_multiple_of: pad the packed sequence length to multiple of. This is to ensure proper shape for FP8 linear kernels.
  • pad_full_seq_to: this is used when PP>1, we need to keep all sequences in a batch to be the same length.
    This functionality is now maintained in a separate utility function.

closes #1551

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • Refactor
    • Centralized sequence padding parameter calculations for improved consistency across training and inference workflows. Introduced helper functions to derive padding configurations based on model settings, replacing previously scattered padding logic. Enhanced padding handling for context parallelism, tensor model parallelism, and precision settings, ensuring more robust sequence packing behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: root <[email protected]>
@guyueh1 guyueh1 requested review from a team as code owners November 25, 2025 22:39
@guyueh1 guyueh1 added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Nov 25, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 25, 2025

📝 Walkthrough

Walkthrough

The changes consolidate and enhance padding logic for Megatron sequence packing to address FP8 alignment requirements. A new centralized helper function derives padding parameters from Megatron configuration (including FP8 settings and parallelism sizes), and these parameters are threaded through the policy worker workflow to ensure proper tensor dimension alignment.

Changes

Cohort / File(s) Summary
Core Megatron Padding Helpers
nemo_rl/models/megatron/common.py
Added _round_up_to_multiple() utility. Introduced _get_pack_sequence_parameters_for_megatron() to centralize derivation of padding parameters from Megatron config. Extended _pack_sequences_for_megatron() signature with pad_packed_seq_to_multiple_of parameter. Refactored padding logic to use a centralized needs_padding flag and account for FP8, context parallelism, tensor model parallelism, and sequence parallelism multipliers. Updated forward_step_arbitrary_loss() to accept and propagate pad_packed_seq_to_multiple_of.
Policy Configuration
nemo_rl/models/policy/lm_policy.py
Removed dynamic FP8-specific adjustment of sequence padding via math.lcm(16, sequence_length_pad_multiple). Removed math module import. Padding now relies on centralized helper in common.py.
Policy Execution
nemo_rl/models/policy/megatron_policy_worker.py
Imported and integrated _get_pack_sequence_parameters_for_megatron() helper. Replaced ad-hoc padding computations with centralized parameter derivation across training, inference (get_logprobs, get_topk_logits) code paths. Propagated pad_packed_seq_to_multiple_of, pad_factor, and pad_full_seq_to through forward_step_fn calls and packing operations. Adjusted multiple code branches to thread new padding parameters when sequence packing is enabled.

Sequence Diagram(s)

sequenceDiagram
    participant Policy Worker as megatron_policy_worker
    participant Common as common.py
    participant PackSeq as _pack_sequences_for_megatron()
    
    Policy Worker->>Common: _get_pack_sequence_parameters_for_megatron(megatron_cfg, max_seq_len)
    Note over Common: Derive padding params from:<br/>- Megatron config<br/>- FP8 settings<br/>- CP/TP/SP sizes
    Common-->>Policy Worker: (pad_individual, pad_packed_multiple, pad_packed_to)
    
    Policy Worker->>Policy Worker: Compute pad_factor from params
    
    alt sequence_packing enabled
        Policy Worker->>PackSeq: Call with pad_packed_seq_to_multiple_of
        PackSeq->>Common: _round_up_to_multiple(pad_to, pad_multiple)
        Common-->>PackSeq: Rounded pad target
        PackSeq->>PackSeq: Apply padding to achieve<br/>FP8-compliant dimensions
        PackSeq-->>Policy Worker: Padded sequences
    else sequence_packing disabled
        Policy Worker->>Policy Worker: Use standard padding
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Key areas requiring extra attention:
    • nemo_rl/models/megatron/common.py: The reworked padding logic with needs_padding flag and interaction between multiple padding parameters (FP8, CP, TP, SP multipliers). Verify that FP8 alignment requirements (last dim divisible by 16, product of other dims divisible by 8) are correctly computed and applied.
    • nemo_rl/models/policy/megatron_policy_worker.py: Multiple code paths now use the centralized helper. Confirm that pad_packed_seq_to_multiple_of is correctly threaded through training, inference, and edge cases (different parallelism configurations, disabled FP8, etc.).
    • Boundary conditions: Verify handling of sequences that are already aligned, very small sequences, and interaction with existing padding logic when sequence packing is disabled.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning PR lacks documented test results for critical FP8+Sequence Parallelism bug fix and contains uninitialized variable bugs suggesting insufficient testing. Document specific test commands, execution results, and evidence validating the FP8+Sequence Parallelism fix and that all code paths are tested.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: Fix the sequence padding for FP8 case' accurately reflects the main change: consolidating sequence padding logic and fixing FP8-related padding bugs identified in issue #1551.
Linked Issues check ✅ Passed The PR successfully addresses issue #1551 by introducing proper sequence padding parameters (pad_individual_seq_to_multiple_of, pad_packed_seq_to_multiple_of, pad_full_seq_to) and consolidating padding logic to fix FP8 + Sequence Parallelism failures due to tensor dimension misalignment.
Out of Scope Changes check ✅ Passed All changes are directly related to fixing FP8 sequence padding: helper functions for padding calculations, updates to packing logic, and removal of redundant FP8 adjustment logic. No unrelated modifications detected.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fa379ff and 92c2227.

📒 Files selected for processing (3)
  • nemo_rl/models/megatron/common.py (9 hunks)
  • nemo_rl/models/policy/lm_policy.py (0 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (8 hunks)
💤 Files with no reviewable changes (1)
  • nemo_rl/models/policy/lm_policy.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/megatron/common.py
  • nemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/megatron/common.py
  • nemo_rl/models/policy/megatron_policy_worker.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/megatron/common.py
  • nemo_rl/models/policy/megatron_policy_worker.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/megatron/common.py
  • nemo_rl/models/policy/megatron_policy_worker.py
🧠 Learnings (1)
📚 Learning: 2025-09-18T10:13:58.767Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: nemo_rl/algorithms/loss_functions.py:900-909
Timestamp: 2025-09-18T10:13:58.767Z
Learning: In NeMo RL codebase, when using device_mesh.get_group("tp"), the returned tp_group object has its own .rank() and .size() methods, which are the correct APIs to use (not torch.distributed.get_rank(tp_group)).

Applied to files:

  • nemo_rl/models/policy/megatron_policy_worker.py
🧬 Code graph analysis (1)
nemo_rl/models/policy/megatron_policy_worker.py (2)
nemo_rl/models/megatron/common.py (1)
  • _get_pack_sequence_parameters_for_megatron (229-281)
nemo_rl/distributed/batched_data_dict.py (1)
  • get_microbatch_iterator_for_packable_sequences_len (798-800)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (10)
nemo_rl/models/policy/megatron_policy_worker.py (3)

115-119: LGTM on import addition.

The new import for _get_pack_sequence_parameters_for_megatron is correctly added alongside related imports from nemo_rl.models.megatron.common.


1020-1047: LGTM on centralized padding parameter retrieval in train().

The padding parameters are correctly initialized with defaults for non-packing paths (lines 1022-1024) and only overridden when sequence packing is enabled (lines 1040-1047). This ensures the forward step always has valid values.


1068-1070: Padding parameter correctly threaded through forward step.

The pad_packed_seq_to_multiple_of parameter is now passed to forward_step, ensuring FP8 alignment requirements are honored during training.

nemo_rl/models/megatron/common.py (7)

36-41: LGTM on _round_up_to_multiple helper.

The implementation correctly rounds up to the nearest multiple. The conditional check if value % multiple != 0 avoids unnecessary computation when the value is already aligned.


74-88: LGTM on needs_padding flag and initial rounding.

The centralized needs_padding flag correctly captures when any padding is required. Rounding pad_packed_seq_to to a multiple of pad_packed_seq_to_multiple_of upfront (lines 85-88) ensures FP8 alignment is satisfied when PP forces a specific packed length.


104-120: Correct handling of padded cumulative sequence lengths.

The logic properly updates cu_seqlens_padded based on the padding requirements:

  • Individual sequences are padded to pad_factor (line 106)
  • The final packed length is either set to pad_packed_seq_to or rounded up to pad_packed_seq_to_multiple_of (lines 115-120)

144-159: Complex padding logic for last sequence element is correct.

When b == batch_size - 1 and padding is needed, the code correctly computes the padded length to satisfy both individual sequence padding and the overall packed sequence alignment requirements.


188-205: Padding logic when pad_factor == 1 is correct.

When no individual sequence padding is needed but packed sequence padding is required (e.g., FP8 without CP/SP), this path correctly pads the entire packed sequence to satisfy alignment.


229-281: Well-designed centralized padding parameter calculation.

The _get_pack_sequence_parameters_for_megatron function correctly derives padding requirements:

  • Individual sequence padding accounts for CP (×cp_size×2) and SP (×tp_size)
  • Packed sequence padding for FP8 uses 128 for blockwise, 16 otherwise, with CP/SP multipliers
  • PP requires padding to max sequence length in batch

This directly addresses the FP8 dimension divisibility issue in #1551.


358-426: LGTM on forward_step_arbitrary_loss updates.

The new pad_packed_seq_to_multiple_of parameter is correctly added to the function signature and passed through to _pack_sequences_for_megatron.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

root added 2 commits November 26, 2025 10:50
Signed-off-by: root <[email protected]>
Signed-off-by: root <[email protected]>
@guyueh1 guyueh1 requested a review from a team as a code owner November 26, 2025 19:03
@guyueh1
Copy link
Contributor Author

guyueh1 commented Nov 26, 2025

Taken coderabbit comment

@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Nov 26, 2025
terrykong
terrykong previously approved these changes Nov 26, 2025
@terrykong terrykong enabled auto-merge (squash) November 26, 2025 19:08
Signed-off-by: root <[email protected]>
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Nov 27, 2025
@guyueh1 guyueh1 requested a review from terrykong November 27, 2025 22:26
@terrykong terrykong merged commit 25ff3f6 into NVIDIA-NeMo:main Nov 28, 2025
40 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FP8 Training + Sequence Parallelism fails with Transformer Engine AssertionError: dims=[4092, 4096] not divisible by 8

2 participants