feat(mimo): phase 2 - model provider, DDP wrapping, process groups#2004
feat(mimo): phase 2 - model provider, DDP wrapping, process groups#2004yaoyu-33 merged 6 commits intoNVIDIA-NeMo:mainfrom
Conversation
8312d19 to
310869a
Compare
8903f7c to
c10a632
Compare
📝 WalkthroughWalkthroughThis pull request introduces a comprehensive MIMO (Multi-Input Multi-Output) model provider framework enabling heterogeneous per-module parallelism. It adds configuration structures, distributed communication grid builders, model providers, vision-language model support, and DDP wrapping utilities, consolidated through updated package exports and integration points. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MimoProvider as MimoModelProvider
participant Config as MimoParallelismConfig
participant Builder as mimo_builder
participant Infrastructure as HyperCommGrid/<br/>ProcessGroups
participant Provider as ModelProvider<br/>Internal
participant Model as MimoModel
participant DDP as wrap_mimo_model<br/>_distributed
User->>MimoProvider: create provider with specs & config
User->>MimoProvider: finalize()
MimoProvider->>Config: finalize(world_size)
Config->>Config: derive data_parallel_size
Config->>Config: validate deployment mode
User->>MimoProvider: provide()
MimoProvider->>MimoProvider: build_infra()
MimoProvider->>Builder: build_hypercomm_grids(config)
Builder->>Infrastructure: create grids per module
Infrastructure-->>Builder: Dict[module → HyperCommGrid]
Builder->>Builder: _default_topology()
Builder->>Infrastructure: build_colocated_comm_config()
Infrastructure-->>Builder: topology & pg_collections
MimoProvider->>Provider: inject pg_collections into specs
MimoProvider->>Provider: create MimoModel or MimoStubModel
Provider-->>Model: model instance
User->>MimoProvider: provide_distributed_model()
MimoProvider->>DDP: wrap_mimo_model_distributed(...)
DDP->>Infrastructure: check rank in grid
DDP->>DDP: wrap language_model with DDP
DDP->>DDP: wrap modality_submodules with DDP
Infrastructure-->>DDP: DistributedDataParallel-wrapped modules
DDP-->>Model: in-place wrapped MimoModel
Model-->>User: ready for distributed training
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/__init__.py`:
- Around line 14-21: The __all__ list in the module contains exported names that
are not alphabetically sorted which triggers Ruff RUF022; update the __all__
assignment in the module to have the exported symbols ("MimoParallelismConfig",
"ModuleParallelismConfig", "MimoModelProvider", "MimoModelInfra",
"MimoStubModel", "LlavaMimoProvider") ordered alphabetically (case-sensitive) so
the list is sorted and the linter error is resolved.
In `@src/megatron/bridge/models/mimo/mimo_provider.py`:
- Around line 291-294: The for-loop in mimo_provider.py uses an unused variable
encoder_name; rename it to _ (or _encoder_name) to silence lint warnings without
changing behavior: update the loop header "for encoder_name, encoder_spec in
spec.submodules['encoders'].items()" to "for _, encoder_spec in
spec.submodules['encoders'].items()" while leaving the body that sets
encoder_spec.params and encoder_spec.params['pg_collection'] unchanged.
- Line 22: Remove the unused import DistributedDataParallel from the top import
line that currently reads "from megatron.core.distributed import
DistributedDataParallel, DistributedDataParallelConfig"; keep only
DistributedDataParallelConfig since that is the only symbol referenced (e.g., as
a type hint in the mimo provider code), so replace the import to import just
DistributedDataParallelConfig and run a quick lint to confirm no other usages of
DistributedDataParallel remain.
In `@tests/unit_tests/models/mimo/test_mimo_provider.py`:
- Around line 100-113: The test test_provide_signature_matches_mixin has unused
patched arguments and unused local assignments causing lint warnings; update the
signature to mark unused patches with underscores (e.g., rename mock_build_grids
and mock_mimo_model to _mock_build_grids, _mock_mimo_model) or prefix them with
_, and remove or use any unused local variables (for example, if language_spec
or result are only for clarity, either assert on result or change its name to
_result) so no unused-local lint errors remain while keeping the call to
provider.provide(pre_process=True, post_process=False, vp_stage=0) intact.
In `@tests/unit_tests/training/mimo/test_mimo_config.py`:
- Around line 81-82: The tests instantiate ModuleParallelismConfig using invalid
kwarg names (tensor_parallel, data_parallel, rank_offset) which causes
import-time errors; open the ModuleParallelismConfig constructor to confirm its
actual parameter names or use positional arguments, then update the "encoder"
and "language_module" instantiations (and the other occurrences at the mentioned
ranges) to use the correct argument names or positional order with the same
values (1,2,0 and 1,4,4) so the test imports succeed; ensure you update every
occurrence in this test file where ModuleParallelismConfig is constructed.
In `@tests/unit_tests/training/mimo/test_mimo_ddp.py`:
- Around line 194-196: The test assigns the unused variable result when calling
wrap_mimo_model_distributed; remove the unused assignment and call
wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config,
grids, pg_collections) directly (do the same for the other occurrence where
result is assigned at the later call). This removes the unused local and
satisfies linting while keeping the call and side effects intact.
- Around line 104-112: The _create_mimo_parallelism_config helper is passing
incorrect kwarg names to ModuleParallelismConfig; replace
tensor_parallel=config.get("tp", 1) with
tensor_model_parallel_size=config.get("tp", 1) and
data_parallel=config.get("dp", 1) with data_parallel_size=config.get("dp", 1)
(leave rank_offset as-is) in the _create_mimo_parallelism_config function so
ModuleParallelismConfig is constructed with the expected parameter names.
- Line 6: The file test_mimo_ddp.py contains an unused import "pytest" — remove
the unused import statement (the line importing pytest) from the top of the
module so only necessary imports (e.g., unittest.mock) remain; ensure no other
references to pytest exist in functions like the test cases in this file before
committing the change.
16f5bd4 to
bda8972
Compare
|
/ok to test bda8972 |
bda8972 to
3d7ff48
Compare
Add DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment. Key changes: - Add wrap_mimo_model_distributed() for rank-aware DDP wrapping of MIMO submodules - Add embedding group helpers to mimo_builder.py: populate_embedding_and_position_groups(), is_pp_first_stage(), is_pp_last_stage(), is_current_rank_in_grid() - Improve gap detection in MimoParallelismConfig._validate_heterogeneous() - Extend _get_pg_collections_from_grids() to populate pos_embd and embd process groups - Set variable_seq_lengths=True in provide_distributed_model() - Update copyright headers to 2026 - Add comprehensive unit tests for all new functionality Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
3d7ff48 to
59b7d32
Compare
Tests were importing from mimo_builder which no longer defines these functions. Update tests to import from megatron.core.pipeline_parallel.utils directly. Fix mock setup: patch torch.distributed.is_initialized to True so get_pg_rank uses group.rank() instead of always returning 0. Update None-group tests to assert True (MCore returns rank=0, size=1 for None). Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
What does this PR do?
Implements Phase 2 of MIMO upstreaming: adds DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment.
Changelog
wrap_mimo_model_distributed()for rank-aware DDP wrapping of MIMO submodulesmimo_builder.py:populate_embedding_and_position_groups()for PP > 1 supportis_pp_first_stage()/is_pp_last_stage()helpersis_current_rank_in_grid()for rank participation checksMimoParallelismConfig._validate_heterogeneous():_get_pg_collections_from_grids()to populatepos_embdandembdprocess groupsFiles Changed
models/mimo/mimo_builder.pymodels/mimo/mimo_provider.pytraining/mimo_config.pytraining/mimo_ddp.pytests/GitHub Actions CI
See the CI section in the Contributing doc for how to trigger the CI.
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.