Skip to content

[sync] Update RNG sharding to include EP rank#2092

Merged
ananthsub merged 4 commits intoNVIDIA-NeMo:mainfrom
ananthsub:sync-2658
Jan 29, 2026
Merged

[sync] Update RNG sharding to include EP rank#2092
ananthsub merged 4 commits intoNVIDIA-NeMo:mainfrom
ananthsub:sync-2658

Conversation

@ananthsub
Copy link
Contributor

@ananthsub ananthsub commented Jan 27, 2026

What does this PR do ?

Sync with changes from NVIDIA/Megatron-LM#2658 and NVIDIA/Megatron-LM#2641

Changelog

  • Update RNG sharding to include EP rank, and fix CUDA RNG tracker

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • Improvements

    • Enhanced checkpoint handling for Expert Parallelism configurations in distributed training.
    • Implemented graph-safe CUDA RNG state management during checkpoint loading.
  • Tests

    • Added comprehensive test coverage for RNG state collection with varying parallelism configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ananthsub
Copy link
Contributor Author

/ok to test efa8548

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub ananthsub requested a review from yaoyu-33 January 29, 2026 10:19
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub
Copy link
Contributor Author

/ok to test b3e472d

@ananthsub ananthsub marked this pull request as ready for review January 29, 2026 10:46
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

Enhanced RNG state handling in checkpointing to support Expert Parallelism (EP) by accepting a ProcessGroupCollection parameter, sharding RNG states across PP, TP, and DP when EP > 1, and introducing graph-safe CUDA RNG tracker state loading through conversion before application.

Changes

Cohort / File(s) Summary
RNG State Checkpointing
src/megatron/bridge/training/checkpointing.py
Imports get_pg_size from megatron.core.utils. Updates get_rng_state signature to accept pg_collection parameter. Implements EP-aware RNG sharding logic: when EP > 1, shards across PP, TP, DP; otherwise maintains prior PP, TP sharding with DP as replica_id. Adds graph-safe RNG state loading by acquiring CUDA RNG tracker, determining graph-safety, and converting rng_tracker_states via tensor_parallel.convert_cuda_rng_state() before application.
RNG State Checkpointing Tests
tests/unit_tests/training/test_checkpointing.py
Adds comprehensive unit tests for RNG state collection covering EP scenarios: EP > 1 (sharded by PP, TP, DP) and EP = 1 (sharded by PP, TP). Includes tests for EP group being None and validates ShardedObject metadata (global_shape, global_offset, replica_id) and correct RNG state gathering. Tests verify get_pg_size invocation with appropriate EP objects.

Sequence Diagram

sequenceDiagram
    participant CL as Checkpoint Loader
    participant PGC as ProcessGroupCollection
    participant GSDD as Graph Safety Detector
    participant TP as tensor_parallel
    participant CRT as CUDA RNG Tracker
    
    CL->>PGC: Query EP size via get_pg_size
    PGC-->>CL: Return EP size
    
    alt EP > 1
        CL->>CL: Shard RNG states by PP, TP, DP
    else EP ≤ 1
        CL->>CL: Shard RNG states by PP, TP (DP as replica_id)
    end
    
    CL->>CRT: Acquire CUDA RNG tracker
    CRT-->>CL: Return tracker instance
    
    CL->>GSDD: Determine graph_safety status
    GSDD-->>CL: Return is_graph_safe flag
    
    CL->>TP: convert_cuda_rng_state(rng_tracker_states)
    TP-->>CL: Return converted states
    
    CL->>CRT: Set converted RNG states
    CRT-->>CL: States applied successfully
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR description lacks test results or testing information despite major breaking changes to get_rng_state() function signature and significant RNG state handling refactoring for Expert Parallelism support. Update PR description to document test execution results, convergence validation, and performance impact assessment to verify major changes are properly validated.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: updating RNG sharding to include EP (expert parallel) rank, which is the primary focus of the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@ananthsub
Copy link
Contributor Author

/ok to test 3bbb050

@ananthsub ananthsub enabled auto-merge (squash) January 29, 2026 18:50
@ananthsub ananthsub merged commit a44f04c into NVIDIA-NeMo:main Jan 29, 2026
83 of 85 checks passed
conver334 pushed a commit to conver334/Megatron-Bridge that referenced this pull request Jan 30, 2026
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: conver334 <conver334@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants