Skip to content

[DO NOT MERGE] Combined MiMo non-colocated changes for MBridge integration#4022

Draft
yashaswikarnati wants to merge 3 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/mcore-mimo-combined
Draft

[DO NOT MERGE] Combined MiMo non-colocated changes for MBridge integration#4022
yashaswikarnati wants to merge 3 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/mcore-mimo-combined

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Contributor

DO NOT MERGE — This is a temporary integration branch for MBridge PRs.

Combines all pending MiMo non-colocated MCore changes so MBridge can bump its MCore submodule to a single ref:

This branch will be closed once the individual PRs are merged to main.

Used by MBridge PRs:

_prepare_tensor_for_comm() always inserted a singleton dim at position
-1 (the end), regardless of dim_mapping. With SBH format, the bridge
operates on dim_mapping['b']=1, but after unsqueeze(-1), dim 1 is the
hidden dimension, not batch. This caused incorrect cat/split operations
when DP sizes differ between modules (fan-in/fan-out).

Fix: add tensor_ndim parameter to BridgeCommunicator. For 2D tensors
[B*S, H], batch is folded into dim 0, so fan-in/fan-out uses cat/split
at dim 0 directly — no unsqueeze/squeeze needed. Each bridge gets
tensor_ndim from module_output_ndim config in the communicator.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yashaswikarnati yashaswikarnati force-pushed the yash/mcore-mimo-combined branch 2 times, most recently from e337e71 to f8cd5d9 Compare March 25, 2026 05:39
yashaswikarnati and others added 2 commits March 24, 2026 23:38
Fixes three issues with dist_checkpointing in non-colocated MiMo:

1. sharded_state_dict() on MimoModel and ModalitySubmodules now injects
   dp_cp_group from each module's pg_collection, bypassing parallel_state
   global fallbacks that crash in non-colocated mode.

2. MimoOptimizer.sharded_state_dict() extracts param_groups and
   grad_scaler as ShardedObjects routed through distributed save,
   fixing the issue where common.pt is only written by global rank 0
   (the encoder rank) and LLM optimizer metadata was lost.

3. ModalitySubmodules gains sharded_state_dict() for TP-aware
   checkpointing (previously all tensors treated as TP-replicated).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the yash/mcore-mimo-combined branch from f8cd5d9 to 10b3ddd Compare March 25, 2026 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant