Skip to content

fix(mimo): Fix for MIMO Slice Batch#2956

Merged
kamran-nvidia merged 1 commit intomimo/phase5-checkpointing-rebuildfrom
kamran/mimo_slice_batch_fix
Mar 23, 2026
Merged

fix(mimo): Fix for MIMO Slice Batch#2956
kamran-nvidia merged 1 commit intomimo/phase5-checkpointing-rebuildfrom
kamran/mimo_slice_batch_fix

Conversation

@kamran-nvidia
Copy link
Copy Markdown
Contributor

@kamran-nvidia kamran-nvidia commented Mar 23, 2026

What does this PR do ?

MIMO Heterogeneous DP Data Loading Fix

W B Chart 3_23_2026, 12_18_46 PM

Problem

Heterogeneous MIMO data loading was broken for any config with DP > 1.
Each module (vision encoder, LLM) independently created a DistributedSampler
using its own module-local DP rank/size. With asymmetric DP (e.g. LLM DP=1,
Vision DP=3), the samplers produced different sample orderings per module.
The BridgeCommunicator would then fan-in mismatched vision activations into
the LLM — silently training on garbage.

The symptom was either:

  • RuntimeError: get_batch returned None (iterator exhausted because
    cfg.data_parallel_size=1 made the microbatch calculator compute the wrong
    num_microbatches), or
  • Silent data mismatch (vision DP replica k processes different samples than
    LLM DP replica k).

Changes (4 files, +180 / -98)

src/megatron/bridge/data/mimo/dp_utils.py

  • Extracted shared helpers _find_rank_module() and _needs_data_for_module()
    from the original monolithic get_mimo_dp_info().

  • get_mimo_dp_info() — now documented as returning module-local DP
    rank/size, used only by slice_batch_for_mimo for per-module sub-sharding.
    Explicitly warns against using it for sampler construction.

  • get_mimo_sampling_info() (new) — returns dp_size=1, dp_rank=0 for all
    data-loading ranks. This disables DP sharding at the sampler level so every
    rank loads the same global micro-batch. Per-module DP sharding is deferred
    to slice_batch_for_mimo in the forward step.

  • slice_batch_for_mimo() — added recursive handling for nested dicts
    (modality_inputs → encoder → kwargs) and improved error messages requiring
    MBS % module_dp == 0.

src/megatron/bridge/data/mimo/loaders.py

  • Switched sampler construction from get_mimo_dp_infoget_mimo_sampling_info
    so all data-loading ranks get identical batches.

  • Added an upfront validation loop: asserts micro_batch_size % dp == 0 for
    every module's DP size, failing fast instead of crashing deep in the forward
    pass with a shape error.

src/megatron/bridge/training/mimo_step.py

  • _get_module_dp_info() (new) — derives module-local DP rank/size from
    mimo_model.mimo_config.module_to_grid_map at forward-step time. Returns
    (0, 1) in colocated mode (no grids).

  • forward_step() — after get_batch() returns the global micro-batch,
    calls slice_batch_for_mimo(data_batch, dp_rank, dp_size) to sub-shard it
    for the current module's DP group. This is the critical wiring that was
    previously missing.

src/megatron/bridge/data/mimo/__init__.py

  • Exported get_mimo_sampling_info and slice_batch_for_mimo.

Design

 ┌─────────────────────────────────────────────────────┐
 │  DataLoader (sampler dp_size=1)                     │
 │  All data-loading ranks get identical micro-batches  │
 └────────────────────────┬────────────────────────────┘
                          │  global micro-batch (MBS samples)
              ┌───────────┴───────────┐
              ▼                       ▼
     Vision ranks                 LLM ranks
  slice_batch_for_mimo()     slice_batch_for_mimo()
  (dp_rank, dp_size from     (dp_rank, dp_size from
   vision grid)               LLM grid)
              │                       │
              ▼                       ▼
     MBS/vision_dp samples     MBS/llm_dp samples
              │                       │
              └───── BridgeCommunicator fan-in ──→ LLM receives
                     activations aligned with its text tokens ✓

Constraints

  • micro_batch_size must be divisible by every module's data_parallel_size.
    Enforced by an assertion in build_mimo_data_loaders.
  • global_batch_size must be divisible by micro_batch_size.

Gradient norm note

With use_distributed_optimizer=True, gradient norms scale linearly with DP
size (DP=2 → 2× the norm of DP=1). This is expected Megatron behavior: the
schedule normalizes loss by num_tokens × num_microbatches, making per-rank
gradients batch-size-independent, and the distributed optimizer's
reduce_scatter SUMs (not averages) across DP ranks. Standard DDP
(use_distributed_optimizer=False) averages instead, giving consistent norms
across DP sizes.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@kamran-nvidia kamran-nvidia marked this pull request as ready for review March 23, 2026 17:53
@aroshanghias-nvd
Copy link
Copy Markdown
Contributor

aroshanghias-nvd commented Mar 23, 2026

It is an important fix, thanks. Have you run tests to check the validity of the fix? Also maybe adding unit tests would be helpful.

@kamran-nvidia
Copy link
Copy Markdown
Contributor Author

kamran-nvidia commented Mar 23, 2026

@aroshanghias-nvd Yes, added the loss curves for some configs (see the PR description) with DP=1,2 and 4.

Copy link
Copy Markdown
Contributor

@aroshanghias-nvd aroshanghias-nvd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@kamran-nvidia kamran-nvidia merged commit cd3f7fc into mimo/phase5-checkpointing-rebuild Mar 23, 2026
2 checks passed
@kamran-nvidia kamran-nvidia deleted the kamran/mimo_slice_batch_fix branch March 23, 2026 19:53
aroshanghias-nvd pushed a commit that referenced this pull request Mar 25, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
kamran-nvidia added a commit that referenced this pull request Mar 25, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants