Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions megatron/core/models/mimo/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import torch

from megatron.core.distributed import DistributedDataParallel
from megatron.core.models.mimo.config import MimoModelConfig
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole
from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.utils import sharded_state_dict_default
from megatron.core.utils import unwrap_model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,6 +90,44 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) -
self._initialize_submodules()
self._initialize_language_model()

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Build sharded state dict, bypassing parallel_state global fallbacks.

Iterates modality_submodules manually (ModuleDict lacks sharded_state_dict)
and injects dp_cp_group from each module's pg_collection.
"""
sharded_sd = {}
for name, module in self.named_children():
if name == 'modality_submodules':
# Unwrap DDP, call ModalitySubmodules.sharded_state_dict directly
# (which injects dp_cp_group from its pg_collection)
for mod_name, mod in module.items():
is_ddp = isinstance(mod, DistributedDataParallel)
inner = mod.module if is_ddp else mod
child_prefix = f'{prefix}{name}.{mod_name}.'
if is_ddp:
child_prefix += 'module.'
sharded_sd.update(
inner.sharded_state_dict(child_prefix, sharded_offsets, metadata)
)
else:
# Inject dp_cp_group from pg_collection for language_model
inner = module.module if isinstance(module, DistributedDataParallel) else module
pg = getattr(inner, 'pg_collection', None)
mod_metadata = metadata
if pg is not None:
assert (
hasattr(pg, 'dp_cp') and pg.dp_cp is not None
), f"pg_collection on '{name}' is missing dp_cp group"
mod_metadata = dict(metadata) if metadata else {}
mod_metadata['dp_cp_group'] = pg.dp_cp
sharded_sd.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, mod_metadata
)
)
return sharded_sd

def align_embeddings_by_token_positions(
self,
modality_embeddings: Dict[str, torch.Tensor], # [num_embeddings, hidden_dim]
Expand Down
136 changes: 132 additions & 4 deletions megatron/core/models/mimo/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import torch

from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32
from megatron.core.optimizer.optimizer import MegatronOptimizer
from megatron.core.optimizer.optimizer_config import OptimizerConfig
Expand Down Expand Up @@ -136,24 +138,148 @@ def state_dict(self):
}

def load_state_dict(self, state_dict: Dict):
"""Load per-module optimizer state dicts.

Reassembles param_groups and grad_scaler that were extracted and saved
as ShardedObjects by sharded_state_dict(), then delegates to each
per-module optimizer's load_state_dict.
"""
for name, info in self.module_infos.items():
if info.is_active and info.optimizer and state_dict.get(name):
info.optimizer.load_state_dict(state_dict[name])
if not (info.is_active and info.optimizer):
continue
module_sd = state_dict.get(name)
if module_sd is None:
continue

for sub_sd, inner_opt in _iter_optimizer_sub_dicts(module_sd, info.optimizer):
_restore_param_groups(sub_sd, inner_opt, name)
_restore_grad_scaler(sub_sd)

info.optimizer.load_state_dict(module_sd)

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
"""Build sharded state dict, routing param_groups and grad_scaler
through distributed save as ShardedObjects (common.pt is rank-0 only,
which misses LLM optimizer state in non-colocated mode).
"""
sharded_state = {}
for name, info in self.module_infos.items():
if info.is_active and info.optimizer:
sharded_state[name] = info.optimizer.sharded_state_dict(
module_sd = info.optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading, **kwargs
)
replica_id = _get_replica_id(info.pg_collection)

for idx, (sub_sd, _) in enumerate(
_iter_optimizer_sub_dicts(module_sd, info.optimizer)
):
suffix = f'.{idx}' if idx > 0 else ''
_extract_param_groups(sub_sd, name, suffix, replica_id)
_extract_grad_scaler(sub_sd, name, suffix, replica_id)

sharded_state[name] = module_sd
else:
sharded_state[name] = {}
return sharded_state

def reload_model_params(self, state_dict=None):
for opt in self._active_optimizers:
opt.reload_model_params(state_dict)


def _iter_optimizer_sub_dicts(module_sd, optimizer):
"""Yield (sub_state_dict, inner_optimizer) pairs.

For a single optimizer, yields (module_sd, optimizer) once.
For ChainedOptimizer with N>1 inner optimizers, yields
(module_sd[i], chained_optimizers[i]) for each.
"""
from megatron.core.optimizer.optimizer import ChainedOptimizer

if isinstance(optimizer, ChainedOptimizer) and len(optimizer.chained_optimizers) > 1:
for idx, inner_opt in enumerate(optimizer.chained_optimizers):
yield module_sd[idx], inner_opt
else:
yield module_sd, optimizer


def _extract_param_groups(sub_sd, module_name, suffix, replica_id):
"""Save: extract param_groups from optimizer sub-dict into a ShardedObject."""
opt_sub = sub_sd.get('optimizer')
if isinstance(opt_sub, dict) and 'param_groups' in opt_sub:
pg = deepcopy(opt_sub['param_groups'])
for group in pg:
group['params'] = []
sub_sd[f'_mimo_param_groups{suffix}'] = ShardedObject(
f'optimizer.mimo.{module_name}{suffix}.param_groups',
pg,
(1,),
(0,),
replica_id=replica_id,
)
del opt_sub['param_groups']


def _extract_grad_scaler(sub_sd, module_name, suffix, replica_id):
"""Save: extract grad_scaler into a ShardedObject."""
if 'grad_scaler' in sub_sd and sub_sd['grad_scaler'] is not None:
sub_sd[f'_mimo_grad_scaler{suffix}'] = ShardedObject(
f'optimizer.mimo.{module_name}{suffix}.grad_scaler',
sub_sd.pop('grad_scaler'),
(1,),
(0,),
replica_id=replica_id,
)


def _restore_param_groups(sub_sd, inner_optimizer, module_name):
"""Load: restore param_groups with current param IDs from the inner optimizer."""
# Find the _mimo_param_groups key (may have a suffix for chained optimizers)
pg_key = None
for k in list(sub_sd.keys()):
if k.startswith('_mimo_param_groups'):
pg_key = k
break
if pg_key is None:
return

loaded_pg = sub_sd.pop(pg_key)
# Get current param IDs from the inner torch optimizer's state_dict
current_pg = inner_optimizer.optimizer.state_dict()['param_groups']
if len(loaded_pg) != len(current_pg):
raise ValueError(
f"Optimizer '{module_name}': checkpoint has {len(loaded_pg)} param_groups "
f"but current optimizer has {len(current_pg)}"
)
for loaded_g, current_g in zip(loaded_pg, current_pg):
loaded_g['params'] = current_g['params']
sub_sd['optimizer']['param_groups'] = loaded_pg


def _restore_grad_scaler(sub_sd):
"""Load: restore grad_scaler from ShardedObject key."""
for k in list(sub_sd.keys()):
if k.startswith('_mimo_grad_scaler'):
sub_sd['grad_scaler'] = sub_sd.pop(k)
break


def _get_replica_id(pg_collection: Optional[ProcessGroupCollection]) -> tuple:
"""Build replica_id tuple for ShardedObject deduplication.

Includes pp_rank so only one PP stage writes the metadata,
and dp_rank so only dp_rank=0 writes (others are replicas).
"""
assert pg_collection is not None, "pg_collection required for checkpoint replica_id"
assert (
hasattr(pg_collection, 'pp') and pg_collection.pp is not None
), "pg_collection.pp must be set for checkpoint deduplication"
assert (
hasattr(pg_collection, 'dp') and pg_collection.dp is not None
), "pg_collection.dp must be set for checkpoint deduplication"
return (0, pg_collection.pp.rank(), pg_collection.dp.rank())


def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection:
"""Create ProcessGroupCollection from HyperCommGrid for optimizer use.

Expand All @@ -164,6 +290,7 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection:
grid.create_pg(["dp"])
grid.create_pg(["dp", "cp"])
grid.create_pg(["tp"])
grid.create_pg(["pp"])
grid.create_pg(["tp", "pp"])
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])
Expand All @@ -183,10 +310,11 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection:
"""
pg = ProcessGroupCollection()

# Core groups needed by optimizer
# Core groups needed by optimizer and checkpointing
pg.dp = grid.get_pg("dp")
pg.dp_cp = grid.get_pg(["dp", "cp"])
pg.tp = grid.get_pg("tp")
pg.pp = grid.get_pg("pp")
pg.mp = grid.get_pg(["tp", "pp"])

# Expert groups
Expand Down
29 changes: 27 additions & 2 deletions megatron/core/models/mimo/submodules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.utils import sharded_state_dict_default

# Initialize logger
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
output_projections: Optional[List[nn.Module]] = None,
is_first_stage: bool = True,
is_last_stage: bool = True,
pg_collection=None,
**kwargs,
) -> None:
"""Initialize the modality submodules.
Expand All @@ -55,14 +57,14 @@ def __init__(
output_projections: List of output projection modules
is_first_stage: Whether this is the first PP stage for this module
is_last_stage: Whether this is the last PP stage for this module
pg_collection: Process group collection for this module
"""
super().__init__()
self.encoders = nn.ModuleDict(encoders or {})
self.decoders = nn.ModuleDict(decoders or {})
self.input_projections = nn.ModuleList(input_projections or [])
self.output_projections = nn.ModuleList(output_projections or [])

# Stage info for multi-module pipeline parallelism (immutable after init)
self.pg_collection = pg_collection
self._is_first_stage: bool = is_first_stage
self._is_last_stage: bool = is_last_stage

Expand All @@ -73,6 +75,29 @@ def __init__(
stacklevel=2,
)

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Iterate into ModuleDict/ModuleList children for TP-aware checkpointing.

Injects dp_cp_group from pg_collection into metadata to avoid
parallel_state global fallback in ensure_metadata_has_dp_cp_group.
"""
if self.pg_collection is not None:
assert (
hasattr(self.pg_collection, 'dp_cp') and self.pg_collection.dp_cp is not None
), "pg_collection is missing dp_cp group"
metadata = dict(metadata) if metadata else {}
metadata['dp_cp_group'] = self.pg_collection.dp_cp

sharded_sd = {}
for name, container in self.named_children():
for sub_name, module in container.named_children():
sharded_sd.update(
sharded_state_dict_default(
module, f'{prefix}{name}.{sub_name}.', sharded_offsets, metadata
)
)
return sharded_sd

@property
def is_first_stage(self) -> bool:
"""Whether this is the first pipeline stage for this module."""
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/models/vision/multimodal_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import torch

from megatron.core.fp8_utils import get_fp8_context
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.typed_torch import apply_module, not_none
from megatron.core.utils import make_viewless_tensor
from megatron.core.utils import get_tensor_model_parallel_group_if_none, make_viewless_tensor


class MultimodalProjector(MegatronModule):
Expand All @@ -32,9 +33,12 @@ def __init__(
projector_type: str,
input_size: int,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
):
super().__init__(config=config)
self.projector_type = projector_type
tp_group = pg_collection.tp if pg_collection is not None else tp_group
self.tp_group = get_tensor_model_parallel_group_if_none(tp_group)

assert submodules is not None, "MLPSubmodules must be provided"

Expand Down
Loading