Skip to content

Add distributed checkpoint support for non-colocated MiMo#4020

Merged
yashaswikarnati merged 1 commit intoNVIDIA:mainfrom
yashaswikarnati:yash/mimo-checkpoint-pr
Mar 25, 2026
Merged

Add distributed checkpoint support for non-colocated MiMo#4020
yashaswikarnati merged 1 commit intoNVIDIA:mainfrom
yashaswikarnati:yash/mimo-checkpoint-pr

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Contributor

Summary

Stacked on #4019 (MimoOptimizer). Implements NMFW-33.

Fixes distributed checkpointing for non-colocated MiMo where encoder and LLM run on separate rank sets with per-module process groups:

  • MimoModel.sharded_state_dict(): injects dp_cp_group from each module's pg_collection, bypassing parallel_state global fallbacks that crash in non-colocated mode
  • MimoOptimizer.sharded_state_dict(): extracts param_groups and grad_scaler as ShardedObjects routed through distributed save, fixing the issue where common.pt is only written by global rank 0 (encoder rank) and LLM optimizer metadata was lost
  • ModalitySubmodules.sharded_state_dict(): enables TP-aware checkpointing (previously all tensors treated as TP-replicated)
  • MultimodalProjector: accepts pg_collection for correct tp_group

Test plan

  • Save/load roundtrip tests for model params and optimizer state
  • 4-GPU: LLM PP=3
  • 8-GPU: Encoder TP=2, LLM TP=2 PP=3
  • 8-GPU: Encoder TP=1, LLM TP=1 PP=7
  • 8-GPU: Encoder TP=2 PP=2, LLM TP=2 PP=2

🤖 Generated with Claude Code

@yashaswikarnati yashaswikarnati requested review from a team as code owners March 24, 2026 19:21
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 24, 2026 19:21
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@yashaswikarnati yashaswikarnati force-pushed the yash/mimo-checkpoint-pr branch 2 times, most recently from 392dd1f to b7050e4 Compare March 25, 2026 00:56
@yashaswikarnati yashaswikarnati marked this pull request as ready for review March 25, 2026 00:57
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 25, 2026 00:57
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Mar 25, 2026
@yashaswikarnati
Copy link
Copy Markdown
Contributor Author

/claude review


info.optimizer.load_state_dict(module_sd)

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: The old sharded_state_dict method (the simple 4-line version that was already in the file) still exists below this new implementation. In Python, the last definition of a method wins, so this entire new method is dead code — the old simple version at line ~231 silently shadows it.

The old method needs to be deleted for this PR to have any effect.

@yashaswikarnati yashaswikarnati force-pushed the yash/mimo-checkpoint-pr branch from b7050e4 to 09c3fd5 Compare March 25, 2026 05:39
@yashaswikarnati
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Fixes three issues with dist_checkpointing in non-colocated MiMo:

1. sharded_state_dict() on MimoModel and ModalitySubmodules now injects
   dp_cp_group from each module's pg_collection, bypassing parallel_state
   global fallbacks that crash in non-colocated mode.

2. MimoOptimizer.sharded_state_dict() extracts param_groups and
   grad_scaler as ShardedObjects routed through distributed save,
   fixing the issue where common.pt is only written by global rank 0
   (the encoder rank) and LLM optimizer metadata was lost.

3. ModalitySubmodules gains sharded_state_dict() for TP-aware
   checkpointing (previously all tensors treated as TP-replicated).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the yash/mimo-checkpoint-pr branch from 09c3fd5 to 69a49c2 Compare March 25, 2026 06:40
Copy link
Copy Markdown
Contributor

@dimapihtar dimapihtar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

@yashaswikarnati
Copy link
Copy Markdown
Contributor Author

/ok to test 69a49c2

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the Final Review PR is in the "final review" stage label Mar 25, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label Mar 25, 2026
@yashaswikarnati yashaswikarnati added this pull request to the merge queue Mar 25, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23552500799

Merged via the queue into NVIDIA:main with commit 1df5591 Mar 25, 2026
67 of 69 checks passed
@yashaswikarnati yashaswikarnati deleted the yash/mimo-checkpoint-pr branch March 25, 2026 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants