Skip to content

Add MimoOptimizer for heterogeneous parallelism#4018

Closed
yashaswikarnati wants to merge 35 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/mimo-optimizer
Closed

Add MimoOptimizer for heterogeneous parallelism#4018
yashaswikarnati wants to merge 35 commits intoNVIDIA:mainfrom
yashaswikarnati:yash/mimo-optimizer

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Contributor

Summary

Replaces #3212 (closed when base branch pull-request/3211 was deleted after #3211 merged).

  • MimoOptimizer class managing per-module MegatronOptimizer instances with heterogeneous parallelism (different DP/TP/PP per module)
  • Global gradient norm via all_reduce MAX across module boundaries
  • Module-aware gradient clipping using the global norm
  • Module-keyed state dicts for checkpointing
  • Fix grad norm computation: intra_dist_opt group now spans full module world (["tp", "cp", "ep", "pp", "dp"]) instead of just ["dp", "cp"], matching standard Megatron's intra_distributed_optimizer_instance_group
  • Assert num_distributed_optimizer_instances == 1 (multi-instance not yet supported)

Test plan

  • Unit tests pass (test_mimo_optimizer.py)
  • 2-GPU integration test (test_baseline_2gpu)
  • 4-GPU integration test (test_lm_pp3_4gpu)
  • 8-GPU integration tests (test_encoder_tp2_llm_tp2_pp3_8gpu, test_full_pp_8gpu)

🤖 Generated with Claude Code

yashaswikarnati and others added 30 commits January 27, 2026 11:48
- Rename ProcessGroupCollectionWrapper to MultiModuleProcessGroupCollection
- Rename language_model field to language_model_module_name for clarity
- Add language_model_module_name param to backward_step_multimodule
- Use functools.partial to bind param, keeping signature consistent
- Add type hints to _ensure_3d_tensor and _restore_tensor_shape
- Move is_multimodule check earlier for validation and backward selection
Introduce data classes to manage rank roles in multi-module PP setups:
- ModuleStageInfo: tracks first/last stage position within a module
- RankRole: tracks which modules a rank participates in and their stages

These classes enable selective module initialization and stage-aware
forward passes when different modules run on separate PP grids.

Signed-off-by: ykarnati <ykarnati@nvidia.com>
Enable modality submodules to operate in multi-stage PP configurations:
- Add is_first_stage/is_last_stage as immutable properties
- First stage: runs encoder on raw inputs
- Intermediate stages: pass through hidden states
- Last stage: applies input projection before language model

Update from_spec() to pass stage info through constructor for proper
initialization based on pipeline position.

Signed-off-by: ykarnati <ykarnati@nvidia.com>
Add support for running encoder and language modules on separate PP grids:
- Determine rank role based on module_to_grid_map configuration
- Selective module initialization based on role (encoder-only or LM-only)
- Stage-aware forward dispatching based on role
- Validate grid map configuration requires language_module_key

The forward pass now routes to _forward_encoders or _forward_language_module
based on the rank's assigned role in the multi-module PP setup.

Signed-off-by: ykarnati <ykarnati@nvidia.com>
Add comprehensive tests for multi-module PP functionality:
- test_mimo_role.py: RankRole and ModuleStageInfo data classes
- test_mimo_1f1b_schedule.py: 1F1B schedule with multi-module PP
- Update existing tests for stage-aware submodule behavior

Tests validate role determination, selective initialization, and
stage-aware forward passes for both encoder-only and language-only ranks.

