Skip to content

feat(mimo): phase 2 - model provider, DDP wrapping, process groups#2004

Merged
yaoyu-33 merged 6 commits intoNVIDIA-NeMo:mainfrom
aroshanghias-nvd:mimo/phase2-model
Mar 12, 2026
Merged

feat(mimo): phase 2 - model provider, DDP wrapping, process groups#2004
yaoyu-33 merged 6 commits intoNVIDIA-NeMo:mainfrom
aroshanghias-nvd:mimo/phase2-model

Conversation

@aroshanghias-nvd
Copy link
Copy Markdown
Contributor

@aroshanghias-nvd aroshanghias-nvd commented Jan 20, 2026

What does this PR do?

Implements Phase 2 of MIMO upstreaming: adds DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment.

Changelog

  • Add wrap_mimo_model_distributed() for rank-aware DDP wrapping of MIMO submodules
  • Add embedding group helpers to mimo_builder.py:
    • populate_embedding_and_position_groups() for PP > 1 support
    • is_pp_first_stage() / is_pp_last_stage() helpers
    • is_current_rank_in_grid() for rank participation checks
  • Improve gap detection in MimoParallelismConfig._validate_heterogeneous():
    • Error for gaps between modules (likely misconfiguration)
    • Warning for leading unused ranks (could be intentional)
  • Extend _get_pg_collections_from_grids() to populate pos_embd and embd process groups
  • Add comprehensive unit tests for all new functionality

Files Changed

File Change
models/mimo/mimo_builder.py Added embedding group and rank participation helpers
models/mimo/mimo_provider.py Imports helpers from mimo_builder, populates embedding groups in pg_collections
training/mimo_config.py Improved gap detection (error for middle gaps, warning for leading)
training/mimo_ddp.py New file: DDP wrapping utilities for MIMO models
tests/ Added tests for embedding groups, gap detection, and DDP wrapping

GitHub Actions CI

See the CI section in the Contributing doc for how to trigger the CI.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • N/A - no new optional dependencies

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Multi-Module Multimodal (MIMO) model support with heterogeneous parallelism across modules.
    • Added LLaVA-style Vision-Language model provider for multimodal inference.
    • Added distributed training infrastructure for MIMO models.
  • Tests

    • Added comprehensive test suite for MIMO configuration, providers, and distributed training.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Jan 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase2-model branch 3 times, most recently from 8903f7c to c10a632 Compare January 28, 2026 18:59
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a comprehensive MIMO (Multi-Input Multi-Output) model provider framework enabling heterogeneous per-module parallelism. It adds configuration structures, distributed communication grid builders, model providers, vision-language model support, and DDP wrapping utilities, consolidated through updated package exports and integration points.

Changes

Cohort / File(s) Summary
MIMO Configuration
src/megatron/bridge/models/mimo/mimo_config.py
Introduces ModuleParallelismConfig and MimoParallelismConfig dataclasses with finalize() methods, world-size derivation for data parallelism, and multi-mode validation (colocated, homogeneous, heterogeneous) including rank-offset overlap detection and gap warnings.
MIMO Infrastructure & Builder
src/megatron/bridge/models/mimo/mimo_builder.py
Adds build_hypercomm_grids(), _default_topology(), build_colocated_comm_config() to construct HyperCommGrids per module, embedding/position group helpers, and rank-stage predicates (is_pp_first_stage, is_pp_last_stage, is_current_rank_in_grid).
MIMO Model Providers
src/megatron/bridge/models/mimo/mimo_provider.py
Introduces MimoModelProvider orchestrating per-module heterogeneous parallelism, MimoModelInfra metadata container, and MimoStubModel for non-participating ranks; includes build_infra(), ProcessGroup injection into specs, provide(), and provide_distributed_model() with optional FP16/BF16 and parameter freezing.
Vision-Language Model Provider
src/megatron/bridge/models/mimo/llava_provider.py
Adds LlavaMimoProvider dataclass extending MimoModelProvider with vision_encoder_module/params, 2-layer MLP projector, default Vicuna-7B language config with RMSNorm/SiLU/gated GELU, and image special token mapping.
MIMO DDP Wrapping
src/megatron/bridge/training/mimo_ddp.py
Introduces wrap_mimo_model_distributed() utility for conditional DDP wrapping of language models and modality submodules based on grid membership and rank participation.
Package Integration
src/megatron/bridge/models/mimo/__init__.py
Exports public API: MimoParallelismConfig, ModuleParallelismConfig, MimoModelProvider, MimoModelInfra, MimoStubModel, LlavaMimoProvider.
Training Config Integration
src/megatron/bridge/training/config.py
Adds MimoModelProvider to ConfigContainer model field union.
Unit Test Suite
tests/unit_tests/models/mimo/test_mimo_provider.py
736 lines of comprehensive tests covering provider initialization, infra building, parallelism handling, non-participating ranks, injection utilities, parameter freezing, and embedding group logic.
Configuration Tests
tests/unit_tests/training/mimo/test_mimo_config.py
125 lines testing finalize(), world-size derivation, deployment mode validation (colocated/homogeneous/heterogeneous), rank-offset overlap/gap detection, and idle-rank warnings.
DDP Tests
tests/unit_tests/training/mimo/test_mimo_ddp.py
361 lines testing rank grid participation, selective DDP wrapping for language models and modality submodules, heterogeneous rank ranges, and argument propagation.
Test Package Init
tests/unit_tests/models/mimo/__init__.py
Package initializer with copyright notice.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MimoProvider as MimoModelProvider
    participant Config as MimoParallelismConfig
    participant Builder as mimo_builder
    participant Infrastructure as HyperCommGrid/<br/>ProcessGroups
    participant Provider as ModelProvider<br/>Internal
    participant Model as MimoModel
    participant DDP as wrap_mimo_model<br/>_distributed

    User->>MimoProvider: create provider with specs & config
    User->>MimoProvider: finalize()
    MimoProvider->>Config: finalize(world_size)
    Config->>Config: derive data_parallel_size
    Config->>Config: validate deployment mode
    
    User->>MimoProvider: provide()
    MimoProvider->>MimoProvider: build_infra()
    MimoProvider->>Builder: build_hypercomm_grids(config)
    Builder->>Infrastructure: create grids per module
    Infrastructure-->>Builder: Dict[module → HyperCommGrid]
    Builder->>Builder: _default_topology()
    Builder->>Infrastructure: build_colocated_comm_config()
    Infrastructure-->>Builder: topology & pg_collections
    
    MimoProvider->>Provider: inject pg_collections into specs
    MimoProvider->>Provider: create MimoModel or MimoStubModel
    Provider-->>Model: model instance
    
    User->>MimoProvider: provide_distributed_model()
    MimoProvider->>DDP: wrap_mimo_model_distributed(...)
    DDP->>Infrastructure: check rank in grid
    DDP->>DDP: wrap language_model with DDP
    DDP->>DDP: wrap modality_submodules with DDP
    Infrastructure-->>DDP: DistributedDataParallel-wrapped modules
    DDP-->>Model: in-place wrapped MimoModel
    Model-->>User: ready for distributed training
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • cuichenx
  • skyw
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major distributed training infrastructure changes (1,222 lines) but tests contain unvalidated failures with incorrect parameter names blocking test collection in test_mimo_config.py and test_mimo_ddp.py, and PR description lacks test results confirmation. Fix parameter names to tensor_model_parallel_size and data_parallel_size, resolve lint violations, validate all tests pass locally, and update PR description with test results, coverage metrics, and multi-GPU integration validation.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main feature additions: MIMO model provider, DDP wrapping, and process groups, which are the primary changes in this Phase 2 implementation.
Docstring Coverage ✅ Passed Docstring coverage is 82.80% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/__init__.py`:
- Around line 14-21: The __all__ list in the module contains exported names that
are not alphabetically sorted which triggers Ruff RUF022; update the __all__
assignment in the module to have the exported symbols ("MimoParallelismConfig",
"ModuleParallelismConfig", "MimoModelProvider", "MimoModelInfra",
"MimoStubModel", "LlavaMimoProvider") ordered alphabetically (case-sensitive) so
the list is sorted and the linter error is resolved.

In `@src/megatron/bridge/models/mimo/mimo_provider.py`:
- Around line 291-294: The for-loop in mimo_provider.py uses an unused variable
encoder_name; rename it to _ (or _encoder_name) to silence lint warnings without
changing behavior: update the loop header "for encoder_name, encoder_spec in
spec.submodules['encoders'].items()" to "for _, encoder_spec in
spec.submodules['encoders'].items()" while leaving the body that sets
encoder_spec.params and encoder_spec.params['pg_collection'] unchanged.
- Line 22: Remove the unused import DistributedDataParallel from the top import
line that currently reads "from megatron.core.distributed import
DistributedDataParallel, DistributedDataParallelConfig"; keep only
DistributedDataParallelConfig since that is the only symbol referenced (e.g., as
a type hint in the mimo provider code), so replace the import to import just
DistributedDataParallelConfig and run a quick lint to confirm no other usages of
DistributedDataParallel remain.

In `@tests/unit_tests/models/mimo/test_mimo_provider.py`:
- Around line 100-113: The test test_provide_signature_matches_mixin has unused
patched arguments and unused local assignments causing lint warnings; update the
signature to mark unused patches with underscores (e.g., rename mock_build_grids
and mock_mimo_model to _mock_build_grids, _mock_mimo_model) or prefix them with
_, and remove or use any unused local variables (for example, if language_spec
or result are only for clarity, either assert on result or change its name to
_result) so no unused-local lint errors remain while keeping the call to
provider.provide(pre_process=True, post_process=False, vp_stage=0) intact.

In `@tests/unit_tests/training/mimo/test_mimo_config.py`:
- Around line 81-82: The tests instantiate ModuleParallelismConfig using invalid
kwarg names (tensor_parallel, data_parallel, rank_offset) which causes
import-time errors; open the ModuleParallelismConfig constructor to confirm its
actual parameter names or use positional arguments, then update the "encoder"
and "language_module" instantiations (and the other occurrences at the mentioned
ranges) to use the correct argument names or positional order with the same
values (1,2,0 and 1,4,4) so the test imports succeed; ensure you update every
occurrence in this test file where ModuleParallelismConfig is constructed.

In `@tests/unit_tests/training/mimo/test_mimo_ddp.py`:
- Around line 194-196: The test assigns the unused variable result when calling
wrap_mimo_model_distributed; remove the unused assignment and call
wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config,
grids, pg_collections) directly (do the same for the other occurrence where
result is assigned at the later call). This removes the unused local and
satisfies linting while keeping the call and side effects intact.
- Around line 104-112: The _create_mimo_parallelism_config helper is passing
incorrect kwarg names to ModuleParallelismConfig; replace
tensor_parallel=config.get("tp", 1) with
tensor_model_parallel_size=config.get("tp", 1) and
data_parallel=config.get("dp", 1) with data_parallel_size=config.get("dp", 1)
(leave rank_offset as-is) in the _create_mimo_parallelism_config function so
ModuleParallelismConfig is constructed with the expected parameter names.
- Line 6: The file test_mimo_ddp.py contains an unused import "pytest" — remove
the unused import statement (the line importing pytest) from the top of the
module so only necessary imports (e.g., unittest.mock) remain; ensure no other
references to pytest exist in functions like the test cases in this file before
committing the change.

@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase2-model branch 2 times, most recently from 16f5bd4 to bda8972 Compare January 30, 2026 22:02
@aroshanghias-nvd
Copy link
Copy Markdown
Contributor Author

/ok to test bda8972

Add DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment.

Key changes:
- Add wrap_mimo_model_distributed() for rank-aware DDP wrapping of MIMO submodules
- Add embedding group helpers to mimo_builder.py: populate_embedding_and_position_groups(), is_pp_first_stage(), is_pp_last_stage(), is_current_rank_in_grid()
- Improve gap detection in MimoParallelismConfig._validate_heterogeneous()
- Extend _get_pg_collections_from_grids() to populate pos_embd and embd process groups
- Set variable_seq_lengths=True in provide_distributed_model()
- Update copyright headers to 2026
- Add comprehensive unit tests for all new functionality

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Tests were importing from mimo_builder which no longer defines these
functions. Update tests to import from megatron.core.pipeline_parallel.utils
directly. Fix mock setup: patch torch.distributed.is_initialized to True
so get_pg_rank uses group.rank() instead of always returning 0. Update
None-group tests to assert True (MCore returns rank=0, size=1 for None).

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants