Add distributed checkpoint support for non-colocated MiMo#4020
Add distributed checkpoint support for non-colocated MiMo#4020yashaswikarnati merged 1 commit intoNVIDIA:mainfrom
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
0db5afb to
673fe9c
Compare
392dd1f to
b7050e4
Compare
|
/claude review |
|
|
||
| info.optimizer.load_state_dict(module_sd) | ||
|
|
||
| def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): |
There was a problem hiding this comment.
Critical bug: The old sharded_state_dict method (the simple 4-line version that was already in the file) still exists below this new implementation. In Python, the last definition of a method wins, so this entire new method is dead code — the old simple version at line ~231 silently shadows it.
The old method needs to be deleted for this PR to have any effect.
b7050e4 to
09c3fd5
Compare
|
/claude review |
Fixes three issues with dist_checkpointing in non-colocated MiMo: 1. sharded_state_dict() on MimoModel and ModalitySubmodules now injects dp_cp_group from each module's pg_collection, bypassing parallel_state global fallbacks that crash in non-colocated mode. 2. MimoOptimizer.sharded_state_dict() extracts param_groups and grad_scaler as ShardedObjects routed through distributed save, fixing the issue where common.pt is only written by global rank 0 (the encoder rank) and LLM optimizer metadata was lost. 3. ModalitySubmodules gains sharded_state_dict() for TP-aware checkpointing (previously all tensors treated as TP-replicated). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
09c3fd5 to
69a49c2
Compare
|
/ok to test 69a49c2 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23552500799 |
Summary
Stacked on #4019 (MimoOptimizer). Implements NMFW-33.
Fixes distributed checkpointing for non-colocated MiMo where encoder and LLM run on separate rank sets with per-module process groups:
dp_cp_groupfrom each module'spg_collection, bypassingparallel_stateglobal fallbacks that crash in non-colocated modeparam_groupsandgrad_scalerasShardedObjects routed through distributed save, fixing the issue wherecommon.ptis only written by global rank 0 (encoder rank) and LLM optimizer metadata was lostpg_collectionfor correcttp_groupTest plan
🤖 Generated with Claude Code