Signed-off-by: ykarnati <ykarnati@nvidia.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolve merge conflicts in 8 files after syncing fork with upstream
(including merged NVIDIA#3129). Take main's improvements for multimodule
communicator, p2p_communication, process_groups_config, and schedule
tests. Keep both non-colocated PP changes and main's new features.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- base.py: Remove redundant conditionals in _initialize_submodules,
  simplify forward() dispatch with guard-first pattern, collapse
  _forward_encoders to single-expression conditionals, return None
  from _determine_role when rank is in no grid
- submodules/base.py: Promote encode, combine_embeddings,
  project_embeddings, and forward from abstract to concrete methods,
  fix missing f-prefix in error message, fix project_embeddings to
  always combine before projecting
- submodules/vision.py, audio.py: Remove duplicate implementations,
  keep only __init__ (with projection assertions) and decode
- config/role.py: Add __post_init__ validation for language_module_name
- from_spec docstring: Document is_first_stage controls output projections

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add _make_vlm/_make_avlm/_make_input_ids/_make_position_ids helpers
  to eliminate repeated 7-arg factory calls and tensor construction
- Move device to setup_method, remove 7 duplicate torch.device() lines
- Delete dead module-level AudioEncoderWrapper (duplicate of inner class)
- Simplify test_state_dict to any() one-liners
- Remove redundant assert-not-None before shape checks
- Fix hardcoded batch_size=2 to use self.batch_size
- Remove test-internal setup assertions that can never fail
- Add img_h/img_w/patch_dim attrs to TestMimoModelNonColocated

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add colocated flag and RankRole.all_modules() factory so
  _determine_role always returns a RankRole (never None)
- Remove all `if self.role is not None` guards from _initialize_submodules,
  _initialize_language_model, and forward()
- forward() checks self.role.colocated instead of self.role is None
- Rank-not-in-any-grid now raises RuntimeError immediately in
  _determine_role instead of returning None and failing later
- Simplify _forward_encoders: pass both encoder_inputs and hidden_states
  to submodule, let its is_first_stage flag decide which to use
- Update test_role_determination to assert colocated role properties

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MIMO always has exactly one language model, so the key doesn't need
to be configurable. This removes:
- language_module_key field from MimoModelConfig
- language_module_name field from RankRole
- Validation that language_module_key is set
- The or "_language" fallback hack in _determine_role

Replaced with a single constant LANGUAGE_MODULE_KEY = "language" in
config/role.py, used consistently across base.py and tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds optimizer support for MIMO models where different modules
(encoder, LLM) may have different DP/TP/PP configurations.

Key features:
- MimoOptimizer class managing per-module optimizers
- True global gradient norm via all_reduce MAX across module boundaries
- Module-aware checkpointing (state_dict keyed by module name)
- Simple API: get_mimo_optimizer(mimo_model, config)
- Rename _get_pg_collection_from_grid to _get_pg_collection_for_optimizer
- Remove embedding group creation (not needed by optimizer)
- Fix mp group to use ["tp", "pp"] instead of just "tp"
- Add missing optimizer groups: tp_ep_pp, expt_dp
- Update test to create required process groups
- Replace .get() with direct indexing for modality_submodules access
- When is_active=True, module MUST exist; using .get() silently hid bugs
- Direct indexing raises KeyError immediately if module missing
- Add is_current_rank_in_grid() method to HyperCommGrid
- Format code for consistency

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Replace colocated bool with PipelineMode enum (UNIFIED, NON_COLOCATED,
  COLOCATED) for clear forward path dispatch
- Move _determine_role and _validate_grid_map from MimoModel to
  RankRole.from_grid_map classmethod — MimoModel no longer knows
  about grids
- Rename LANGUAGE_MODULE_KEY to MIMO_LANGUAGE_MODULE_KEY
- Type module_to_grid_map as Dict[str, HyperCommGrid] instead of
  Dict[str, Any]
- Remove torch.distributed import from base.py (moved to role.py)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Better describes the spatial arrangement of modules across ranks
without overloading "pipeline" which has specific meaning in Megatron.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
base.py:
- Fix set_input_tensor: unwrap schedule's list wrapper before checking
  dict type, and unwrap single-element lists in dict values (P2P recv
  returns [tensor] for VPP compat)
- Fix set_input_tensor for DDP: use unwrap_model to call
  set_input_tensor on underlying GPTModel through DDP wrapper
- Remove pipeline_model_parallel_size == 1 assertion (contradicts
  non-colocated PP goal)

test_mimo_1f1b_schedule.py:
- Convert from standalone script to pytest class (TestMimo1F1BSchedule)
- Add grid tracking + cleanup (destroy_all_grids, teardown_method)
- Fix dist.new_group desync: create_all_embedding_groups upfront so
  all ranks participate in collective new_group calls
- Fix embedding groups: set embd=None for encoder ranks (no shared
  word embeddings to sync in finalize_model_grads)
- Fix NVTE env vars: clear conftest's NVTE_FLASH_ATTN=0 before
  GPTModel creation (LanguageModule asserts these match backend)
- Use MIMO_LANGUAGE_MODULE_KEY, 6-dim grid shape with expt_dp
- Cache pg_collections to avoid PG leaks in finalize_grads_func
- Add BridgeCommunicator.destroy_broadcast_pgs() to teardown

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
yashaswikarnati and others added 5 commits March 22, 2026 22:29
# Conflicts:
#	tests/unit_tests/models/test_mimo_1f1b_schedule.py
After merging PR3211 into PR3212, the optimizer code referenced the
removed `config.language_module_key` attribute. Update to use the
constant `MIMO_LANGUAGE_MODULE_KEY` from role.py, remove the stale
kwarg from test_mimo_optimizer.py, and add composite process groups
(tp-pp, tp-ep-pp, dp-ep) required by the optimizer to the 1F1B test
grid setup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Two fixes for the MIMO non-colocated pipeline tests:

1. HyperCommGrid.is_current_rank_in_grid() returned numpy.bool_ (from
   np.prod) instead of Python bool, causing `is True` checks to fail
   in the distributed optimizer test.

2. test_lm_pp3_4gpu used num_layers=2 with llm_pp=3, violating the
   Megatron constraint that num_layers must be divisible by pp_size.
   This caused an assertion failure on LLM ranks while the encoder
   rank waited at a barrier, appearing as a deadlock. Fixed by
   changing num_layers to 3.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The intra_dist_opt group was set to ["dp", "cp"] which only spans
data-parallel ranks. This meant the grad norm all-reduce in
get_grad_norm_fp32 missed TP/PP/EP ranks that hold different parameter
shards, producing an incomplete norm and incorrect gradient clipping.

Changed to ["tp", "cp", "ep", "pp", "dp"] (full module world) to match
standard Megatron's intra_distributed_optimizer_instance_group which
spans all ranks when num_distributed_optimizer_instances == 1.

Also added assertion that num_distributed_optimizer_instances == 1,
since the MIMO optimizer does not yet support multiple instances.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati requested review from a team as code owners March 24, 2026 18:17
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 24, 2026 18:17
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants