From bfb508df460f7ea25c4ee514d236127b76186024 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 24 Mar 2026 11:44:33 -0700 Subject: [PATCH] Add MimoOptimizer for heterogeneous parallelism Adds optimizer support for MIMO models where different modules (encoder, LLM) can have different DP/TP/PP configurations. - MimoOptimizer class managing per-module MegatronOptimizer instances - Global gradient norm via all_reduce MAX across module boundaries - Module-aware gradient clipping using the global norm - Module-keyed state dicts for checkpointing - intra_dist_opt group spans full module world ["tp","cp","ep","pp","dp"] matching standard Megatron's intra_distributed_optimizer_instance_group - Assert num_distributed_optimizer_instances == 1 (multi-instance not yet supported) - HyperCommGrid.is_current_rank_in_grid() helper - Optimizer integrated into existing 1F1B schedule tests (8-GPU) Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/hyper_comm_grid.py | 9 + megatron/core/models/mimo/__init__.py | 3 + megatron/core/models/mimo/optimizer.py | 251 ++++++++++++++++++ .../models/test_mimo_1f1b_schedule.py | 25 ++ 4 files changed, 288 insertions(+) create mode 100644 megatron/core/models/mimo/optimizer.py diff --git a/megatron/core/hyper_comm_grid.py b/megatron/core/hyper_comm_grid.py index 9b5cc6cfaf5..4b860396c4e 100644 --- a/megatron/core/hyper_comm_grid.py +++ b/megatron/core/hyper_comm_grid.py @@ -262,3 +262,12 @@ def _order_dims(self, dims: Union[str, list[str]]) -> Tuple[list[str], str]: unique_group_key = "-".join(ordered_dims) return ordered_dims, unique_group_key + + def is_current_rank_in_grid(self) -> bool: + """Check if the current rank belongs to this grid. + + Returns: + True if the current rank is within [rank_offset, rank_offset + size). + """ + rank = dist.get_rank() + return bool(self.rank_offset <= rank < self.rank_offset + self.size) diff --git a/megatron/core/models/mimo/__init__.py b/megatron/core/models/mimo/__init__.py index 204851c444b..779bf921e1c 100644 --- a/megatron/core/models/mimo/__init__.py +++ b/megatron/core/models/mimo/__init__.py @@ -2,6 +2,7 @@ from megatron.core.models.mimo.config.base_configs import MimoModelConfig from megatron.core.models.mimo.model import MimoModel +from megatron.core.models.mimo.optimizer import MimoOptimizer, get_mimo_optimizer from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.base import ModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -9,6 +10,8 @@ __all__ = [ 'MimoModelConfig', 'MimoModel', + 'MimoOptimizer', + 'get_mimo_optimizer', # Submodule classes 'ModalitySubmodules', 'VisionModalitySubmodules', diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py new file mode 100644 index 00000000000..c59025cbce3 --- /dev/null +++ b/megatron/core/models/mimo/optimizer.py @@ -0,0 +1,251 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Optimizer for MIMO models with heterogeneous parallelism.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import torch + +from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32 +from megatron.core.optimizer.optimizer import MegatronOptimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.process_groups_config import ProcessGroupCollection + +if TYPE_CHECKING: + from megatron.core.hyper_comm_grid import HyperCommGrid + + +@dataclass +class ModuleOptimizerInfo: + """Optimizer info for a single module.""" + + optimizer: Optional[MegatronOptimizer] + grid: Optional[HyperCommGrid] + pg_collection: Optional[ProcessGroupCollection] + is_active: bool + + +class MimoOptimizer(MegatronOptimizer): + """ + Optimizer for MimoModel with heterogeneous parallelism. + + Each module gets its own optimizer. Global gradient norm is computed + across all modules via all_reduce MAX. + """ + + def __init__(self, module_infos: Dict[str, ModuleOptimizerInfo], config: OptimizerConfig): + self.module_infos = module_infos + self.config = config + self._active_optimizers: List[MegatronOptimizer] = [ + info.optimizer + for info in module_infos.values() + if info.is_active and info.optimizer is not None + ] + self.is_stub_optimizer = len(self._active_optimizers) == 0 + self.optimizer = None # Base class compat + + @torch.no_grad() + def prepare_grads(self) -> bool: + found_inf = False + for opt in self._active_optimizers: + found_inf |= opt.prepare_grads() + return found_inf + + @torch.no_grad() + def get_grad_norm(self) -> float: + """Compute global gradient norm across all modules via all_reduce MAX.""" + num_modules = len(self.module_infos) + norm_sq = torch.zeros(num_modules, device="cuda", dtype=torch.float32) + + for i, (name, info) in enumerate(sorted(self.module_infos.items())): + if info.is_active and info.optimizer: + module_norm = info.optimizer.get_grad_norm() or 0.0 + norm_sq[i] = module_norm**2 + + torch.distributed.all_reduce(norm_sq, op=torch.distributed.ReduceOp.MAX) + return torch.sqrt(norm_sq.sum()).item() + + @torch.no_grad() + def step(self) -> Tuple[bool, Optional[float], Optional[int]]: + found_inf = self.prepare_grads() + # Synchronize found_inf across all ranks to prevent deadlock: + # if encoder ranks detect inf but LLM ranks don't, the early return + # would skip the all_reduce in get_grad_norm(), causing a hang. + found_inf_tensor = torch.tensor([found_inf], dtype=torch.float32, device="cuda") + torch.distributed.all_reduce(found_inf_tensor, op=torch.distributed.ReduceOp.MAX) + found_inf = found_inf_tensor.item() > 0 + if found_inf: + return False, None, None + + grad_norm = self.get_grad_norm() + + # Clip with global norm + for opt in self._active_optimizers: + if getattr(opt, "is_stub_optimizer", False): + continue + params = opt.get_parameters() + if params and opt.config.clip_grad > 0.0: + clip_grad_by_total_norm_fp32( + params, + max_norm=opt.config.clip_grad, + total_norm=grad_norm, + use_decoupled_grad=opt.config.use_precision_aware_optimizer, + ) + + num_zeros = self.count_zeros() if self.config.log_num_zeros_in_grad else None + success = self.step_with_ready_grads() + + return success, grad_norm, num_zeros + + @torch.no_grad() + def step_with_ready_grads(self) -> bool: + success = True + for opt in self._active_optimizers: + success &= opt.step_with_ready_grads() + return success + + def zero_grad(self, set_to_none: bool = True): + for opt in self._active_optimizers: + opt.zero_grad(set_to_none) + + def get_loss_scale(self) -> torch.Tensor: + if self._active_optimizers: + return self._active_optimizers[0].get_loss_scale() + return torch.tensor([1.0], dtype=torch.float32, device="cuda") + + def count_zeros(self) -> int: + return sum(opt.count_zeros() for opt in self._active_optimizers) + + @property + def param_groups(self) -> List[dict]: + """Combined param groups from all active module optimizers.""" + groups = [] + for opt in self._active_optimizers: + groups.extend(opt.param_groups) + return groups + + # Checkpointing + + def state_dict(self): + return { + name: info.optimizer.state_dict() if info.is_active and info.optimizer else None + for name, info in self.module_infos.items() + } + + def load_state_dict(self, state_dict: Dict): + for name, info in self.module_infos.items(): + if info.is_active and info.optimizer and state_dict.get(name): + info.optimizer.load_state_dict(state_dict[name]) + + def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): + sharded_state = {} + for name, info in self.module_infos.items(): + if info.is_active and info.optimizer: + sharded_state[name] = info.optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading, **kwargs + ) + return sharded_state + + def reload_model_params(self, state_dict=None): + for opt in self._active_optimizers: + opt.reload_model_params(state_dict) + + +def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: + """Create ProcessGroupCollection from HyperCommGrid for optimizer use. + + Only fetches process groups required by the optimizer. Assumes all groups + are pre-created in the grid via grid.create_pg() - does not create any new groups. + + The following groups must be pre-created in the grid before calling this function: + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["tp"]) + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) + + Args: + grid: HyperCommGrid with pre-created process groups. + + Returns: + ProcessGroupCollection containing optimizer-required groups: + - dp: Data parallel group + - dp_cp: Data parallel with context parallel + - tp: Tensor parallel group + - mp: Model parallel group (tp × pp) + - tp_ep_pp: Expert tensor-model-pipeline group + - expt_dp: Expert data parallel group + """ + pg = ProcessGroupCollection() + + # Core groups needed by optimizer + pg.dp = grid.get_pg("dp") + pg.dp_cp = grid.get_pg(["dp", "cp"]) + pg.tp = grid.get_pg("tp") + pg.mp = grid.get_pg(["tp", "pp"]) + + # Expert groups + pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) + pg.expt_dp = grid.get_pg(["dp", "ep"]) + + # Distributed optimizer grad stats group: must span all dimensions so grad norm + # and found-inf all-reduces see every unique gradient shard. TP/PP/EP ranks hold + # different parameters, DP ranks hold different optimizer shards after reduce-scatter. + # This mirrors standard Megatron's intra_distributed_optimizer_instance_group which + # spans the full world when num_distributed_optimizer_instances == 1. + pg.intra_dist_opt = grid.get_pg(["tp", "cp", "ep", "pp", "dp"]) + + return pg + + +def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> MimoOptimizer: + """Create optimizer for MimoModel with heterogeneous parallelism.""" + from megatron.core.optimizer import get_megatron_optimizer + + grid_map = mimo_model.mimo_config.module_to_grid_map + from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + + lang_key = MIMO_LANGUAGE_MODULE_KEY + + module_infos: Dict[str, ModuleOptimizerInfo] = {} + + for module_name, grid in grid_map.items(): + is_active = grid.is_current_rank_in_grid() + + optimizer = None + pg_collection = _get_pg_collection_for_optimizer(grid) + + if is_active: + if module_name == lang_key: + module = mimo_model.language_model + else: + module = mimo_model.modality_submodules[module_name] + + if module is not None: + assert ( + not hasattr(module, 'ddp_config') + or module.ddp_config is None + or module.ddp_config.num_distributed_optimizer_instances == 1 + ), ( + "MIMO optimizer does not yet support " + "num_distributed_optimizer_instances > 1. " + f"Module '{module_name}' has " + f"{module.ddp_config.num_distributed_optimizer_instances} instances." + ) + optimizer = get_megatron_optimizer( + config=config, + model_chunks=[module], + pg_collection=pg_collection, + use_gloo_process_groups=False, + ) + + module_infos[module_name] = ModuleOptimizerInfo( + optimizer=optimizer, grid=grid, pg_collection=pg_collection, is_active=is_active + ) + + return MimoOptimizer(module_infos, config) diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 1a6b29cc58a..930861868d0 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -24,8 +24,10 @@ from megatron.core.models.mimo.config.base_configs import MimoModelConfig from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.optimizer import get_mimo_optimizer from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage @@ -73,6 +75,11 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): grid.create_pg(["dp", "cp"]) grid.create_pg(["ep"]) grid.create_pg(["expt_dp"]) + # Required by _get_pg_collection_for_optimizer + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) _active_grids.append(grid) return grid @@ -505,6 +512,17 @@ def finalize_grads_func(*args, **kwargs): else loss ) + # Create optimizer + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + communicator = MultiModulePipelineCommunicator( module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} ) @@ -557,6 +575,8 @@ def loss_func(loss_mask, output_tensor): output_tensor, loss_mask = model(**batch) return output_tensor, partial(loss_func, loss_mask) + optimizer.zero_grad() + losses = schedule.forward_backward_pipelining_without_interleaving( forward_step_func=step_func, data_iterator=data_iterator, @@ -569,6 +589,11 @@ def loss_func(loss_mask, output_tensor): pg_collection=pg_collection, ) + # Optimizer step with global gradient clipping + success, grad_norm, num_zeros = optimizer.step() + assert success, "Optimizer step failed" + assert grad_norm is not None and grad_norm > 0, f"Expected positive grad norm, got {grad_norm}" + # Verify results on last LLM stage if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): assert len(losses) > 0, "Expected losses on last LLM stage"