fix(mimo): Fix for MIMO Slice Batch#2956
Merged
kamran-nvidia merged 1 commit intomimo/phase5-checkpointing-rebuildfrom Mar 23, 2026
Merged
fix(mimo): Fix for MIMO Slice Batch#2956kamran-nvidia merged 1 commit intomimo/phase5-checkpointing-rebuildfrom
kamran-nvidia merged 1 commit intomimo/phase5-checkpointing-rebuildfrom
Conversation
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Contributor
|
It is an important fix, thanks. Have you run tests to check the validity of the fix? Also maybe adding unit tests would be helpful. |
Contributor
Author
|
@aroshanghias-nvd Yes, added the loss curves for some configs (see the PR description) with DP=1,2 and 4. |
aroshanghias-nvd
approved these changes
Mar 23, 2026
Contributor
aroshanghias-nvd
left a comment
There was a problem hiding this comment.
Looks good to me.
cd3f7fc
into
mimo/phase5-checkpointing-rebuild
2 checks passed
aroshanghias-nvd
pushed a commit
that referenced
this pull request
Mar 25, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
kamran-nvidia
added a commit
that referenced
this pull request
Mar 25, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
MIMO Heterogeneous DP Data Loading Fix
Problem
Heterogeneous MIMO data loading was broken for any config with DP > 1.
Each module (vision encoder, LLM) independently created a
DistributedSamplerusing its own module-local DP rank/size. With asymmetric DP (e.g. LLM DP=1,
Vision DP=3), the samplers produced different sample orderings per module.
The
BridgeCommunicatorwould then fan-in mismatched vision activations intothe LLM — silently training on garbage.
The symptom was either:
RuntimeError: get_batch returned None(iterator exhausted becausecfg.data_parallel_size=1made the microbatch calculator compute the wrongnum_microbatches), orLLM DP replica k).
Changes (4 files, +180 / -98)
src/megatron/bridge/data/mimo/dp_utils.pyExtracted shared helpers
_find_rank_module()and_needs_data_for_module()from the original monolithic
get_mimo_dp_info().get_mimo_dp_info()— now documented as returning module-local DPrank/size, used only by
slice_batch_for_mimofor per-module sub-sharding.Explicitly warns against using it for sampler construction.
get_mimo_sampling_info()(new) — returnsdp_size=1, dp_rank=0for alldata-loading ranks. This disables DP sharding at the sampler level so every
rank loads the same global micro-batch. Per-module DP sharding is deferred
to
slice_batch_for_mimoin the forward step.slice_batch_for_mimo()— added recursive handling for nested dicts(
modality_inputs→ encoder → kwargs) and improved error messages requiringMBS % module_dp == 0.src/megatron/bridge/data/mimo/loaders.pySwitched sampler construction from
get_mimo_dp_info→get_mimo_sampling_infoso all data-loading ranks get identical batches.
Added an upfront validation loop: asserts
micro_batch_size % dp == 0forevery module's DP size, failing fast instead of crashing deep in the forward
pass with a shape error.
src/megatron/bridge/training/mimo_step.py_get_module_dp_info()(new) — derives module-local DP rank/size frommimo_model.mimo_config.module_to_grid_mapat forward-step time. Returns(0, 1)in colocated mode (no grids).forward_step()— afterget_batch()returns the global micro-batch,calls
slice_batch_for_mimo(data_batch, dp_rank, dp_size)to sub-shard itfor the current module's DP group. This is the critical wiring that was
previously missing.
src/megatron/bridge/data/mimo/__init__.pyget_mimo_sampling_infoandslice_batch_for_mimo.Design
Constraints
micro_batch_sizemust be divisible by every module'sdata_parallel_size.Enforced by an assertion in
build_mimo_data_loaders.global_batch_sizemust be divisible bymicro_batch_size.Gradient norm note
With
use_distributed_optimizer=True, gradient norms scale linearly with DPsize (DP=2 → 2× the norm of DP=1). This is expected Megatron behavior: the
schedule normalizes loss by
num_tokens × num_microbatches, making per-rankgradients batch-size-independent, and the distributed optimizer's
reduce_scatterSUMs (not averages) across DP ranks. Standard DDP(
use_distributed_optimizer=False) averages instead, giving consistent normsacross DP sizes.
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information