Skip to content
9 changes: 9 additions & 0 deletions megatron/core/hyper_comm_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,12 @@ def _order_dims(self, dims: Union[str, list[str]]) -> Tuple[list[str], str]:

unique_group_key = "-".join(ordered_dims)
return ordered_dims, unique_group_key

def is_current_rank_in_grid(self) -> bool:
"""Check if the current rank belongs to this grid.

Returns:
True if the current rank is within [rank_offset, rank_offset + size).
"""
rank = dist.get_rank()
return self.rank_offset <= rank < self.rank_offset + self.size
3 changes: 3 additions & 0 deletions megatron/core/models/mimo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.model import MimoModel
from megatron.core.models.mimo.optimizer import MimoOptimizer, get_mimo_optimizer
from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules
from megatron.core.models.mimo.submodules.base import ModalitySubmodules
from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules

__all__ = [
'MimoModelConfig',
'MimoModel',
'MimoOptimizer',
'get_mimo_optimizer',
# Submodule classes
'ModalitySubmodules',
'VisionModalitySubmodules',
Expand Down
225 changes: 225 additions & 0 deletions megatron/core/models/mimo/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

"""Optimizer for MIMO models with heterogeneous parallelism."""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch

from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32
from megatron.core.optimizer.optimizer import MegatronOptimizer
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.process_groups_config import ProcessGroupCollection


@dataclass
class ModuleOptimizerInfo:
"""Optimizer info for a single module."""

optimizer: Optional[MegatronOptimizer]
grid: Any # HyperCommGrid
pg_collection: Optional[ProcessGroupCollection]
is_active: bool


class MimoOptimizer(MegatronOptimizer):
"""
Optimizer for MimoModel with heterogeneous parallelism.

Each module gets its own optimizer. Global gradient norm is computed
across all modules via all_reduce MAX.
"""

def __init__(self, module_infos: Dict[str, ModuleOptimizerInfo], config: OptimizerConfig):
self.module_infos = module_infos
self.config = config
self._active_optimizers: List[MegatronOptimizer] = [
info.optimizer
for info in module_infos.values()
if info.is_active and info.optimizer is not None
]
self.is_stub_optimizer = len(self._active_optimizers) == 0
self.optimizer = None # Base class compat

@torch.no_grad()
def prepare_grads(self) -> bool:
found_inf = False
for opt in self._active_optimizers:
found_inf |= opt.prepare_grads()
return found_inf

@torch.no_grad()
def get_grad_norm(self) -> float:
"""Compute global gradient norm across all modules via all_reduce MAX."""
num_modules = len(self.module_infos)
norm_sq = torch.zeros(num_modules, device="cuda", dtype=torch.float32)

for i, (name, info) in enumerate(sorted(self.module_infos.items())):
if info.is_active and info.optimizer:
module_norm = info.optimizer.get_grad_norm() or 0.0
norm_sq[i] = module_norm**2

torch.distributed.all_reduce(norm_sq, op=torch.distributed.ReduceOp.MAX)
return torch.sqrt(norm_sq.sum()).item()

@torch.no_grad()
def step(self) -> Tuple[bool, Optional[float], Optional[int]]:
found_inf = self.prepare_grads()
if found_inf:
return False, None, None

grad_norm = self.get_grad_norm()

# Clip with global norm
for opt in self._active_optimizers:
if getattr(opt, "is_stub_optimizer", False):
continue
params = opt.get_parameters()
if params and opt.config.clip_grad > 0.0:
clip_grad_by_total_norm_fp32(
params,
max_norm=opt.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=opt.config.use_precision_aware_optimizer,
)

num_zeros = self.count_zeros() if self.config.log_num_zeros_in_grad else None
success = self.step_with_ready_grads()

return success, grad_norm, num_zeros

@torch.no_grad()
def step_with_ready_grads(self) -> bool:
success = True
for opt in self._active_optimizers:
success &= opt.step_with_ready_grads()
return success

def zero_grad(self, set_to_none: bool = True):
for opt in self._active_optimizers:
opt.zero_grad(set_to_none)

def get_loss_scale(self) -> torch.Tensor:
if self._active_optimizers:
return self._active_optimizers[0].get_loss_scale()
return torch.tensor([1.0], dtype=torch.float32, device="cuda")

def count_zeros(self) -> int:
return sum(opt.count_zeros() for opt in self._active_optimizers)

@property
def param_groups(self) -> List[dict]:
groups = []
for opt in self._active_optimizers:
groups.extend(opt.param_groups)
return groups

# Checkpointing

def state_dict(self):
return {
name: info.optimizer.state_dict() if info.is_active and info.optimizer else None
for name, info in self.module_infos.items()
}

def load_state_dict(self, state_dict: Dict):
for name, info in self.module_infos.items():
if info.is_active and info.optimizer and state_dict.get(name):
info.optimizer.load_state_dict(state_dict[name])

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
sharded_state = {}
for name, info in self.module_infos.items():
if info.is_active and info.optimizer:
sharded_state[name] = info.optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading, **kwargs
)
return sharded_state

def reload_model_params(self, state_dict=None):
for opt in self._active_optimizers:
opt.reload_model_params(state_dict)


def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection:
"""Create ProcessGroupCollection from HyperCommGrid for optimizer use.

Only fetches process groups required by the optimizer. Assumes all groups
are pre-created in the grid via grid.create_pg() - does not create any new groups.

The following groups must be pre-created in the grid before calling this function:
grid.create_pg(["dp"])
grid.create_pg(["dp", "cp"])
grid.create_pg(["tp"])
grid.create_pg(["tp", "pp"])
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])

Args:
grid: HyperCommGrid with pre-created process groups.

Returns:
ProcessGroupCollection containing optimizer-required groups:
- dp: Data parallel group
- dp_cp: Data parallel with context parallel
- tp: Tensor parallel group
- mp: Model parallel group (tp × pp)
- tp_ep_pp: Expert tensor-model-pipeline group
- expt_dp: Expert data parallel group
"""
pg = ProcessGroupCollection()

# Core groups needed by optimizer
pg.dp = grid.get_pg("dp")
pg.dp_cp = grid.get_pg(["dp", "cp"])
pg.tp = grid.get_pg("tp")
pg.mp = grid.get_pg(["tp", "pp"])

# Expert groups
pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"])
pg.expt_dp = grid.get_pg(["dp", "ep"])

# Distributed optimizer group (same as dp_cp when num_distributed_optimizer_instances == 1)
# FIXME: Yash - handle multiple optimizer instances
pg.intra_dist_opt = grid.get_pg(["dp", "cp"])

return pg


def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> MimoOptimizer:
"""Create optimizer for MimoModel with heterogeneous parallelism."""
from megatron.core.optimizer import get_megatron_optimizer

grid_map = mimo_model.mimo_config.module_to_grid_map
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

lang_key = MIMO_LANGUAGE_MODULE_KEY

module_infos: Dict[str, ModuleOptimizerInfo] = {}

for module_name, grid in grid_map.items():
is_active = grid.is_current_rank_in_grid()

optimizer = None
pg_collection = _get_pg_collection_for_optimizer(grid)

if is_active:
if module_name == lang_key:
module = mimo_model.language_model
else:
module = mimo_model.modality_submodules[module_name]

if module is not None:
optimizer = get_megatron_optimizer(
config=config,
model_chunks=[module],
pg_collection=pg_collection,
use_gloo_process_groups=False,
)

module_infos[module_name] = ModuleOptimizerInfo(
optimizer=optimizer, grid=grid, pg_collection=pg_collection, is_active=is_active
)

return MimoOptimizer(module_infos, config)
31 changes: 30 additions & 1 deletion tests/unit_tests/models/test_mimo_1f1b_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY
from megatron.core.models.mimo.model.base import MimoModel
from megatron.core.models.mimo.optimizer import get_mimo_optimizer
from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator
from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
Expand Down Expand Up @@ -73,6 +75,9 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1):
grid.create_pg(["dp", "cp"])
grid.create_pg(["ep"])
grid.create_pg(["expt_dp"])
grid.create_pg(["tp", "pp"])
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])
_active_grids.append(grid)
return grid

Expand Down Expand Up @@ -505,6 +510,20 @@ def finalize_grads_func(*args, **kwargs):
else loss
)

# Create MimoOptimizer
logger.info(f"[Rank {dist.get_rank()}] Creating MimoOptimizer...")
opt_config = OptimizerConfig(
optimizer='adam',
lr=1e-4,
weight_decay=0.01,
clip_grad=1.0,
bf16=True,
use_distributed_optimizer=True,
)
optimizer = get_mimo_optimizer(mimo_model, opt_config)
logger.info(f"[Rank {dist.get_rank()}] MimoOptimizer created with {len(optimizer._active_optimizers)} active optimizers")

logger.info(f"[Rank {dist.get_rank()}] Creating communicator...")
communicator = MultiModulePipelineCommunicator(
module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1}
)
Expand Down Expand Up @@ -557,6 +576,11 @@ def loss_func(loss_mask, output_tensor):
output_tensor, loss_mask = model(**batch)
return output_tensor, partial(loss_func, loss_mask)

logger.info(f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches...")

# Zero gradients before forward/backward
optimizer.zero_grad()

losses = schedule.forward_backward_pipelining_without_interleaving(
forward_step_func=step_func,
data_iterator=data_iterator,
Expand All @@ -569,6 +593,11 @@ def loss_func(loss_mask, output_tensor):
pg_collection=pg_collection,
)

# Optimizer step with global gradient clipping
logger.info(f"[Rank {dist.get_rank()}] Running optimizer step...")
success, grad_norm, num_zeros = optimizer.step()
logger.info(f"[Rank {dist.get_rank()}] Optimizer step: success={success}, grad_norm={grad_norm}")

# Verify results on last LLM stage
if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")):
assert len(losses) > 0, "Expected losses on last LLM stage"
Expand Down Expand Up @@ -639,7 +668,7 @@ def test_lm_pp3_4gpu(self):
llm_dp=1,
llm_offset=1,
hidden_size=256,
num_layers=2,
num_layers=3,
vocab_size=1000,
seq_length=64,
micro_batch_size=2,
Expand Down
Loading