Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions megatron/core/hyper_comm_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,12 @@ def _order_dims(self, dims: Union[str, list[str]]) -> Tuple[list[str], str]:

unique_group_key = "-".join(ordered_dims)
return ordered_dims, unique_group_key

def is_current_rank_in_grid(self) -> bool:
"""Check if the current rank belongs to this grid.

Returns:
True if the current rank is within [rank_offset, rank_offset + size).
"""
rank = dist.get_rank()
return bool(self.rank_offset <= rank < self.rank_offset + self.size)
3 changes: 3 additions & 0 deletions megatron/core/models/mimo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.model import MimoModel
from megatron.core.models.mimo.optimizer import MimoOptimizer, get_mimo_optimizer
from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules
from megatron.core.models.mimo.submodules.base import ModalitySubmodules
from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules

__all__ = [
'MimoModelConfig',
'MimoModel',
'MimoOptimizer',
'get_mimo_optimizer',
# Submodule classes
'ModalitySubmodules',
'VisionModalitySubmodules',
Expand Down
251 changes: 251 additions & 0 deletions megatron/core/models/mimo/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""Optimizer for MIMO models with heterogeneous parallelism."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import torch

from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32
from megatron.core.optimizer.optimizer import MegatronOptimizer
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.process_groups_config import ProcessGroupCollection

if TYPE_CHECKING:
from megatron.core.hyper_comm_grid import HyperCommGrid


@dataclass
class ModuleOptimizerInfo:
"""Optimizer info for a single module."""

optimizer: Optional[MegatronOptimizer]
grid: Optional[HyperCommGrid]
pg_collection: Optional[ProcessGroupCollection]
is_active: bool


class MimoOptimizer(MegatronOptimizer):
"""
Optimizer for MimoModel with heterogeneous parallelism.

Each module gets its own optimizer. Global gradient norm is computed
across all modules via all_reduce MAX.
"""

def __init__(self, module_infos: Dict[str, ModuleOptimizerInfo], config: OptimizerConfig):
self.module_infos = module_infos
self.config = config
self._active_optimizers: List[MegatronOptimizer] = [
info.optimizer
for info in module_infos.values()
if info.is_active and info.optimizer is not None
]
self.is_stub_optimizer = len(self._active_optimizers) == 0
self.optimizer = None # Base class compat

@torch.no_grad()
def prepare_grads(self) -> bool:
found_inf = False
for opt in self._active_optimizers:
found_inf |= opt.prepare_grads()
return found_inf

@torch.no_grad()
def get_grad_norm(self) -> float:
"""Compute global gradient norm across all modules via all_reduce MAX."""
num_modules = len(self.module_infos)
norm_sq = torch.zeros(num_modules, device="cuda", dtype=torch.float32)

for i, (name, info) in enumerate(sorted(self.module_infos.items())):
if info.is_active and info.optimizer:
module_norm = info.optimizer.get_grad_norm() or 0.0
norm_sq[i] = module_norm**2

torch.distributed.all_reduce(norm_sq, op=torch.distributed.ReduceOp.MAX)
return torch.sqrt(norm_sq.sum()).item()

@torch.no_grad()
def step(self) -> Tuple[bool, Optional[float], Optional[int]]:
found_inf = self.prepare_grads()
# Synchronize found_inf across all ranks to prevent deadlock:
# if encoder ranks detect inf but LLM ranks don't, the early return
# would skip the all_reduce in get_grad_norm(), causing a hang.
found_inf_tensor = torch.tensor([found_inf], dtype=torch.float32, device="cuda")
torch.distributed.all_reduce(found_inf_tensor, op=torch.distributed.ReduceOp.MAX)
found_inf = found_inf_tensor.item() > 0
if found_inf:
return False, None, None

grad_norm = self.get_grad_norm()

# Clip with global norm
for opt in self._active_optimizers:
if getattr(opt, "is_stub_optimizer", False):
continue
params = opt.get_parameters()
if params and opt.config.clip_grad > 0.0:
clip_grad_by_total_norm_fp32(
params,
max_norm=opt.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=opt.config.use_precision_aware_optimizer,
)

num_zeros = self.count_zeros() if self.config.log_num_zeros_in_grad else None
success = self.step_with_ready_grads()

return success, grad_norm, num_zeros

@torch.no_grad()
def step_with_ready_grads(self) -> bool:
success = True
for opt in self._active_optimizers:
success &= opt.step_with_ready_grads()
return success

def zero_grad(self, set_to_none: bool = True):
for opt in self._active_optimizers:
opt.zero_grad(set_to_none)

def get_loss_scale(self) -> torch.Tensor:
if self._active_optimizers:
return self._active_optimizers[0].get_loss_scale()
return torch.tensor([1.0], dtype=torch.float32, device="cuda")

def count_zeros(self) -> int:
return sum(opt.count_zeros() for opt in self._active_optimizers)

@property
def param_groups(self) -> List[dict]:
"""Combined param groups from all active module optimizers."""
groups = []
for opt in self._active_optimizers:
groups.extend(opt.param_groups)
return groups

# Checkpointing

def state_dict(self):
return {
name: info.optimizer.state_dict() if info.is_active and info.optimizer else None
for name, info in self.module_infos.items()
}

def load_state_dict(self, state_dict: Dict):
for name, info in self.module_infos.items():
if info.is_active and info.optimizer and state_dict.get(name):
info.optimizer.load_state_dict(state_dict[name])

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
sharded_state = {}
for name, info in self.module_infos.items():
if info.is_active and info.optimizer:
sharded_state[name] = info.optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading, **kwargs
)
return sharded_state

def reload_model_params(self, state_dict=None):
for opt in self._active_optimizers:
opt.reload_model_params(state_dict)


def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection:
"""Create ProcessGroupCollection from HyperCommGrid for optimizer use.

Only fetches process groups required by the optimizer. Assumes all groups
are pre-created in the grid via grid.create_pg() - does not create any new groups.

The following groups must be pre-created in the grid before calling this function:
grid.create_pg(["dp"])
grid.create_pg(["dp", "cp"])
grid.create_pg(["tp"])
grid.create_pg(["tp", "pp"])
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])
grid.create_pg(["tp", "cp", "ep", "pp", "dp"])

Args:
grid: HyperCommGrid with pre-created process groups.

Returns:
ProcessGroupCollection containing optimizer-required groups:
- dp: Data parallel group
- dp_cp: Data parallel with context parallel
- tp: Tensor parallel group
- mp: Model parallel group (tp × pp)
- tp_ep_pp: Expert tensor-model-pipeline group
- expt_dp: Expert data parallel group
"""
pg = ProcessGroupCollection()

# Core groups needed by optimizer
pg.dp = grid.get_pg("dp")
pg.dp_cp = grid.get_pg(["dp", "cp"])
pg.tp = grid.get_pg("tp")
pg.mp = grid.get_pg(["tp", "pp"])

# Expert groups
pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"])
pg.expt_dp = grid.get_pg(["dp", "ep"])

# Distributed optimizer grad stats group: must span all dimensions so grad norm
# and found-inf all-reduces see every unique gradient shard. TP/PP/EP ranks hold
# different parameters, DP ranks hold different optimizer shards after reduce-scatter.
# This mirrors standard Megatron's intra_distributed_optimizer_instance_group which
# spans the full world when num_distributed_optimizer_instances == 1.
pg.intra_dist_opt = grid.get_pg(["tp", "cp", "ep", "pp", "dp"])

return pg


def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> MimoOptimizer:
"""Create optimizer for MimoModel with heterogeneous parallelism."""
from megatron.core.optimizer import get_megatron_optimizer

grid_map = mimo_model.mimo_config.module_to_grid_map
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

lang_key = MIMO_LANGUAGE_MODULE_KEY

module_infos: Dict[str, ModuleOptimizerInfo] = {}

for module_name, grid in grid_map.items():
is_active = grid.is_current_rank_in_grid()

optimizer = None
pg_collection = _get_pg_collection_for_optimizer(grid)

if is_active:
if module_name == lang_key:
module = mimo_model.language_model
else:
module = mimo_model.modality_submodules[module_name]

if module is not None:
assert (
not hasattr(module, 'ddp_config')
or module.ddp_config is None
or module.ddp_config.num_distributed_optimizer_instances == 1
), (
"MIMO optimizer does not yet support "
"num_distributed_optimizer_instances > 1. "
f"Module '{module_name}' has "
f"{module.ddp_config.num_distributed_optimizer_instances} instances."
)
optimizer = get_megatron_optimizer(
config=config,
model_chunks=[module],
pg_collection=pg_collection,
use_gloo_process_groups=False,
)

module_infos[module_name] = ModuleOptimizerInfo(
optimizer=optimizer, grid=grid, pg_collection=pg_collection, is_active=is_active
)

return MimoOptimizer(module_infos, config)
25 changes: 25 additions & 0 deletions tests/unit_tests/models/test_mimo_1f1b_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY
from megatron.core.models.mimo.model.base import MimoModel
from megatron.core.models.mimo.optimizer import get_mimo_optimizer
from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator
from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
Expand Down Expand Up @@ -73,6 +75,11 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1):
grid.create_pg(["dp", "cp"])
grid.create_pg(["ep"])
grid.create_pg(["expt_dp"])
# Required by _get_pg_collection_for_optimizer
grid.create_pg(["tp", "pp"])
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])
grid.create_pg(["tp", "cp", "ep", "pp", "dp"])
_active_grids.append(grid)
return grid

Expand Down Expand Up @@ -505,6 +512,17 @@ def finalize_grads_func(*args, **kwargs):
else loss
)

# Create optimizer
opt_config = OptimizerConfig(
optimizer='adam',
lr=1e-4,
weight_decay=0.01,
clip_grad=1.0,
bf16=True,
use_distributed_optimizer=True,
)
optimizer = get_mimo_optimizer(mimo_model, opt_config)

communicator = MultiModulePipelineCommunicator(
module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1}
)
Expand Down Expand Up @@ -557,6 +575,8 @@ def loss_func(loss_mask, output_tensor):
output_tensor, loss_mask = model(**batch)
return output_tensor, partial(loss_func, loss_mask)

optimizer.zero_grad()

losses = schedule.forward_backward_pipelining_without_interleaving(
forward_step_func=step_func,
data_iterator=data_iterator,
Expand All @@ -569,6 +589,11 @@ def loss_func(loss_mask, output_tensor):
pg_collection=pg_collection,
)

# Optimizer step with global gradient clipping
success, grad_norm, num_zeros = optimizer.step()
assert success, "Optimizer step failed"
assert grad_norm is not None and grad_norm > 0, f"Expected positive grad norm, got {grad_norm}"

# Verify results on last LLM stage
if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")):
assert len(losses) > 0, "Expected losses on last LLM stage"
Expand Down
Loading