Add MimoOptimizer for heterogeneous parallelism#4018
Closed
yashaswikarnati wants to merge 35 commits intoNVIDIA:mainfrom
Closed
Add MimoOptimizer for heterogeneous parallelism#4018yashaswikarnati wants to merge 35 commits intoNVIDIA:mainfrom
yashaswikarnati wants to merge 35 commits intoNVIDIA:mainfrom
Conversation
- Rename ProcessGroupCollectionWrapper to MultiModuleProcessGroupCollection - Rename language_model field to language_model_module_name for clarity - Add language_model_module_name param to backward_step_multimodule - Use functools.partial to bind param, keeping signature consistent - Add type hints to _ensure_3d_tensor and _restore_tensor_shape - Move is_multimodule check earlier for validation and backward selection
Introduce data classes to manage rank roles in multi-module PP setups: - ModuleStageInfo: tracks first/last stage position within a module - RankRole: tracks which modules a rank participates in and their stages These classes enable selective module initialization and stage-aware forward passes when different modules run on separate PP grids. Signed-off-by: ykarnati <ykarnati@nvidia.com>
Enable modality submodules to operate in multi-stage PP configurations: - Add is_first_stage/is_last_stage as immutable properties - First stage: runs encoder on raw inputs - Intermediate stages: pass through hidden states - Last stage: applies input projection before language model Update from_spec() to pass stage info through constructor for proper initialization based on pipeline position. Signed-off-by: ykarnati <ykarnati@nvidia.com>
Add support for running encoder and language modules on separate PP grids: - Determine rank role based on module_to_grid_map configuration - Selective module initialization based on role (encoder-only or LM-only) - Stage-aware forward dispatching based on role - Validate grid map configuration requires language_module_key The forward pass now routes to _forward_encoders or _forward_language_module based on the rank's assigned role in the multi-module PP setup. Signed-off-by: ykarnati <ykarnati@nvidia.com>
Add comprehensive tests for multi-module PP functionality: - test_mimo_role.py: RankRole and ModuleStageInfo data classes - test_mimo_1f1b_schedule.py: 1F1B schedule with multi-module PP - Update existing tests for stage-aware submodule behavior Tests validate role determination, selective initialization, and stage-aware forward passes for both encoder-only and language-only ranks. Signed-off-by: ykarnati <ykarnati@nvidia.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolve merge conflicts in 8 files after syncing fork with upstream (including merged NVIDIA#3129). Take main's improvements for multimodule communicator, p2p_communication, process_groups_config, and schedule tests. Keep both non-colocated PP changes and main's new features. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- base.py: Remove redundant conditionals in _initialize_submodules, simplify forward() dispatch with guard-first pattern, collapse _forward_encoders to single-expression conditionals, return None from _determine_role when rank is in no grid - submodules/base.py: Promote encode, combine_embeddings, project_embeddings, and forward from abstract to concrete methods, fix missing f-prefix in error message, fix project_embeddings to always combine before projecting - submodules/vision.py, audio.py: Remove duplicate implementations, keep only __init__ (with projection assertions) and decode - config/role.py: Add __post_init__ validation for language_module_name - from_spec docstring: Document is_first_stage controls output projections Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add _make_vlm/_make_avlm/_make_input_ids/_make_position_ids helpers to eliminate repeated 7-arg factory calls and tensor construction - Move device to setup_method, remove 7 duplicate torch.device() lines - Delete dead module-level AudioEncoderWrapper (duplicate of inner class) - Simplify test_state_dict to any() one-liners - Remove redundant assert-not-None before shape checks - Fix hardcoded batch_size=2 to use self.batch_size - Remove test-internal setup assertions that can never fail - Add img_h/img_w/patch_dim attrs to TestMimoModelNonColocated Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add colocated flag and RankRole.all_modules() factory so _determine_role always returns a RankRole (never None) - Remove all `if self.role is not None` guards from _initialize_submodules, _initialize_language_model, and forward() - forward() checks self.role.colocated instead of self.role is None - Rank-not-in-any-grid now raises RuntimeError immediately in _determine_role instead of returning None and failing later - Simplify _forward_encoders: pass both encoder_inputs and hidden_states to submodule, let its is_first_stage flag decide which to use - Update test_role_determination to assert colocated role properties Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MIMO always has exactly one language model, so the key doesn't need to be configurable. This removes: - language_module_key field from MimoModelConfig - language_module_name field from RankRole - Validation that language_module_key is set - The or "_language" fallback hack in _determine_role Replaced with a single constant LANGUAGE_MODULE_KEY = "language" in config/role.py, used consistently across base.py and tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds optimizer support for MIMO models where different modules (encoder, LLM) may have different DP/TP/PP configurations. Key features: - MimoOptimizer class managing per-module optimizers - True global gradient norm via all_reduce MAX across module boundaries - Module-aware checkpointing (state_dict keyed by module name) - Simple API: get_mimo_optimizer(mimo_model, config)
- Rename _get_pg_collection_from_grid to _get_pg_collection_for_optimizer - Remove embedding group creation (not needed by optimizer) - Fix mp group to use ["tp", "pp"] instead of just "tp" - Add missing optimizer groups: tp_ep_pp, expt_dp - Update test to create required process groups
- Replace .get() with direct indexing for modality_submodules access - When is_active=True, module MUST exist; using .get() silently hid bugs - Direct indexing raises KeyError immediately if module missing - Add is_current_rank_in_grid() method to HyperCommGrid - Format code for consistency Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Replace colocated bool with PipelineMode enum (UNIFIED, NON_COLOCATED, COLOCATED) for clear forward path dispatch - Move _determine_role and _validate_grid_map from MimoModel to RankRole.from_grid_map classmethod — MimoModel no longer knows about grids - Rename LANGUAGE_MODULE_KEY to MIMO_LANGUAGE_MODULE_KEY - Type module_to_grid_map as Dict[str, HyperCommGrid] instead of Dict[str, Any] - Remove torch.distributed import from base.py (moved to role.py) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Better describes the spatial arrangement of modules across ranks without overloading "pipeline" which has specific meaning in Megatron. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
base.py: - Fix set_input_tensor: unwrap schedule's list wrapper before checking dict type, and unwrap single-element lists in dict values (P2P recv returns [tensor] for VPP compat) - Fix set_input_tensor for DDP: use unwrap_model to call set_input_tensor on underlying GPTModel through DDP wrapper - Remove pipeline_model_parallel_size == 1 assertion (contradicts non-colocated PP goal) test_mimo_1f1b_schedule.py: - Convert from standalone script to pytest class (TestMimo1F1BSchedule) - Add grid tracking + cleanup (destroy_all_grids, teardown_method) - Fix dist.new_group desync: create_all_embedding_groups upfront so all ranks participate in collective new_group calls - Fix embedding groups: set embd=None for encoder ranks (no shared word embeddings to sync in finalize_model_grads) - Fix NVTE env vars: clear conftest's NVTE_FLASH_ATTN=0 before GPTModel creation (LanguageModule asserts these match backend) - Use MIMO_LANGUAGE_MODULE_KEY, 6-dim grid shape with expt_dp - Cache pg_collections to avoid PG leaks in finalize_grads_func - Add BridgeCommunicator.destroy_broadcast_pgs() to teardown Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
# Conflicts: # tests/unit_tests/models/test_mimo_1f1b_schedule.py
After merging PR3211 into PR3212, the optimizer code referenced the removed `config.language_module_key` attribute. Update to use the constant `MIMO_LANGUAGE_MODULE_KEY` from role.py, remove the stale kwarg from test_mimo_optimizer.py, and add composite process groups (tp-pp, tp-ep-pp, dp-ep) required by the optimizer to the 1F1B test grid setup. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Two fixes for the MIMO non-colocated pipeline tests: 1. HyperCommGrid.is_current_rank_in_grid() returned numpy.bool_ (from np.prod) instead of Python bool, causing `is True` checks to fail in the distributed optimizer test. 2. test_lm_pp3_4gpu used num_layers=2 with llm_pp=3, violating the Megatron constraint that num_layers must be divisible by pp_size. This caused an assertion failure on LLM ranks while the encoder rank waited at a barrier, appearing as a deadlock. Fixed by changing num_layers to 3. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The intra_dist_opt group was set to ["dp", "cp"] which only spans data-parallel ranks. This meant the grad norm all-reduce in get_grad_norm_fp32 missed TP/PP/EP ranks that hold different parameter shards, producing an incomplete norm and incorrect gradient clipping. Changed to ["tp", "cp", "ep", "pp", "dp"] (full module world) to match standard Megatron's intra_distributed_optimizer_instance_group which spans all ranks when num_distributed_optimizer_instances == 1. Also added assertion that num_distributed_optimizer_instances == 1, since the MIMO optimizer does not yet support multiple instances. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Contributor
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces #3212 (closed when base branch
pull-request/3211was deleted after #3211 merged).all_reduce MAXacross module boundariesintra_dist_optgroup now spans full module world (["tp", "cp", "ep", "pp", "dp"]) instead of just["dp", "cp"], matching standard Megatron'sintra_distributed_optimizer_instance_groupnum_distributed_optimizer_instances == 1(multi-instance not yet supported)Test plan
test_mimo_optimizer.py)test_baseline_2gpu)test_lm_pp3_4gpu)test_encoder_tp2_llm_tp2_pp3_8gpu,test_full_pp_8gpu)🤖 Generated with Claude Code