diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 49e2fd42116..b1c12f521c3 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -6,12 +6,14 @@ import torch +from megatron.core.distributed import DistributedDataParallel from megatron.core.models.mimo.config import MimoModelConfig from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import unwrap_model logger = logging.getLogger(__name__) @@ -88,6 +90,44 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - self._initialize_submodules() self._initialize_language_model() + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Build sharded state dict, bypassing parallel_state global fallbacks. + + Iterates modality_submodules manually (ModuleDict lacks sharded_state_dict) + and injects dp_cp_group from each module's pg_collection. + """ + sharded_sd = {} + for name, module in self.named_children(): + if name == 'modality_submodules': + # Unwrap DDP, call ModalitySubmodules.sharded_state_dict directly + # (which injects dp_cp_group from its pg_collection) + for mod_name, mod in module.items(): + is_ddp = isinstance(mod, DistributedDataParallel) + inner = mod.module if is_ddp else mod + child_prefix = f'{prefix}{name}.{mod_name}.' + if is_ddp: + child_prefix += 'module.' + sharded_sd.update( + inner.sharded_state_dict(child_prefix, sharded_offsets, metadata) + ) + else: + # Inject dp_cp_group from pg_collection for language_model + inner = module.module if isinstance(module, DistributedDataParallel) else module + pg = getattr(inner, 'pg_collection', None) + mod_metadata = metadata + if pg is not None: + assert ( + hasattr(pg, 'dp_cp') and pg.dp_cp is not None + ), f"pg_collection on '{name}' is missing dp_cp group" + mod_metadata = dict(metadata) if metadata else {} + mod_metadata['dp_cp_group'] = pg.dp_cp + sharded_sd.update( + sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, mod_metadata + ) + ) + return sharded_sd + def align_embeddings_by_token_positions( self, modality_embeddings: Dict[str, torch.Tensor], # [num_embeddings, hidden_dim] diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index c59025cbce3..1a79c1f91ff 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -4,11 +4,13 @@ from __future__ import annotations +from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch +from megatron.core.dist_checkpointing.mapping import ShardedObject from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32 from megatron.core.optimizer.optimizer import MegatronOptimizer from megatron.core.optimizer.optimizer_config import OptimizerConfig @@ -136,17 +138,48 @@ def state_dict(self): } def load_state_dict(self, state_dict: Dict): + """Load per-module optimizer state dicts. + + Reassembles param_groups and grad_scaler that were extracted and saved + as ShardedObjects by sharded_state_dict(), then delegates to each + per-module optimizer's load_state_dict. + """ for name, info in self.module_infos.items(): - if info.is_active and info.optimizer and state_dict.get(name): - info.optimizer.load_state_dict(state_dict[name]) + if not (info.is_active and info.optimizer): + continue + module_sd = state_dict.get(name) + if module_sd is None: + continue + + for sub_sd, inner_opt in _iter_optimizer_sub_dicts(module_sd, info.optimizer): + _restore_param_groups(sub_sd, inner_opt, name) + _restore_grad_scaler(sub_sd) + + info.optimizer.load_state_dict(module_sd) def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): + """Build sharded state dict, routing param_groups and grad_scaler + through distributed save as ShardedObjects (common.pt is rank-0 only, + which misses LLM optimizer state in non-colocated mode). + """ sharded_state = {} for name, info in self.module_infos.items(): if info.is_active and info.optimizer: - sharded_state[name] = info.optimizer.sharded_state_dict( + module_sd = info.optimizer.sharded_state_dict( model_sharded_state_dict, is_loading, **kwargs ) + replica_id = _get_replica_id(info.pg_collection) + + for idx, (sub_sd, _) in enumerate( + _iter_optimizer_sub_dicts(module_sd, info.optimizer) + ): + suffix = f'.{idx}' if idx > 0 else '' + _extract_param_groups(sub_sd, name, suffix, replica_id) + _extract_grad_scaler(sub_sd, name, suffix, replica_id) + + sharded_state[name] = module_sd + else: + sharded_state[name] = {} return sharded_state def reload_model_params(self, state_dict=None): @@ -154,6 +187,99 @@ def reload_model_params(self, state_dict=None): opt.reload_model_params(state_dict) +def _iter_optimizer_sub_dicts(module_sd, optimizer): + """Yield (sub_state_dict, inner_optimizer) pairs. + + For a single optimizer, yields (module_sd, optimizer) once. + For ChainedOptimizer with N>1 inner optimizers, yields + (module_sd[i], chained_optimizers[i]) for each. + """ + from megatron.core.optimizer.optimizer import ChainedOptimizer + + if isinstance(optimizer, ChainedOptimizer) and len(optimizer.chained_optimizers) > 1: + for idx, inner_opt in enumerate(optimizer.chained_optimizers): + yield module_sd[idx], inner_opt + else: + yield module_sd, optimizer + + +def _extract_param_groups(sub_sd, module_name, suffix, replica_id): + """Save: extract param_groups from optimizer sub-dict into a ShardedObject.""" + opt_sub = sub_sd.get('optimizer') + if isinstance(opt_sub, dict) and 'param_groups' in opt_sub: + pg = deepcopy(opt_sub['param_groups']) + for group in pg: + group['params'] = [] + sub_sd[f'_mimo_param_groups{suffix}'] = ShardedObject( + f'optimizer.mimo.{module_name}{suffix}.param_groups', + pg, + (1,), + (0,), + replica_id=replica_id, + ) + del opt_sub['param_groups'] + + +def _extract_grad_scaler(sub_sd, module_name, suffix, replica_id): + """Save: extract grad_scaler into a ShardedObject.""" + if 'grad_scaler' in sub_sd and sub_sd['grad_scaler'] is not None: + sub_sd[f'_mimo_grad_scaler{suffix}'] = ShardedObject( + f'optimizer.mimo.{module_name}{suffix}.grad_scaler', + sub_sd.pop('grad_scaler'), + (1,), + (0,), + replica_id=replica_id, + ) + + +def _restore_param_groups(sub_sd, inner_optimizer, module_name): + """Load: restore param_groups with current param IDs from the inner optimizer.""" + # Find the _mimo_param_groups key (may have a suffix for chained optimizers) + pg_key = None + for k in list(sub_sd.keys()): + if k.startswith('_mimo_param_groups'): + pg_key = k + break + if pg_key is None: + return + + loaded_pg = sub_sd.pop(pg_key) + # Get current param IDs from the inner torch optimizer's state_dict + current_pg = inner_optimizer.optimizer.state_dict()['param_groups'] + if len(loaded_pg) != len(current_pg): + raise ValueError( + f"Optimizer '{module_name}': checkpoint has {len(loaded_pg)} param_groups " + f"but current optimizer has {len(current_pg)}" + ) + for loaded_g, current_g in zip(loaded_pg, current_pg): + loaded_g['params'] = current_g['params'] + sub_sd['optimizer']['param_groups'] = loaded_pg + + +def _restore_grad_scaler(sub_sd): + """Load: restore grad_scaler from ShardedObject key.""" + for k in list(sub_sd.keys()): + if k.startswith('_mimo_grad_scaler'): + sub_sd['grad_scaler'] = sub_sd.pop(k) + break + + +def _get_replica_id(pg_collection: Optional[ProcessGroupCollection]) -> tuple: + """Build replica_id tuple for ShardedObject deduplication. + + Includes pp_rank so only one PP stage writes the metadata, + and dp_rank so only dp_rank=0 writes (others are replicas). + """ + assert pg_collection is not None, "pg_collection required for checkpoint replica_id" + assert ( + hasattr(pg_collection, 'pp') and pg_collection.pp is not None + ), "pg_collection.pp must be set for checkpoint deduplication" + assert ( + hasattr(pg_collection, 'dp') and pg_collection.dp is not None + ), "pg_collection.dp must be set for checkpoint deduplication" + return (0, pg_collection.pp.rank(), pg_collection.dp.rank()) + + def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: """Create ProcessGroupCollection from HyperCommGrid for optimizer use. @@ -164,6 +290,7 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: grid.create_pg(["dp"]) grid.create_pg(["dp", "cp"]) grid.create_pg(["tp"]) + grid.create_pg(["pp"]) grid.create_pg(["tp", "pp"]) grid.create_pg(["tp", "ep", "pp"]) grid.create_pg(["dp", "ep"]) @@ -183,10 +310,11 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: """ pg = ProcessGroupCollection() - # Core groups needed by optimizer + # Core groups needed by optimizer and checkpointing pg.dp = grid.get_pg("dp") pg.dp_cp = grid.get_pg(["dp", "cp"]) pg.tp = grid.get_pg("tp") + pg.pp = grid.get_pg("pp") pg.mp = grid.get_pg(["tp", "pp"]) # Expert groups diff --git a/megatron/core/models/mimo/submodules/base.py b/megatron/core/models/mimo/submodules/base.py index 58f61f81d3c..3b54fd737f2 100644 --- a/megatron/core/models/mimo/submodules/base.py +++ b/megatron/core/models/mimo/submodules/base.py @@ -9,6 +9,7 @@ import torch.nn as nn from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.utils import sharded_state_dict_default # Initialize logger logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ def __init__( output_projections: Optional[List[nn.Module]] = None, is_first_stage: bool = True, is_last_stage: bool = True, + pg_collection=None, **kwargs, ) -> None: """Initialize the modality submodules. @@ -55,14 +57,14 @@ def __init__( output_projections: List of output projection modules is_first_stage: Whether this is the first PP stage for this module is_last_stage: Whether this is the last PP stage for this module + pg_collection: Process group collection for this module """ super().__init__() self.encoders = nn.ModuleDict(encoders or {}) self.decoders = nn.ModuleDict(decoders or {}) self.input_projections = nn.ModuleList(input_projections or []) self.output_projections = nn.ModuleList(output_projections or []) - - # Stage info for multi-module pipeline parallelism (immutable after init) + self.pg_collection = pg_collection self._is_first_stage: bool = is_first_stage self._is_last_stage: bool = is_last_stage @@ -73,6 +75,29 @@ def __init__( stacklevel=2, ) + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Iterate into ModuleDict/ModuleList children for TP-aware checkpointing. + + Injects dp_cp_group from pg_collection into metadata to avoid + parallel_state global fallback in ensure_metadata_has_dp_cp_group. + """ + if self.pg_collection is not None: + assert ( + hasattr(self.pg_collection, 'dp_cp') and self.pg_collection.dp_cp is not None + ), "pg_collection is missing dp_cp group" + metadata = dict(metadata) if metadata else {} + metadata['dp_cp_group'] = self.pg_collection.dp_cp + + sharded_sd = {} + for name, container in self.named_children(): + for sub_name, module in container.named_children(): + sharded_sd.update( + sharded_state_dict_default( + module, f'{prefix}{name}.{sub_name}.', sharded_offsets, metadata + ) + ) + return sharded_sd + @property def is_first_stage(self) -> bool: """Whether this is the first pipeline stage for this module.""" diff --git a/megatron/core/models/vision/multimodal_projector.py b/megatron/core/models/vision/multimodal_projector.py index ab4d2f8cd41..cb54891228e 100644 --- a/megatron/core/models/vision/multimodal_projector.py +++ b/megatron/core/models/vision/multimodal_projector.py @@ -4,11 +4,12 @@ import torch from megatron.core.fp8_utils import get_fp8_context +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module, not_none -from megatron.core.utils import make_viewless_tensor +from megatron.core.utils import get_tensor_model_parallel_group_if_none, make_viewless_tensor class MultimodalProjector(MegatronModule): @@ -32,9 +33,12 @@ def __init__( projector_type: str, input_size: int, tp_group: Optional[torch.distributed.ProcessGroup] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ): super().__init__(config=config) self.projector_type = projector_type + tp_group = pg_collection.tp if pg_collection is not None else tp_group + self.tp_group = get_tensor_model_parallel_group_if_none(tp_group) assert submodules is not None, "MLPSubmodules must be provided" diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 930861868d0..4dbf2f39ca0 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -284,6 +284,7 @@ def get_vision_submodules_spec( return ModuleSpec( module=VisionModalitySubmodules, + params={"pg_collection": pg_collection}, submodules={ "encoders": {"clip_encoder": vision_encoder_spec}, "input_projections": [vision_projection_spec], diff --git a/tests/unit_tests/models/test_mimo_checkpoint.py b/tests/unit_tests/models/test_mimo_checkpoint.py new file mode 100644 index 00000000000..3dc75a05a87 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_checkpoint.py @@ -0,0 +1,273 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Integration tests for MIMO distributed checkpoint save/load in non-colocated mode. + +Run with 8 GPUs: + uv run python -m torch.distributed.run --nproc-per-node=8 \ + -m pytest tests/unit_tests/models/test_mimo_checkpoint.py -v -s +""" + +import os +import shutil +import tempfile + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +from megatron.core.dist_checkpointing import load, save +from megatron.core.dist_checkpointing.validation import StrictHandling +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + create_all_embedding_groups, + create_hypercomm_grid, + destroy_all_grids, + get_mimo_model, + is_rank_in_grid, +) +from tests.unit_tests.test_utilities import Utils + +ENCODER_NAME = "images" + + +def _get_shared_tmpdir(): + """Create a shared temp directory across all ranks.""" + tmpdir_list = [None] + if dist.get_rank() == 0: + tmpdir_list[0] = tempfile.mkdtemp(prefix="mimo_ckpt_test_") + dist.broadcast_object_list(tmpdir_list, src=0) + return tmpdir_list[0] + + +def _cleanup_tmpdir(tmpdir): + """Clean up temp directory (rank 0 only).""" + dist.barrier() + if dist.get_rank() == 0: + shutil.rmtree(tmpdir, ignore_errors=True) + + +def _randomize_params(model, seed): + """Set all model parameters to deterministic random values.""" + torch.manual_seed(seed) + with torch.no_grad(): + for p in model.parameters(): + p.random_() + + +def _create_model_and_optimizer(encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seed): + """Create MIMO model with DDP + optimizer, do a fake step to populate optimizer state. + + Caller must call create_all_embedding_groups() before this function. + """ + torch.manual_seed(seed) + + mimo_model, _, _, _, _ = get_mimo_model( + encoder_name=ENCODER_NAME, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=64, + ) + _randomize_params(mimo_model, seed) + + # Use Float16Optimizer (not DistributedOptimizer) to exercise the MIMO-specific + # param_groups/grad_scaler extraction in sharded_state_dict. DistributedOptimizer + # handles its own checkpointing internally and our code is transparent to it. + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=False, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + + # Fake backward + step to populate optimizer state (Adam m/v) + for param in mimo_model.parameters(): + param.grad = torch.randn_like(param) + optimizer.step() + + return mimo_model, optimizer + + +def run_checkpoint_test( + encoder_tp, + encoder_pp, + encoder_dp, + encoder_offset, + llm_tp, + llm_pp, + llm_dp, + llm_offset, + hidden_size=256, + num_layers=2, + vocab_size=1000, +): + """Save model + optimizer checkpoint, load into fresh instances, verify match.""" + # Clear NVTE env vars that the conftest set_env fixture sets to '0'. + # GPTModel (LanguageModule) asserts these are unset or match the attention backend. + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + encoder_grid = create_hypercomm_grid( + offset=encoder_offset, tp=encoder_tp, cp=1, pp=encoder_pp, dp=encoder_dp + ) + llm_grid = create_hypercomm_grid(offset=llm_offset, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + create_all_embedding_groups([encoder_grid, llm_grid]) + + # --- Create model A + optimizer, snapshot state --- + model_a, optimizer_a = _create_model_and_optimizer( + encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seed=1 + ) + params_a = {name: p.clone() for name, p in model_a.named_parameters()} + + ckpt_dir = _get_shared_tmpdir() + try: + model_ckpt = os.path.join(ckpt_dir, 'model') + optim_ckpt = os.path.join(ckpt_dir, 'optimizer') + if dist.get_rank() == 0: + os.makedirs(model_ckpt) + os.makedirs(optim_ckpt) + dist.barrier() + + # Save model + save(model_a.sharded_state_dict(), model_ckpt) + + # Save optimizer (needs fresh model sharded_state_dict since save() consumes tensor refs) + optim_sd_a = optimizer_a.sharded_state_dict(model_a.sharded_state_dict(), is_loading=False) + save(optim_sd_a, optim_ckpt, validate_access_integrity=False) + + dist.barrier() + + # --- Create model B + optimizer with different weights (reuse same grids) --- + model_b, optimizer_b = _create_model_and_optimizer( + encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seed=2 + ) + + # Load model + model_sd_b = model_b.sharded_state_dict() + loaded_model_sd, missing, unexpected = load( + model_sd_b, model_ckpt, strict=StrictHandling.RETURN_ALL + ) + real_missing = [k for k in missing if '_extra_state' not in k] + real_unexpected = [k for k in unexpected if '_extra_state' not in k] + assert not real_missing, f"Missing keys: {real_missing}" + assert not real_unexpected, f"Unexpected keys: {real_unexpected}" + model_b.load_state_dict(loaded_model_sd) + + # Load optimizer + optim_sd_b = optimizer_b.sharded_state_dict(model_b.sharded_state_dict(), is_loading=True) + loaded_optim_sd = load(optim_sd_b, optim_ckpt, validate_access_integrity=False) + optimizer_b.load_state_dict(loaded_optim_sd) + + # --- Verify model params match --- + mismatches = [ + name + for name, p in model_b.named_parameters() + if name in params_a and not torch.equal(p, params_a[name]) + ] + assert not mismatches, f"Model param mismatch after load: {mismatches}" + + # --- Verify optimizer state matches (param_groups + Adam m/v tensors) --- + for name, info_b in optimizer_b.module_infos.items(): + if not (info_b.is_active and info_b.optimizer): + continue + info_a = optimizer_a.module_infos[name] + sd_a = info_a.optimizer.state_dict() + sd_b = info_b.optimizer.state_dict() + + # Verify param_groups + pg_a = sd_a.get('optimizer', {}).get('param_groups', []) + pg_b = sd_b.get('optimizer', {}).get('param_groups', []) + assert len(pg_a) == len(pg_b), f"Optimizer {name}: param_groups count mismatch" + for i, (ga, gb) in enumerate(zip(pg_a, pg_b)): + assert ga['lr'] == gb['lr'], f"Optimizer {name} group[{i}]: lr mismatch" + + # Verify Adam state tensors (exp_avg, exp_avg_sq) + state_a = sd_a.get('optimizer', {}).get('state', {}) + state_b = sd_b.get('optimizer', {}).get('state', {}) + for param_id in state_a: + if param_id not in state_b: + continue + for key in ('exp_avg', 'exp_avg_sq'): + if key in state_a[param_id] and key in state_b[param_id]: + assert torch.equal( + state_a[param_id][key], state_b[param_id][key] + ), f"Optimizer {name} param {param_id} {key} mismatch" + + finally: + _cleanup_tmpdir(ckpt_dir) + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimoCheckpoint: + """Distributed checkpoint save/load tests for non-colocated MiMo (8 GPUs).""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_encoder_tp2_llm_tp2_pp3(self): + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_checkpoint_test( + encoder_tp=2, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=3, + llm_dp=1, + llm_offset=2, + hidden_size=256, + num_layers=3, + ) + + def test_encoder_tp1_llm_pp7(self): + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_checkpoint_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=7, + llm_dp=1, + llm_offset=1, + hidden_size=256, + num_layers=7, + ) + + def test_encoder_tp2_pp2_llm_tp2_pp2(self): + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_checkpoint_test( + encoder_tp=2, + encoder_pp=2, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=2, + llm_dp=1, + llm_offset=4, + hidden_size=256, + num_layers=2, + )