diff --git a/src/megatron/bridge/data/mimo/__init__.py b/src/megatron/bridge/data/mimo/__init__.py index 045d87152e..8713d2c68c 100644 --- a/src/megatron/bridge/data/mimo/__init__.py +++ b/src/megatron/bridge/data/mimo/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """MIMO multi-encoder data loading utilities.""" -from megatron.bridge.data.mimo.dataset import MimoDataset -from megatron.bridge.data.mimo.collate import mimo_collate_fn -from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info -from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders - # Providers from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider +from megatron.bridge.data.mimo.collate import mimo_collate_fn +from megatron.bridge.data.mimo.dataset import MimoDataset +from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info, get_mimo_sampling_info, slice_batch_for_mimo from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider +from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders from megatron.bridge.data.mimo.mock_provider import MockMimoProvider + __all__ = [ # Core "MimoDataset", @@ -21,5 +21,7 @@ "MockMimoProvider", # Utilities "get_mimo_dp_info", + "get_mimo_sampling_info", + "slice_batch_for_mimo", "build_mimo_data_loaders", ] diff --git a/src/megatron/bridge/data/mimo/dp_utils.py b/src/megatron/bridge/data/mimo/dp_utils.py index c6d867a4c4..d7565646e2 100644 --- a/src/megatron/bridge/data/mimo/dp_utils.py +++ b/src/megatron/bridge/data/mimo/dp_utils.py @@ -3,8 +3,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Any, Dict, Tuple +import torch import torch.distributed as dist from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY @@ -15,61 +16,160 @@ from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig +def _find_rank_module( + grids: Dict[str, "HyperCommGrid"], +) -> Tuple["HyperCommGrid | None", "str | None"]: + """Find which module grid the current rank belongs to.""" + current_rank = dist.get_rank() + for module_name, grid in grids.items(): + if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): + return grid, module_name + return None, None + + +def _needs_data_for_module(grid: "HyperCommGrid", module_name: str) -> bool: + """Determine if the current rank needs to load data for the given module. + + LLM: first and last PP stage need data (input_ids and labels respectively). + Encoders: only the first PP stage needs raw modality inputs. + """ + pp_group = grid.get_pg(["pp"]) + pp_rank = pp_group.rank() + pp_size = pp_group.size() + if module_name == MIMO_LANGUAGE_MODULE_KEY: + return (pp_rank == 0) or (pp_rank == pp_size - 1) + return pp_rank == 0 + + def get_mimo_dp_info( mimo_cfg: "MimoParallelismConfig", grids: Dict[str, "HyperCommGrid"], ) -> Tuple[int, int, bool, str]: - """Get DP rank, size, data-loading responsibility, and loader module for MIMO. + """Get **module-local** DP rank, size, data-loading flag, and module name. - Determines which module's DP settings to use for data loading based on - current rank's participation in heterogeneous deployment. + Returns the DP settings for the module that the current rank participates + in. These are used by :func:`slice_batch_for_mimo` to sub-shard a global + micro-batch into per-module DP shards. - In heterogeneous mode, each rank uses its own module's DP settings. + .. note:: + Do **not** use these values to construct a ``DistributedSampler``. + For sampler construction use :func:`get_mimo_sampling_info` instead, + which returns settings that keep all data-loading ranks synchronised + on the same sample order. Args: mimo_cfg: MIMO parallelism configuration. grids: Module name to HyperCommGrid mapping from build_hypercomm_grids(). Returns: - Tuple of (dp_rank, dp_size, needs_data, loader_module): - - dp_rank: This rank's position in DP group. - - dp_size: Size of DP group for data sharding. - - needs_data: Whether this rank needs to load data (first/last PP stage). - - loader_module: Which module's DP settings are being used. - - Example: - >>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids - >>> grids = build_hypercomm_grids(mimo_cfg) - >>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - >>> if needs_data: - ... # Build data loader with dp_rank and dp_size - ... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank) + Tuple of (dp_rank, dp_size, needs_data, loader_module). """ - current_rank = dist.get_rank() - - # Heterogeneous: find which module this rank belongs to - my_grid = None - my_module = None - for module_name, grid in grids.items(): - if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): - my_grid = grid - my_module = module_name - break - + my_grid, my_module = _find_rank_module(grids) if my_grid is None or my_module is None: - # Rank doesn't participate in any module return 0, 1, False, MIMO_LANGUAGE_MODULE_KEY dp_rank = my_grid.get_pg(["dp"]).rank() dp_size = my_grid.get_pg(["dp"]).size() + needs_data = _needs_data_for_module(my_grid, my_module) + return dp_rank, dp_size, needs_data, my_module - pp_group = my_grid.get_pg(["pp"]) - pp_rank = pp_group.rank() - pp_size = pp_group.size() - if my_module == MIMO_LANGUAGE_MODULE_KEY: - needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1) - else: - needs_data = pp_rank == 0 +def get_mimo_sampling_info( + mimo_cfg: "MimoParallelismConfig", + grids: Dict[str, "HyperCommGrid"], +) -> Tuple[int, int, bool]: + """Get sampler DP rank, size, and data-loading flag for MIMO. - return dp_rank, dp_size, needs_data, my_module + In heterogeneous MIMO, modules may have different DP sizes. The data + loader must give every data-loading rank the **same global micro-batch** + so that :func:`slice_batch_for_mimo` (called in the forward step) can + sub-shard it consistently with the :class:`BridgeCommunicator` fan-in / + fan-out routing. + + This function therefore returns ``dp_size=1, dp_rank=0`` for all ranks, + disabling DP sharding at the sampler level. Per-module DP sharding is + deferred to :func:`slice_batch_for_mimo`. + + Args: + mimo_cfg: MIMO parallelism configuration. + grids: Module name to HyperCommGrid mapping. + + Returns: + Tuple of (sampler_dp_rank, sampler_dp_size, needs_data). + """ + my_grid, my_module = _find_rank_module(grids) + if my_grid is None or my_module is None: + return 0, 1, False + + needs_data = _needs_data_for_module(my_grid, my_module) + # All data-loading ranks use the same sampler settings so they load + # identical global micro-batches. Module-local DP slicing happens later + # in forward_step via slice_batch_for_mimo. + return 0, 1, needs_data + + +def slice_batch_for_mimo( + batch: Dict[str, Any], + dp_rank: int, + dp_size: int, +) -> Dict[str, Any]: + """Slice a global micro-batch for this rank's module-local DP shard. + + All data-loading ranks receive the same global micro-batch (the sampler + uses ``dp_size=1``). This function contiguously slices it so that each + module-local DP replica processes the correct subset. The slicing is + contiguous to match the :class:`BridgeCommunicator`'s batch-dimension + split / concatenate logic for fan-out and fan-in routing. + + Handles nested dicts (e.g. ``modality_inputs``) by recursing. + + Args: + batch: Global batch dictionary with tensors of shape [global_batch, ...]. + May contain nested dicts (e.g. modality_inputs → encoder → kwargs). + dp_rank: This rank's position in its **module-local** DP group. + dp_size: Size of the module-local DP group. + + Returns: + Dict with tensors sliced to shape [global_batch // dp_size, ...]. + + Example: + >>> global_batch = {'tokens': torch.randn(12, 2048)} + >>> local_batch = slice_batch_for_mimo(global_batch, dp_rank=1, dp_size=3) + >>> local_batch['tokens'].shape # torch.Size([4, 2048]) + """ + if dp_size == 1: + return batch + + sliced = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + # Slice along batch dimension (dim=0) + batch_size = value.size(0) + if batch_size % dp_size != 0: + raise ValueError( + f"Batch size {batch_size} for key '{key}' is not divisible " + f"by DP size {dp_size}. Ensure micro_batch_size is divisible " + f"by every module's data_parallel_size." + ) + local_batch_size = batch_size // dp_size + start_idx = dp_rank * local_batch_size + end_idx = start_idx + local_batch_size + sliced[key] = value[start_idx:end_idx] + elif isinstance(value, dict): + # Recurse into nested dicts (e.g. modality_inputs) + sliced[key] = slice_batch_for_mimo(value, dp_rank, dp_size) + elif isinstance(value, list) and len(value) > 0: + list_len = len(value) + if list_len % dp_size == 0: + local_len = list_len // dp_size + start_idx = dp_rank * local_len + end_idx = start_idx + local_len + sliced[key] = value[start_idx:end_idx] + else: + # Keep as-is if not evenly divisible (global metadata) + sliced[key] = value + else: + # Keep non-tensor, non-list values as-is + sliced[key] = value + + return sliced diff --git a/src/megatron/bridge/data/mimo/loaders.py b/src/megatron/bridge/data/mimo/loaders.py index 576f114505..5e61ca2d06 100644 --- a/src/megatron/bridge/data/mimo/loaders.py +++ b/src/megatron/bridge/data/mimo/loaders.py @@ -3,15 +3,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch from torch.utils.data import DataLoader -from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info +from megatron.bridge.data.mimo.dp_utils import get_mimo_sampling_info from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider from megatron.bridge.utils.common_utils import print_rank_0 + if TYPE_CHECKING: from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.state import TrainState @@ -25,12 +26,14 @@ def build_mimo_data_loaders( valid_samples: int, test_samples: int, ) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]: - """Build MIMO data loaders with per-module DP settings. - - Creates data loaders with DP-aware sampling based on the MIMO parallelism - configuration. Only ranks that need data (first/last PP stage) will get - non-None loaders. - + """Build MIMO data loaders with globally consistent sampling. + + All data-loading ranks receive identical global micro-batches (the sampler + uses dp_size=1). Per-module DP sub-sharding is deferred to + ``slice_batch_for_mimo`` in the forward step, ensuring consistency with + the BridgeCommunicator's fan-in/fan-out routing for asymmetric DP configs. + Only ranks that need data (first/last PP stage) will get non-None loaders. + Args: cfg: Configuration container with MimoModelProvider as cfg.model. train_state: Current training state. @@ -39,14 +42,14 @@ def build_mimo_data_loaders( train_samples: Number of training samples. valid_samples: Number of validation samples. test_samples: Number of test samples. - + Returns: Tuple of (train_loader, valid_loader, test_loader). Returns (None, None, None) if this rank doesn't need data. - + Raises: ValueError: If cfg.model is not MimoModelProvider or mimo_parallelism_config is None. - + Example: >>> from megatron.bridge.data.mimo import MockMimoProvider, build_mimo_data_loaders >>> provider = MockMimoProvider( @@ -62,27 +65,36 @@ def build_mimo_data_loaders( ... ) """ from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider - + if not isinstance(cfg.model, MimoModelProvider): raise ValueError("cfg.model must be MimoModelProvider for MIMO data loading.") - + if cfg.model.mimo_parallelism_config is None: raise ValueError("mimo_parallelism_config must be set for MIMO data loading.") - + if cfg.model._grids is None: raise ValueError( - "MimoModelProvider._grids is None. " - "Ensure build_model() is called before building data loaders." + "MimoModelProvider._grids is None. Ensure build_model() is called before building data loaders." + ) + + # Validate that micro_batch_size is divisible by every module's DP size. + # slice_batch_for_mimo divides the micro-batch contiguously by the module's + # DP size in forward_step; a non-divisible MBS would leave a remainder. + micro_batch_size = cfg.train.micro_batch_size + for mod_name, mod_cfg in cfg.model.mimo_parallelism_config.module_parallelisms.items(): + dp = mod_cfg.data_parallel_size + assert micro_batch_size % dp == 0, ( + f"micro_batch_size ({micro_batch_size}) must be divisible by " + f"data_parallel_size ({dp}) of module '{mod_name}'. " + f"slice_batch_for_mimo requires an evenly divisible micro-batch." ) print_rank_0("> building MIMO train, validation, and test datasets ...") # Use cached grids from build_model() grids = cfg.model._grids - - dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info( - cfg.model.mimo_parallelism_config, grids - ) + + sampler_dp_rank, sampler_dp_size, needs_data = get_mimo_sampling_info(cfg.model.mimo_parallelism_config, grids) if not needs_data: return None, None, None @@ -96,21 +108,25 @@ def build_mimo_data_loaders( ) train_ds, valid_ds, test_ds = mimo_provider.build_datasets(context) - print_rank_0(f" Built datasets: train={len(train_ds) if train_ds else 0}, " - f"valid={len(valid_ds) if valid_ds else 0}, " - f"test={len(test_ds) if test_ds else 0}") + print_rank_0( + f" Built datasets: train={len(train_ds) if train_ds else 0}, " + f"valid={len(valid_ds) if valid_ds else 0}, " + f"test={len(test_ds) if test_ds else 0}" + ) - # Build data loaders with DP-aware sampling + # Build data loaders with globally consistent sampling. + # sampler_dp_size=1 so all data-loading ranks see the same batches. + # Per-module DP sub-sharding is done later by slice_batch_for_mimo. collate_fn = mimo_provider.get_collate_fn() micro_batch_size = cfg.train.micro_batch_size - + def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: if dataset is None: return None sampler = torch.utils.data.DistributedSampler( dataset, - num_replicas=dp_size, - rank=dp_rank, + num_replicas=sampler_dp_size, + rank=sampler_dp_rank, shuffle=shuffle, ) return DataLoader( @@ -126,5 +142,5 @@ def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: train_loader = _make_loader(train_ds, shuffle=True) valid_loader = _make_loader(valid_ds, shuffle=False) test_loader = _make_loader(test_ds, shuffle=False) - + return train_loader, valid_loader, test_loader diff --git a/src/megatron/bridge/models/mimo/mimo_builder.py b/src/megatron/bridge/models/mimo/mimo_builder.py index a8b266a2f4..c22fbdf0b6 100644 --- a/src/megatron/bridge/models/mimo/mimo_builder.py +++ b/src/megatron/bridge/models/mimo/mimo_builder.py @@ -50,6 +50,7 @@ def build_hypercomm_grids( _ = grid.create_pg(["tp", "pp"]) _ = grid.create_pg(["tp", "ep", "pp"]) _ = grid.create_pg(["dp", "ep"]) + _ = grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) grids[module_name] = grid diff --git a/src/megatron/bridge/models/mimo/mimo_ddp.py b/src/megatron/bridge/models/mimo/mimo_ddp.py index fbd0a91f8f..3fe17b0868 100644 --- a/src/megatron/bridge/models/mimo/mimo_ddp.py +++ b/src/megatron/bridge/models/mimo/mimo_ddp.py @@ -53,12 +53,20 @@ def wrap_mimo_model_distributed( if llm_grid is not None and is_current_rank_in_grid(llm_grid): llm_pg = pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if llm_pg is not None: - mimo_model.language_model = DistributedDataParallel( + wrapped_lm = DistributedDataParallel( config=mimo_model.language_model.config, ddp_config=ddp_config, module=mimo_model.language_model, pg_collection=llm_pg, ) + # MCore's DDP wrapper does not proxy arbitrary module methods. + # MimoModel._forward_language_module() checks for and calls + # language_model.set_input_tensor(...) on non-first PP stages. + # Preserve that method on the wrapper so decoder input tensors + # are wired correctly when language_model is DDP-wrapped. + if hasattr(wrapped_lm.module, "set_input_tensor"): + wrapped_lm.set_input_tensor = wrapped_lm.module.set_input_tensor + mimo_model.language_model = wrapped_lm # Wrap modality submodules if hasattr(mimo_model, "modality_submodules"): diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index 2cddb904ec..4d8438e881 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -110,7 +110,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): mimo_parallelism_config: Optional[MimoParallelismConfig] = None # Module data-flow DAG for MultiModulePipelineCommunicator. - # If None, auto-derived as: all modality_submodules → language module (terminal). + # If None, auto-derived as: all modality_submodules → MIMO_LANGUAGE_MODULE_KEY (terminal). # Set explicitly for non-standard topologies (e.g., language → generator). topology: Optional[Dict[str, List[str]]] = None @@ -167,6 +167,7 @@ def build_infra(self) -> MimoModelInfra: participating_modules = [name for name, pg in pg_collections.items() if pg is not None] # Derive module output tensor dimensionality if not explicitly configured. + # Language module produces 3D [S, B, H]; modality encoders produce 2D [S, H]. if self.module_output_ndim is not None: output_ndim = self.module_output_ndim else: diff --git a/src/megatron/bridge/training/checkpointing.py b/src/megatron/bridge/training/checkpointing.py index 7062fd2650..1170b8fb72 100644 --- a/src/megatron/bridge/training/checkpointing.py +++ b/src/megatron/bridge/training/checkpointing.py @@ -674,6 +674,12 @@ def save_checkpoint( pg_collection.dp_cp, ckpt_cfg.ckpt_assume_constant_structure, ) + # MiMo + torch_dist can hit known access-pattern validation failures + # for nested DDP language model tensors in PP>1 runs when + # fully_parallel_save is disabled. Keep validation enabled otherwise. + is_mimo = len(model) == 1 and hasattr(model[0], "mimo_config") + if is_mimo and ckpt_cfg.ckpt_format == "torch_dist" and not ckpt_cfg.fully_parallel_save: + validate_sharding_integrity = False # Store save strategy for future checkpoint saves if checkpointing_context is not None: checkpointing_context["save_strategy"] = save_strategy @@ -1314,6 +1320,7 @@ def load_checkpoint( strict: bool = True, checkpointing_context: Optional[dict[str, Any]] = None, skip_load_to_model_and_opt: bool = False, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> tuple[int, int]: """Load a model checkpoint. @@ -1330,6 +1337,9 @@ def load_checkpoint( checkpointing_context: Dictionary to store context across loads (e.g., strategies). skip_load_to_model_and_opt: If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules. + pg_collection: Optional ProcessGroupCollection. When provided, uses this instead of + extracting from model via get_pg_collection(). Required for MiMo where + model-level PG extraction may not reflect rank-local topology. Returns: A tuple containing: @@ -1352,7 +1362,8 @@ def load_checkpoint( cfg.checkpoint.finetune = True return _load_checkpoint_from_path( - load_dir, state, model, optimizer, opt_param_scheduler, strict, checkpointing_context + load_dir, state, model, optimizer, opt_param_scheduler, strict, checkpointing_context, + pg_collection=pg_collection, ) @@ -1381,6 +1392,7 @@ def _load_checkpoint_from_path( checkpointing_context: Optional[dict[str, Any]] = None, skip_load_to_model_and_opt: bool = False, ignore_ckpt_step: bool = False, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> tuple[int, int]: """Load a checkpoint from a given path. @@ -1396,6 +1408,9 @@ def _load_checkpoint_from_path( skips loading state into model and optimizer modules. ignore_ckpt_step: If True, ignores the ckpt_step config and loads latest checkpoint. Used when loading pretrained checkpoints in PEFT scenarios. + pg_collection: Optional ProcessGroupCollection. When provided, uses this instead of + extracting from model via get_pg_collection(). Required for MiMo where + model-level PG extraction may not reflect rank-local topology. Returns: A tuple containing: @@ -1404,7 +1419,7 @@ def _load_checkpoint_from_path( """ cfg = state.cfg model = unwrap_model(model) - pg_collection = get_pg_collection(model) + pg_collection = pg_collection or get_pg_collection(model) ckpt_format = cfg.checkpoint.ckpt_format # Step 1: Load base checkpoint with rank0=True (torch_dist only) @@ -1416,6 +1431,7 @@ def _load_checkpoint_from_path( checkpointing_context=checkpointing_context, ignore_ckpt_step=ignore_ckpt_step, cfg=cfg, + is_mimo=False, pg_collection=pg_collection, ) @@ -1437,19 +1453,31 @@ def _load_checkpoint_from_path( print_rank_0("run_config.yaml not found, extracting config from legacy Megatron-LM checkpoint") run_config = _extract_megatron_lm_args_from_state_dict(state_dict) - ckpt_tp_pp = ( - run_config["model"]["tensor_model_parallel_size"], - run_config["model"]["pipeline_model_parallel_size"], - ) - run_tp_pp = ( - cfg.model.tensor_model_parallel_size, - cfg.model.pipeline_model_parallel_size, + # MiMo manages per-module parallelism via MimoParallelismConfig, + # so there is no single global (TP, PP) to compare. Skip the + # compatibility check entirely for MiMo configs. + _is_mimo = ( + "mimo_parallelism_config" in run_config.get("model", {}) + or hasattr(cfg.model, "mimo_parallelism_config") ) - mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format(run_tp_pp, ckpt_tp_pp) + if _is_mimo: + tp_pp_match = True + mismatch_msg = "" + else: + ckpt_tp_pp = ( + run_config["model"]["tensor_model_parallel_size"], + run_config["model"]["pipeline_model_parallel_size"], + ) + run_tp_pp = ( + cfg.model.tensor_model_parallel_size, + cfg.model.pipeline_model_parallel_size, + ) + tp_pp_match = ckpt_tp_pp == run_tp_pp + mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format(run_tp_pp, ckpt_tp_pp) # Determine if RNG state will be loaded if ( - ckpt_tp_pp == run_tp_pp + tp_pp_match and not release and not cfg.checkpoint.finetune and cfg.checkpoint.load_rng @@ -1461,7 +1489,7 @@ def _load_checkpoint_from_path( else: ignore_rng_state = True gen_sd_rng_state = None - if ckpt_tp_pp != run_tp_pp: + if not tp_pp_match: print_rank_0("{}: RNG state will be ignored".format(mismatch_msg)) sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=state_dict) @@ -1487,7 +1515,7 @@ def _load_checkpoint_from_path( ), } if ( - ckpt_tp_pp != run_tp_pp + not tp_pp_match and sharded_sd_metadata["distrib_optim_sharding_type"] not in DistributedOptimizer.checkpoint_fully_reshardable_formats ): @@ -1502,7 +1530,7 @@ def _load_checkpoint_from_path( # Determine if rerun state will be loaded if ( - ckpt_tp_pp == run_tp_pp + tp_pp_match and not release and not cfg.checkpoint.finetune and "rerun_state_machine" in state_dict @@ -1514,7 +1542,7 @@ def _load_checkpoint_from_path( ignore_rerun_state = False else: gen_sd_rerun_state = None - if ckpt_tp_pp != run_tp_pp: + if not tp_pp_match: print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg)) sharded_sd_metadata["dp_cp_group"] = pg_collection.dp_cp @@ -1628,6 +1656,7 @@ def _load_checkpoint_from_path( checkpointing_context=checkpointing_context, ignore_ckpt_step=ignore_ckpt_step, cfg=cfg, + is_mimo=(_is_mimo if ckpt_format == "torch_dist" else False), pg_collection=pg_collection, **load_kwargs, ) @@ -1687,7 +1716,18 @@ def _load_checkpoint_from_path( and optimizer is not None and not getattr(optimizer, "is_stub_optimizer", False) ): - optimizer.load_state_dict(state_dict["optimizer"]) + # For MiMo with torch_dist, skip optimizer.load_state_dict(): + # dist_checkpointing only saves common state from rank 0, but + # non-colocated MiMo has different common state per rank (each + # rank only holds its active module's param_groups). The sharded + # param states are already loaded by dist_checkpointing.load, + # and the optimizer was pre-initialized via + # sharded_state_dict(is_loading=True). + # TODO: Make dist_checkpointing.save collect common state from + # all ranks in MiMo, or have MiMo replicate all modules' common + # state on every rank during save. That fix belongs in MCore. + if not (ckpt_format == "torch_dist" and _is_mimo): + optimizer.load_state_dict(state_dict["optimizer"]) if opt_param_scheduler is not None: if "lr_scheduler" in state_dict: @@ -2063,6 +2103,7 @@ def _load_global_dist_base_checkpoint( iteration: int, release: bool, checkpointing_context: Optional[dict[str, Any]] = None, + is_mimo: bool = False, *, pg_collection: ProcessGroupCollection, ) -> tuple[dict[str, Any], str, bool, CheckpointType]: @@ -2081,8 +2122,15 @@ def _load_global_dist_base_checkpoint( load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, pg_collection.dp_cp) if checkpointing_context is not None: checkpointing_context["load_strategy"] = load_strategy + validate_sharding_integrity = True + if is_mimo and ckpt_cfg.ckpt_format == "torch_dist" and not ckpt_cfg.fully_parallel_save: + validate_sharding_integrity = False state_dict = dist_checkpointing.load( - sharded_state_dict, checkpoint_name, load_strategy, strict=ckpt_cfg.dist_ckpt_strictness + sharded_state_dict, + checkpoint_name, + load_strategy, + strict=ckpt_cfg.dist_ckpt_strictness, + validate_access_integrity=validate_sharding_integrity, ) return state_dict, checkpoint_name, release, CheckpointType.GLOBAL @@ -2095,6 +2143,7 @@ def _load_base_checkpoint( checkpointing_context: Optional[dict[str, Any]] = None, ignore_ckpt_step: bool = False, cfg: Optional[ConfigContainer] = None, + is_mimo: bool = False, *, pg_collection: ProcessGroupCollection, ) -> tuple[Optional[dict[str, Any]], str, bool, Optional[CheckpointType]]: @@ -2183,6 +2232,7 @@ def _load_base_checkpoint( iteration, release, checkpointing_context=checkpointing_context, + is_mimo=is_mimo, pg_collection=pg_collection, ) elif ckpt_format == "fsdp_dtensor": diff --git a/src/megatron/bridge/training/eval.py b/src/megatron/bridge/training/eval.py index 99b6d6a0af..1103b41262 100644 --- a/src/megatron/bridge/training/eval.py +++ b/src/megatron/bridge/training/eval.py @@ -22,12 +22,25 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator, RerunMode, get_rerun_state_machine from megatron.core.transformer import MegatronModule from megatron.core.transformer.enums import CudaGraphScope from megatron.core.utils import get_model_config from modelopt.torch.distill.plugins.megatron import get_tensor_shapes_adjust_fn_for_distillation + +# Multimodule support from PR 3212 (optional - fallback if not available) +try: + from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator + from megatron.core.process_groups_config import MultiModuleProcessGroupCollection + + _MULTIMODULE_AVAILABLE = True +except ImportError: + MultiModulePipelineCommunicator = None # type: ignore + MultiModuleProcessGroupCollection = None # type: ignore + _MULTIMODULE_AVAILABLE = False + from megatron.bridge.data.finetuning import prepare_finetuning_batch from megatron.bridge.data.iterator_utils import make_data_iterator_list from megatron.bridge.training import fault_tolerance @@ -49,6 +62,8 @@ def evaluate( config: ConfigContainer, verbose: bool = False, non_loss_data_func: Optional[Callable] = None, + p2p_communicator: Optional[Union[P2PCommunicator, "MultiModulePipelineCommunicator"]] = None, + pg_collection: Optional[Union[ProcessGroupCollection, "MultiModuleProcessGroupCollection"]] = None, callback_manager: CallbackManager | None = None, is_test: bool = False, ) -> tuple[Optional[dict[str, torch.Tensor]], Optional[Any], bool]: @@ -63,6 +78,12 @@ def evaluate( config (ConfigContainer): Configuration container (potentially redundant). verbose (bool, optional): Whether to print evaluation progress. Defaults to False. non_loss_data_func (Optional[Callable], optional): Function to compute non-loss data. Defaults to None. + p2p_communicator (Optional[Union[P2PCommunicator, MultiModulePipelineCommunicator]], optional): + Custom communicator for pipeline parallelism. If None, creates a default P2PCommunicator. + For MIMO models, pass a MultiModulePipelineCommunicator. Defaults to None. + pg_collection (Optional[Union[ProcessGroupCollection, MultiModuleProcessGroupCollection]], optional): + Custom process group collection. If None, extracts from model via get_pg_collection(). + For MIMO models, pass a MultiModuleProcessGroupCollection. Defaults to None. callback_manager (Optional[CallbackManager]): Optional callback manager for firing callbacks. is_test (bool, optional): Whether this is test evaluation (vs validation). Defaults to False. Controls which callback events are fired (on_test_* vs on_eval_*). @@ -88,7 +109,9 @@ def evaluate( model_module.eval() # Retrieve process group collection and model config from the model - pg_collection = get_pg_collection(model) + # Use injected pg_collection if provided, otherwise extract from model + if pg_collection is None: + pg_collection = get_pg_collection(model) model_config = get_model_config(model[0]) # Disable result validation during evaluation @@ -102,6 +125,11 @@ def evaluate( eval_batch_size = state.cfg.train.global_batch_size eval_num_microbatches = eval_batch_size // (state.cfg.train.micro_batch_size * state.cfg.data_parallel_size) + # Determine if this is a multimodule evaluation (MIMO) + is_multimodule = isinstance(pg_collection, MultiModuleProcessGroupCollection) or isinstance( + p2p_communicator, MultiModulePipelineCommunicator + ) + if not state.cfg.dist.use_decentralized_pg: adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( model, @@ -116,7 +144,15 @@ def evaluate( if verbose: print_rank_0(f"Evaluating on {state.cfg.validation.eval_iters * eval_batch_size} samples") - if ( + if is_multimodule: + # For multimodule, use forward_backward_pipelining_without_interleaving directly + # CUDA graphs not yet supported for multimodule + from megatron.core.pipeline_parallel.schedules import ( + forward_backward_pipelining_without_interleaving, + ) + + forward_backward_func = forward_backward_pipelining_without_interleaving + elif ( state.cfg.model.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in state.cfg.model.cuda_graph_scope ): @@ -163,7 +199,11 @@ def evaluate( # Don't care about timing during evaluation config.timers = None fault_tolerance.on_eval_step_start(state) - p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) + + # Use injected communicator or create default P2PCommunicator + eval_p2p_communicator = p2p_communicator + if eval_p2p_communicator is None: + eval_p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) if should_fire(callback_manager, step_start_event): callback_manager.fire( @@ -184,7 +224,7 @@ def evaluate( micro_batch_size=state.cfg.train.micro_batch_size, forward_only=True, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, - p2p_communicator=p2p_communicator, + p2p_communicator=eval_p2p_communicator, pg_collection=pg_collection, ) fault_tolerance.on_eval_step_end(state) @@ -212,8 +252,21 @@ def evaluate( if state.cfg.train.empty_unused_memory_level >= 1: torch.cuda.empty_cache() - if is_pp_last_stage(pg_collection.pp): + # Check if this is the last pipeline stage + # For multimodule, use communicator property; for single module, use pg_collection.pp + if is_multimodule: + is_last_stage = eval_p2p_communicator.is_pp_last_stage + else: + is_last_stage = is_pp_last_stage(pg_collection.pp) + + if is_last_stage: # Reduce across processes. + # For multimodule, get dp_cp from the language model's pg_collection + if is_multimodule: + dp_cp_group = pg_collection.get_language_model_collection().dp_cp + else: + dp_cp_group = pg_collection.dp_cp + for key in loss_dicts[0].keys(): if key not in total_loss_dict: total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() @@ -221,7 +274,7 @@ def evaluate( if val[0].numel() == 2: val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce(val, group=pg_collection.dp_cp) + torch.distributed.all_reduce(val, group=dp_cp_group) total_loss_dict[key] += val elif val[0].numel() == 1: val = torch.cat(val).sum() @@ -265,7 +318,11 @@ def evaluate( data_iterator=non_loss_microbatch_iterator, ) - p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) + # Use injected communicator or create default P2PCommunicator + non_loss_p2p_communicator = p2p_communicator + if non_loss_p2p_communicator is None: + non_loss_p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) + collected_non_loss_data = forward_backward_func( forward_step_func=wrapped_forward_step, data_iterator=non_loss_data_iterator, @@ -275,7 +332,7 @@ def evaluate( micro_batch_size=state.cfg.train.micro_batch_size, forward_only=True, collect_non_loss_data=True, - p2p_communicator=p2p_communicator, + p2p_communicator=non_loss_p2p_communicator, pg_collection=pg_collection, ) @@ -306,6 +363,8 @@ def evaluate_and_print_results( write_to_tensorboard: bool = True, process_non_loss_data_func: Optional[Callable] = None, non_loss_data_func: Optional[Callable] = None, + p2p_communicator: Optional[Union[P2PCommunicator, "MultiModulePipelineCommunicator"]] = None, + pg_collection: Optional[Union[ProcessGroupCollection, "MultiModuleProcessGroupCollection"]] = None, callback_manager: CallbackManager | None = None, is_test: bool = False, ) -> None: @@ -322,6 +381,10 @@ def evaluate_and_print_results( write_to_tensorboard (bool, optional): Whether to write results to TensorBoard. Defaults to True. process_non_loss_data_func (Optional[Callable], optional): Function to process non-loss data. Defaults to None. non_loss_data_func (Optional[Callable], optional): Function to compute non-loss data. Defaults to None. + p2p_communicator (Optional[Union[P2PCommunicator, MultiModulePipelineCommunicator]], optional): + Custom communicator for pipeline parallelism. Passed to evaluate(). Defaults to None. + pg_collection (Optional[Union[ProcessGroupCollection, MultiModuleProcessGroupCollection]], optional): + Custom process group collection. Passed to evaluate(). Defaults to None. callback_manager (Optional[CallbackManager]): Optional callback manager for firing callbacks. is_test (bool, optional): Whether this is test evaluation (vs validation). Defaults to False. Controls which callback events are fired (on_test_* vs on_eval_*). @@ -356,6 +419,8 @@ def evaluate_and_print_results( config, verbose, non_loss_data_func, + p2p_communicator=p2p_communicator, + pg_collection=pg_collection, callback_manager=callback_manager, is_test=is_test, ) diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py index eb6598dd89..2da7fd09ec 100644 --- a/src/megatron/bridge/training/mimo_step.py +++ b/src/megatron/bridge/training/mimo_step.py @@ -13,21 +13,42 @@ import logging from functools import partial -from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple import torch from megatron.core.models.mimo import MimoModel from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.bridge.data.mimo.dp_utils import slice_batch_for_mimo from megatron.bridge.training.mimo_parallel_utils import unwrap_mimo_model from megatron.bridge.training.state import GlobalState -if TYPE_CHECKING: - pass +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +def _get_module_dp_info( + mimo_model: MimoModel, +) -> Tuple[int, int]: + """Get module-local DP rank and size for the current rank. + + Used to slice the global micro-batch via :func:`slice_batch_for_mimo`. + Returns (0, 1) when grids are not configured (colocated mode). + """ + grids = getattr(mimo_model.mimo_config, "module_to_grid_map", None) + if not grids: + return 0, 1 + + import torch.distributed as _dist + + current_rank = _dist.get_rank() + for _name, grid in grids.items(): + if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): + dp_rank = grid.get_pg(["dp"]).rank() + dp_size = grid.get_pg(["dp"]).size() + return dp_rank, dp_size + + return 0, 1 def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor) -> Tuple: @@ -136,9 +157,8 @@ def forward_step( needs_data = True if mimo_model.role is not None: if mimo_model.role.has_language_module: - module_name = MIMO_LANGUAGE_MODULE_KEY - is_first_stage = mimo_model.role.is_first_stage(module_name) - is_last_stage = mimo_model.role.is_last_stage(module_name) + is_first_stage = mimo_model.role.is_first_stage(MIMO_LANGUAGE_MODULE_KEY) + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) needs_data = is_first_stage or is_last_stage elif mimo_model.role.has_modality_modules: modality_modules = mimo_model.role.modality_module_names @@ -151,6 +171,12 @@ def forward_step( "get_batch returned None at a stage that requires data. " "This indicates a data-loading or parallelism misconfiguration." ) + # Slice the global micro-batch for this module's DP shard. + # All data-loading ranks receive identical batches (sampler dp_size=1). + # slice_batch_for_mimo contiguously sub-shards to match the + # BridgeCommunicator's fan-in/fan-out batch-dimension routing. + dp_rank, dp_size = _get_module_dp_info(mimo_model) + data_batch = slice_batch_for_mimo(data_batch, dp_rank, dp_size) else: # Non-data stages consume hidden states from pipeline input tensors. data_batch = { diff --git a/src/megatron/bridge/training/pretrain_mimo.py b/src/megatron/bridge/training/pretrain_mimo.py index b55084ffb2..e659260239 100644 --- a/src/megatron/bridge/training/pretrain_mimo.py +++ b/src/megatron/bridge/training/pretrain_mimo.py @@ -21,6 +21,7 @@ from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.utils import get_model_config +from megatron.bridge.training.checkpointing import init_checkpointing_context, load_checkpoint from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.mimo_parallel_utils import ( build_pg_collection_for_schedule, @@ -30,6 +31,7 @@ ) from megatron.bridge.training.state import GlobalState from megatron.bridge.training.train_mimo import train_mimo +from megatron.bridge.training.utils.checkpoint_utils import checkpoint_exists if TYPE_CHECKING: @@ -55,6 +57,8 @@ class MimoSetupOutput: train_data_iterator: Training data iterator. valid_data_iterator: Validation data iterator (optional). global_state: GlobalState containing timers, config, train_state. + checkpointing_context: Dictionary holding checkpoint-related state + (save strategy cache, LocalCheckpointManager for local saves). """ model: "MimoModel" @@ -65,6 +69,7 @@ class MimoSetupOutput: train_data_iterator: Iterator valid_data_iterator: Optional[Iterator] global_state: GlobalState + checkpointing_context: Dict[str, Any] def setup_mimo( @@ -108,10 +113,11 @@ def setup_mimo( ) train_state = TrainState() global_state = GlobalState() - global_state.cfg = cfg global_state._timers = timers global_state.train_state = train_state + global_state.cfg = cfg + logger.info(f"Rank {dist.get_rank()}: Setting up MIMO training") # Finalize and build infrastructure @@ -182,6 +188,17 @@ def setup_mimo( logger.info(f"Rank {dist.get_rank()}: Building data iterators") train_data_iterator, valid_data_iterator = build_data_iterators_fn(cfg, mimo_infra) + # Initialize async checkpoint worker (idempotent if already initialized). + global_state.initialize_async_checkpoint_worker() + + # Initialize checkpointing context (save strategy cache + LocalCheckpointManager). + checkpointing_context = init_checkpointing_context(cfg.checkpoint) + + # Align start_time across ranks so duration-based exit is consistent. + start_time_tensor = torch.tensor([global_state.start_time], dtype=torch.double, device="cuda") + dist.all_reduce(start_time_tensor, op=dist.ReduceOp.MIN) + global_state.start_time = start_time_tensor.item() + logger.info(f"Rank {dist.get_rank()}: MIMO setup complete") return MimoSetupOutput( @@ -193,6 +210,7 @@ def setup_mimo( train_data_iterator=train_data_iterator, valid_data_iterator=valid_data_iterator, global_state=global_state, + checkpointing_context=checkpointing_context, ) @@ -235,21 +253,26 @@ def pretrain_mimo( # Initialize num-microbatches calculator if not already set. from megatron.core import num_microbatches_calculator as nmc + rampup_batch_size = getattr(cfg.train, "rampup_batch_size", None) + assert rampup_batch_size is None, ( + "Microbatch rampup is not supported in MiMo training. Set rampup_batch_size to None." + ) + if nmc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: nmc.init_num_microbatches_calculator( dist.get_rank(), - getattr(cfg.train, "rampup_batch_size", None), + rampup_batch_size, cfg.train.global_batch_size, cfg.train.micro_batch_size, cfg.data_parallel_size, getattr(cfg.train, "decrease_batch_size_if_needed", False), ) - # Setup MIMO components + # Setup MIMO components (iterators deferred until after checkpoint load) setup_output = setup_mimo( cfg=cfg, mimo_provider=mimo_provider, - build_data_iterators_fn=build_data_iterators_fn, + build_data_iterators_fn=None, global_state=global_state, ) @@ -298,6 +321,86 @@ def pretrain_mimo( ) logger.info(f"Rank {dist.get_rank()}: Auto-created schedulers for modules: {list(schedulers.keys())}") + # Select rank-local PG collection for non-colocated MiMo. + # Each rank participates in exactly one module, so "first non-None" is unambiguous. + active_pgs = [pg for pg in setup_output.mimo_infra.pg_collections.values() if pg is not None] + assert len(active_pgs) == 1, ( + f"Non-colocated MiMo requires exactly one active ProcessGroupCollection per rank, " + f"got {len(active_pgs)}. Colocated MiMo is not supported by this code path." + ) + local_pg_collection = active_pgs[0] + + # Bridge MiMo's per-module process groups into Megatron's global parallel + # state. MiMo intentionally skips global MPU init (see + # MimoModelProvider.initialize_model_parallel), but checkpoint save/load + # paths (sharded_state_dict, ensure_metadata_has_dp_cp_group) rely on the + # globals. For non-colocated MiMo every rank is active in exactly one + # module, so we can safely set the globals from that module's collection. + from megatron.core import parallel_state as mpu + + mpu._TENSOR_MODEL_PARALLEL_GROUP = local_pg_collection.tp + mpu._DATA_PARALLEL_GROUP = local_pg_collection.dp + mpu._DATA_PARALLEL_GROUP_WITH_CP = getattr(local_pg_collection, "dp_cp", local_pg_collection.dp) + if hasattr(local_pg_collection, "pp"): + mpu._PIPELINE_MODEL_PARALLEL_GROUP = local_pg_collection.pp + + first_scheduler = next(iter(schedulers.values()), None) if schedulers else None + + # Broadened load-intent gating: includes non-persistent resume intent + has_persistent = cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load) + has_pretrained = cfg.checkpoint.pretrained_checkpoint is not None and checkpoint_exists( + cfg.checkpoint.pretrained_checkpoint + ) + wants_non_persistent = cfg.checkpoint.non_persistent_ckpt_type is not None + should_load = has_persistent or has_pretrained or wants_non_persistent + + if should_load: + timers = setup_output.global_state.timers + timers("load-checkpoint", log_level=0).start(barrier=True) + load_checkpoint( + setup_output.global_state, + model=[setup_output.model], + optimizer=optimizer, + opt_param_scheduler=first_scheduler, + checkpointing_context=setup_output.checkpointing_context, + pg_collection=local_pg_collection, + ) + timers("load-checkpoint").stop(barrier=True) + timers.log(["load-checkpoint"]) + + # Fan out loaded scheduler state to all active module schedulers. + # v1: checkpoints contain a single scheduler blob (first_scheduler). + if first_scheduler is not None and len(schedulers) > 1: + loaded_state = first_scheduler.state_dict() + for sched in schedulers.values(): + if sched is not first_scheduler: + sched.load_state_dict(loaded_state) + + # Build data iterators after load decision (resume-safe ordering). + # When resuming, train_state has restored consumed-sample offsets that + # the iterator builder must honor to avoid replaying data from sample 0. + train_state = setup_output.global_state.train_state + is_resuming = train_state.step > 0 + + if is_resuming: + import inspect + + sig = inspect.signature(build_data_iterators_fn) + if "train_state" in sig.parameters: + train_data_iterator, valid_data_iterator = build_data_iterators_fn( + cfg, + setup_output.mimo_infra, + train_state=train_state, + ) + else: + raise RuntimeError( + "Resuming from checkpoint but build_data_iterators_fn does not accept " + "'train_state' argument. The iterator builder must support a train_state " + "keyword argument to honor restored consumed-sample offsets during resume." + ) + else: + train_data_iterator, valid_data_iterator = build_data_iterators_fn(cfg, setup_output.mimo_infra) + logger.info(f"Rank {dist.get_rank()}: Starting training loop") # Run training loop @@ -306,11 +409,12 @@ def pretrain_mimo( model=setup_output.model, optimizer=optimizer, schedulers=schedulers, - train_data_iterator=setup_output.train_data_iterator, - valid_data_iterator=setup_output.valid_data_iterator, + train_data_iterator=train_data_iterator, + valid_data_iterator=valid_data_iterator, global_state=setup_output.global_state, mimo_infra=setup_output.mimo_infra, multimodule_communicator=setup_output.multimodule_communicator, + checkpointing_context=setup_output.checkpointing_context, ) logger.info("MIMO pretraining completed") diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index f2aaf14509..52397b1e96 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -1071,6 +1071,7 @@ def save_checkpoint_and_time( checkpointing_context: dict[str, Any], non_persistent_ckpt: bool = False, train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: """Saves a checkpoint and logs the timing. @@ -1088,6 +1089,8 @@ def save_checkpoint_and_time( non_persistent_ckpt: Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False. train_data_iterator: Optional training data iterator to save its state. + pg_collection: Optional process group collection for MiMo topologies. + When None, save_checkpoint falls back to model-attached PGs. """ timers = state.timers energy_monitor = state.energy_monitor @@ -1119,6 +1122,7 @@ def save_checkpoint_and_time( checkpointing_context=checkpointing_context, non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) if state.cfg.model.fp8 is not None: # Run garbage collection after checkpoint saving to free memory from @@ -1145,6 +1149,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far: float, checkpointing_context: dict[str, Any], train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], + pg_collection: Optional[ProcessGroupCollection] = None, ) -> bool: """Handles checkpointing decisions and determines if training should exit. @@ -1160,6 +1165,8 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far: Cumulative TFLOPs up to this point. checkpointing_context: Dictionary holding checkpointing-related state. train_data_iterator: Optional training data iterator to save its state. + pg_collection: Optional process group collection for MiMo topologies. + When None, save_checkpoint falls back to model-attached PGs. Returns: True if the training loop should exit, False otherwise. @@ -1179,6 +1186,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far, checkpointing_context, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) barrier_and_log("exiting program after receiving SIGTERM.") @@ -1198,6 +1206,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far, checkpointing_context, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) saved_checkpoint = True @@ -1215,6 +1224,7 @@ def checkpoint_and_decide_exit( checkpointing_context, non_persistent_ckpt=True, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) saved_checkpoint = True @@ -1234,6 +1244,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far, checkpointing_context, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) barrier_and_log(f"exiting program after {train_time} minutes") @@ -1250,6 +1261,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far, checkpointing_context, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) barrier_and_log(f"exiting program at iteration {state.train_state.step}") @@ -1266,6 +1278,7 @@ def checkpoint_and_decide_exit( num_floating_point_operations_so_far, checkpointing_context, train_data_iterator=train_data_iterator, + pg_collection=pg_collection, ) barrier_and_log("Exiting program due to straggler detection.") return True diff --git a/src/megatron/bridge/training/train_mimo.py b/src/megatron/bridge/training/train_mimo.py index e8d1dc2a3d..aa020d065d 100644 --- a/src/megatron/bridge/training/train_mimo.py +++ b/src/megatron/bridge/training/train_mimo.py @@ -28,7 +28,7 @@ from megatron.core.pipeline_parallel.schedules import forward_backward_pipelining_without_interleaving from megatron.core.utils import get_model_config -from megatron.bridge.training.checkpointing import maybe_finalize_async_save, save_checkpoint +from megatron.bridge.training.checkpointing import maybe_finalize_async_save from megatron.bridge.training.eval import evaluate_and_print_results from megatron.bridge.training.mimo_parallel_utils import ( build_pg_collection_for_schedule, @@ -45,6 +45,7 @@ should_profile_rank, ) from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.train import checkpoint_and_decide_exit from megatron.bridge.training.utils.train_utils import ( prepare_forward_step_func, training_log, @@ -196,6 +197,7 @@ def train_mimo( global_state: GlobalState, mimo_infra: "MimoModelInfra", multimodule_communicator: "MultiModulePipelineCommunicator", + checkpointing_context: Optional[Dict] = None, ) -> None: """Main MIMO training loop. @@ -209,11 +211,10 @@ def train_mimo( - GlobalState for timers, config, train_state - training_log() for metrics reporting - handle_profiling_step() and handle_profiling_stop() for profiler lifecycle - - save_checkpoint() with MimoOptimizer for checkpointing + - save_checkpoint_and_time() / checkpoint_and_decide_exit() for checkpointing - evaluate_and_print_results() for validation with multimodule support - maybe_finalize_async_save() for async checkpoint finalization - Args: forward_step_func: Forward step function. model: MimoModel instance. @@ -224,6 +225,9 @@ def train_mimo( global_state: GlobalState containing timers, config, train_state. mimo_infra: MimoModelInfra with grids, topology, pg_collections. multimodule_communicator: MultiModulePipelineCommunicator for P2P. + checkpointing_context: Dictionary holding checkpoint-related state + (save strategy cache, LocalCheckpointManager). Created by + init_checkpointing_context() in pretrain_mimo. """ timers = global_state.timers train_state = global_state.train_state @@ -251,17 +255,18 @@ def train_mimo( "The list-based fallback is not supported. Ensure Megatron-LM PR 3212 is available." ) - # Use rank-local module PG for logging reductions to avoid global MPU fallback. - # NOTE: In non-colocated MIMO each rank participates in exactly one module, so - # "first non-None" unambiguously selects that module's PG. For colocated MIMO - # (where a rank participates in multiple modules), this selection must be - # replaced with per-module logging or an explicit module-aware reduction strategy. - local_pg_collection = next((pg for pg in mimo_infra.pg_collections.values() if pg is not None), None) - if local_pg_collection is None: - raise RuntimeError( - "No local ProcessGroupCollection found for this rank. " - "Ensure rank participation is correctly configured in MIMO infrastructure." - ) + # Use rank-local module PG for logging reductions and checkpoint saving to + # avoid global MPU fallback. In non-colocated MIMO each rank participates in + # exactly one module, so "first non-None" unambiguously selects that module's PG. + active_pgs = [pg for pg in mimo_infra.pg_collections.values() if pg is not None] + assert len(active_pgs) == 1, ( + f"Non-colocated MiMo requires exactly one active ProcessGroupCollection per rank, " + f"got {len(active_pgs)}. Colocated MiMo is not supported by this code path." + ) + local_pg_collection = active_pgs[0] + + if checkpointing_context is None: + checkpointing_context = {} # Configure gradient hooks on model config model_config = get_model_config(model) @@ -309,6 +314,14 @@ def train_mimo( timers("interval-time", log_level=0).start(barrier=True) while train_state.step < train_config.train_iters: + # Finalize any pending async saves (non-blocking). Placed at the top + # of the loop so async saves get a full iteration to complete. + maybe_finalize_async_save( + global_state=global_state, + ckpt_cfg=cfg.checkpoint, + blocking=False, + ) + # Handle profiling nsys_ctx = handle_profiling_step( prof_config, @@ -415,24 +428,20 @@ def train_mimo( ) timers("evaluate").stop() - # Checkpointing at specified intervals - if cfg.checkpoint.save_interval is not None and train_state.step % cfg.checkpoint.save_interval == 0: - timers("save-checkpoint", log_level=0).start(barrier=True) - save_checkpoint( - state=global_state, - model=[model], - optimizer=optimizer, - opt_param_scheduler=first_scheduler, - num_floating_point_operations_so_far=0, # TODO: Add proper FLOPs tracking - ) - timers("save-checkpoint").stop() - - # Finalize any pending async saves (non-blocking during training) - maybe_finalize_async_save( - global_state=global_state, - ckpt_cfg=cfg.checkpoint, - blocking=False, + # Checkpointing (interval, signal, duration, exit-interval) and exit decision. + # TODO: MiMo FLOPs estimation is non-trivial (heterogeneous modules); pass 0 for now. + should_exit = checkpoint_and_decide_exit( + state=global_state, + model=[model], + optimizer=optimizer, + opt_param_scheduler=first_scheduler, + num_floating_point_operations_so_far=0, + checkpointing_context=checkpointing_context, + train_data_iterator=train_data_iterator, + pg_collection=local_pg_collection, ) + if should_exit: + break # Stop profiling handle_profiling_stop( diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000000..774bc60c1d --- /dev/null +++ b/tests/e2e/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""End-to-end tests for Megatron-Bridge.""" diff --git a/tests/e2e/mimo/__init__.py b/tests/e2e/mimo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/mimo/run_hetero_llava.sh b/tests/e2e/mimo/run_hetero_llava.sh new file mode 100644 index 0000000000..66adc30be1 --- /dev/null +++ b/tests/e2e/mimo/run_hetero_llava.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Heterogeneous MIMO LLaVA training — LLM on ranks 0-3, CLIP on ranks 4-7. + +GPUS_PER_NODE=8 +NUM_NODES=1 + +uv run torchrun \ + --nproc_per_node "$GPUS_PER_NODE" \ + --nnodes "$NUM_NODES" \ + tests/e2e/mimo/test_mimo_training_llava.py \ + --micro-batch-size 2 \ + --global-batch-size 32 \ + --train-iters 500 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --clip-grad 1.0 \ + --log-interval 1 \ + --lr 1e-4 \ + --lr-warmup-iters 20 \ + --min-lr 2.0e-5 \ + --weight-decay 0.01 \ + --wandb-project "Megatron-Bridge-MIMO" \ + --wandb-exp-name "mimo-llava-e2e-test" \ + --wandb-save-dir "/tmp/wandb" \ + --dataset-root /path/to/llava/pretrain/dataset diff --git a/tests/e2e/mimo/run_mimo_checkpoint_resume.sh b/tests/e2e/mimo/run_mimo_checkpoint_resume.sh new file mode 100755 index 0000000000..49a43c8347 --- /dev/null +++ b/tests/e2e/mimo/run_mimo_checkpoint_resume.sh @@ -0,0 +1,175 @@ +#!/bin/bash +# MIMO checkpoint save→resume round-trip e2e test. +# +# Runs the test in two phases (separate torchrun invocations) for each +# parallelism configuration: +# Phase 1 (save): Train for 5 steps, save checkpoint. +# Phase 2 (resume): Resume from checkpoint, train to 10 steps, verify continuity. +# +# Usage: +# ./run_mimo_checkpoint_resume.sh # 8 GPUs, all configs +# ./run_mimo_checkpoint_resume.sh --gpus 8 # explicit GPU count +# ./run_mimo_checkpoint_resume.sh --config tp4_both # single config only + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TEST_FILE="${SCRIPT_DIR}/test_mimo_checkpoint_resume_e2e.py" + +NUM_GPUS=${NUM_GPUS:-8} +SINGLE_CONFIG="" + +while [[ $# -gt 0 ]]; do + case $1 in + --gpus) NUM_GPUS="$2"; shift 2 ;; + --config) SINGLE_CONFIG="$2"; shift 2 ;; + *) echo "Unknown argument: $1"; exit 1 ;; + esac +done + +echo "==========================================" +echo "MIMO Checkpoint Resume E2E Tests" +echo "GPUs: ${NUM_GPUS}" +echo "==========================================" + +# Config format: "name|llm_tp|llm_pp|llm_dp|llm_offset|vision_tp|vision_pp|vision_dp|vision_offset" +# Note: CLIPViT does not support PP > 1; only LLM can use PP. +declare -a CONFIGS_8GPU=( + "dp4_both|1|1|4|0|1|1|4|4" + "tp4_both|4|1|1|0|4|1|1|4" + "tp2_dp2_both|2|1|2|0|2|1|2|4" + "pp2_llm_dp4_vision|1|2|2|0|1|1|4|4" +) + +declare -a CONFIGS_4GPU=( + "dp2_both|1|1|2|0|1|1|2|2" + "tp2_both|2|1|1|0|2|1|1|2" +) + +declare -a CONFIGS_2GPU=( + "dp1_both|1|1|1|0|1|1|1|1" +) + +if [[ $NUM_GPUS -ge 8 ]]; then + CONFIGS=("${CONFIGS_8GPU[@]}") +elif [[ $NUM_GPUS -ge 4 ]]; then + CONFIGS=("${CONFIGS_4GPU[@]}") +else + CONFIGS=("${CONFIGS_2GPU[@]}") +fi + +declare -a RESULTS=() +declare -a FAILED_CONFIGS=() +TOTAL=0 +PASSED=0 + +run_config() { + local config="$1" + local name llm_tp llm_pp llm_dp llm_offset vision_tp vision_pp vision_dp vision_offset + IFS='|' read -r name llm_tp llm_pp llm_dp llm_offset vision_tp vision_pp vision_dp vision_offset <<< "$config" + + echo "" + echo "----------------------------------------" + echo "Config: ${name}" + echo " LLM: TP=${llm_tp}, PP=${llm_pp}, DP=${llm_dp}, offset=${llm_offset}" + echo " Vision: TP=${vision_tp}, PP=${vision_pp}, DP=${vision_dp}, offset=${vision_offset}" + echo "----------------------------------------" + + TOTAL=$((TOTAL + 1)) + local start_time=$(date +%s) + + CKPT_DIR=$(mktemp -d -t "mimo_ckpt_${name}_XXXXXX") + + local env_prefix="MIMO_LLM_TP=${llm_tp} MIMO_LLM_PP=${llm_pp} MIMO_LLM_DP=${llm_dp} MIMO_LLM_OFFSET=${llm_offset}" + env_prefix="${env_prefix} MIMO_VISION_TP=${vision_tp} MIMO_VISION_PP=${vision_pp} MIMO_VISION_DP=${vision_dp} MIMO_VISION_OFFSET=${vision_offset}" + + local ok=true + + echo " Phase 1: SAVE" + if ! env ${env_prefix} \ + python -m torch.distributed.run --nproc_per_node="${NUM_GPUS}" \ + "${TEST_FILE}" --phase save --ckpt-dir "${CKPT_DIR}" 2>&1; then + ok=false + fi + + if $ok; then + echo " Phase 2: RESUME" + if ! env ${env_prefix} \ + python -m torch.distributed.run --nproc_per_node="${NUM_GPUS}" \ + "${TEST_FILE}" --phase resume --ckpt-dir "${CKPT_DIR}" 2>&1; then + ok=false + fi + fi + + rm -rf "${CKPT_DIR}" + + local end_time=$(date +%s) + local duration=$((end_time - start_time)) + + if $ok; then + RESULTS+=("PASS|${name}|${duration}s") + PASSED=$((PASSED + 1)) + echo " [PASS] ${name} (${duration}s)" + else + RESULTS+=("FAIL|${name}|${duration}s") + FAILED_CONFIGS+=("${name}") + echo " [FAIL] ${name} (${duration}s)" + return 1 + fi + return 0 +} + +if [[ -n "${SINGLE_CONFIG}" ]]; then + found=false + for config in "${CONFIGS[@]}"; do + name="${config%%|*}" + if [[ "${name}" == "${SINGLE_CONFIG}" ]]; then + run_config "${config}" + found=true + break + fi + done + if [[ "${found}" == "false" ]]; then + echo "Error: Config '${SINGLE_CONFIG}' not found. Available:" + for config in "${CONFIGS[@]}"; do echo " - ${config%%|*}"; done + exit 1 + fi +else + for config in "${CONFIGS[@]}"; do + if ! run_config "${config}"; then + name="${config%%|*}" + echo "" + echo "==========================================" + echo "FATAL: Config '${name}' failed. Aborting." + echo "==========================================" + exit 1 + fi + done +fi + +echo "" +echo "==========================================" +echo "SUMMARY: ${PASSED}/${TOTAL} passed" +echo "==========================================" +printf "%-6s | %-25s | %s\n" "Status" "Configuration" "Time" +echo "-------|---------------------------|-------" +for result in "${RESULTS[@]}"; do + IFS='|' read -r status name duration <<< "$result" + if [[ "${status}" == "PASS" ]]; then + printf "\033[32m%-6s\033[0m | %-25s | %s\n" "${status}" "${name}" "${duration}" + else + printf "\033[31m%-6s\033[0m | %-25s | %s\n" "${status}" "${name}" "${duration}" + fi +done +echo "==========================================" + +if [[ ${#FAILED_CONFIGS[@]} -gt 0 ]]; then + echo "" + echo "Failed configurations:" + for cfg in "${FAILED_CONFIGS[@]}"; do echo " - ${cfg}"; done + exit 1 +fi + +echo "" +echo "All checkpoint resume tests passed!" +exit 0 diff --git a/tests/e2e/mimo/run_mimo_parallelism_tests.sh b/tests/e2e/mimo/run_mimo_parallelism_tests.sh new file mode 100755 index 0000000000..066be982ac --- /dev/null +++ b/tests/e2e/mimo/run_mimo_parallelism_tests.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# Run MIMO E2E test with various parallelism configurations +# Usage: ./run_mimo_parallelism_tests.sh [--gpus N] [--config CONFIG_NAME] +# +# Examples: +# ./run_mimo_parallelism_tests.sh # Run all configs with 8 GPUs +# ./run_mimo_parallelism_tests.sh --gpus 4 # Run all configs with 4 GPUs +# ./run_mimo_parallelism_tests.sh --config tp2_both # Run only tp2_both config + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TEST_FILE="${SCRIPT_DIR}/test_mimo_training_e2e.py" + +# Default values +NUM_GPUS=${NUM_GPUS:-8} +SINGLE_CONFIG="" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --gpus) + NUM_GPUS="$2" + shift 2 + ;; + --config) + SINGLE_CONFIG="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +echo "==========================================" +echo "MIMO Parallelism E2E Tests" +echo "GPUs: ${NUM_GPUS}" +echo "==========================================" + +# Define configurations as: "name|llm_tp|llm_pp|llm_dp|llm_offset|vision_tp|vision_pp|vision_dp|vision_offset" +# Note: Vision encoder (CLIPViT) does not support PP > 1, only LLM can use PP +declare -a CONFIGS_8GPU=( + "baseline_dp_only|1|1|4|0|1|1|4|4" + # "tp2_both|2|1|2|0|2|1|2|4" + # "tp2_llm_dp_vision|2|1|2|0|1|1|4|4" + # "pp2_llm_only|1|2|2|0|1|1|4|4" + # "tp4_both|4|1|1|0|4|1|1|4" + # "tp4_llm_tp2_vision|4|1|1|0|2|1|2|4" + # "3d_llm_dp_vision|2|2|1|0|1|1|4|4" + # "asymmetric_6_2_pp|2|3|1|0|2|1|1|6" +) + +# Note: PP > 1 not included for 4 GPU configs (would need at least 2 ranks for PP) +declare -a CONFIGS_4GPU=( + "baseline_dp_only|1|1|2|0|1|1|2|2" + "tp2_both|2|1|1|0|2|1|1|2" +) + +declare -a CONFIGS_2GPU=( + "baseline_dp_only|1|1|1|0|1|1|1|1" +) + +# Select configs based on GPU count +if [[ $NUM_GPUS -ge 8 ]]; then + CONFIGS=("${CONFIGS_8GPU[@]}") +elif [[ $NUM_GPUS -ge 4 ]]; then + CONFIGS=("${CONFIGS_4GPU[@]}") +else + CONFIGS=("${CONFIGS_2GPU[@]}") +fi + +# Track results +declare -a RESULTS=() +declare -a FAILED_CONFIGS=() +TOTAL=0 +PASSED=0 + +run_config() { + local config="$1" + local name llm_tp llm_pp llm_dp llm_offset vision_tp vision_pp vision_dp vision_offset + + IFS='|' read -r name llm_tp llm_pp llm_dp llm_offset vision_tp vision_pp vision_dp vision_offset <<< "$config" + + echo "" + echo "----------------------------------------" + echo "Running: ${name}" + echo " LLM: TP=${llm_tp}, PP=${llm_pp}, DP=${llm_dp}, offset=${llm_offset}" + echo " Vision: TP=${vision_tp}, PP=${vision_pp}, DP=${vision_dp}, offset=${vision_offset}" + echo "----------------------------------------" + + TOTAL=$((TOTAL + 1)) + + # Set environment variables and run + local start_time=$(date +%s) + + if MIMO_LLM_TP="${llm_tp}" \ + MIMO_LLM_PP="${llm_pp}" \ + MIMO_LLM_DP="${llm_dp}" \ + MIMO_LLM_OFFSET="${llm_offset}" \ + MIMO_VISION_TP="${vision_tp}" \ + MIMO_VISION_PP="${vision_pp}" \ + MIMO_VISION_DP="${vision_dp}" \ + MIMO_VISION_OFFSET="${vision_offset}" \ + python -m torch.distributed.run --nproc_per_node="${NUM_GPUS}" "${TEST_FILE}" 2>&1; then + local end_time=$(date +%s) + local duration=$((end_time - start_time)) + RESULTS+=("PASS|${name}|${duration}s") + PASSED=$((PASSED + 1)) + echo "[PASS] ${name} (${duration}s)" + else + local end_time=$(date +%s) + local duration=$((end_time - start_time)) + RESULTS+=("FAIL|${name}|${duration}s") + FAILED_CONFIGS+=("${name}") + echo "[FAIL] ${name} (${duration}s)" + return 1 + fi + return 0 +} + +# Run tests +if [[ -n "${SINGLE_CONFIG}" ]]; then + # Run single config + found=false + for config in "${CONFIGS[@]}"; do + name="${config%%|*}" + if [[ "${name}" == "${SINGLE_CONFIG}" ]]; then + run_config "${config}" + found=true + break + fi + done + if [[ "${found}" == "false" ]]; then + echo "Error: Config '${SINGLE_CONFIG}' not found. Available configs:" + for config in "${CONFIGS[@]}"; do + echo " - ${config%%|*}" + done + exit 1 + fi +else + # Run all configs - abort on any failure + for config in "${CONFIGS[@]}"; do + if ! run_config "${config}"; then + name="${config%%|*}" + echo "" + echo "==========================================" + echo "FATAL: Config '${name}' failed. Aborting." + echo "==========================================" + exit 1 + fi + done +fi + +# Print summary +echo "" +echo "==========================================" +echo "SUMMARY: ${PASSED}/${TOTAL} passed" +echo "==========================================" +printf "%-6s | %-25s | %s\n" "Status" "Configuration" "Time" +echo "-------|---------------------------|-------" +for result in "${RESULTS[@]}"; do + IFS='|' read -r status name duration <<< "$result" + if [[ "${status}" == "PASS" ]]; then + printf "\033[32m%-6s\033[0m | %-25s | %s\n" "${status}" "${name}" "${duration}" + else + printf "\033[31m%-6s\033[0m | %-25s | %s\n" "${status}" "${name}" "${duration}" + fi +done +echo "==========================================" + +if [[ ${#FAILED_CONFIGS[@]} -gt 0 ]]; then + echo "" + echo "Failed configurations:" + for cfg in "${FAILED_CONFIGS[@]}"; do + echo " - ${cfg}" + done + exit 1 +fi + +echo "" +echo "All tests passed!" +exit 0 diff --git a/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py b/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py new file mode 100644 index 0000000000..3e093e52ca --- /dev/null +++ b/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py @@ -0,0 +1,518 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""End-to-end MIMO checkpoint save→resume round-trip test. + +Validates that MiMo checkpoint loading/resume produces correct train_state +continuity: step, consumed_train_samples, and scheduler state are restored. + +Two-phase test (separate torchrun invocations required): + Phase 1 (save): Train for SAVE_STEPS steps, save checkpoint. + Phase 2 (resume): Resume from checkpoint, train to TOTAL_STEPS, verify continuity. + +Run via wrapper: + bash tests/e2e/mimo/run_mimo_checkpoint_resume.sh +Or manually: + CKPT_DIR=$(mktemp -d) + torchrun --nproc_per_node=8 tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py --phase save --ckpt-dir $CKPT_DIR + torchrun --nproc_per_node=8 tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py --phase resume --ckpt-dir $CKPT_DIR +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys + +import torch +import torch.distributed as dist +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.clip_vit_model import CLIPViTModel +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.optimizer.optimizer_config import OptimizerConfig as MCoreOptimizerConfig +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from megatron.bridge.data.mimo.mock_provider import MockMimoProvider +from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig +from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + SchedulerConfig, + TrainingConfig, +) +from megatron.bridge.training.config import OptimizerConfig as BridgeOptimizerConfig +from megatron.bridge.training.mimo_step import forward_step as mimo_forward_step +from megatron.bridge.training.pretrain_mimo import pretrain_mimo +from megatron.bridge.training.state import GlobalState, TrainState +from megatron.bridge.training.tokenizers.config import TokenizerConfig + + +logger = logging.getLogger(__name__) + +SAVE_STEPS = 5 +TOTAL_STEPS = 10 +_ENCODER_SEQ_LEN = 197 +_SPECIAL_TOKEN_ID = 32000 +_VOCAB_SIZE = 50304 +_SEQ_LENGTH = 256 +_IMG_SIZE = 224 +_PATCH_DIM = 16 + + +# --------------------------------------------------------------------------- +# Model helpers (same as test_mimo_training_e2e.py) +# --------------------------------------------------------------------------- + + +def _make_vision_config() -> TransformerConfig: + cfg = TransformerConfig( + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + ) + cfg.add_bias_linear = True + cfg.add_qkv_bias = True + cfg.hidden_dropout = 0.0 + cfg.attention_dropout = 0.0 + cfg.gated_linear_unit = False + cfg.layernorm_zero_centered_gamma = False + cfg.apply_query_key_layer_scaling = False + cfg.bias_activation_fusion = False + cfg.bias_dropout_fusion = False + cfg.attention_softmax_in_fp32 = True + cfg.normalization = "LayerNorm" + cfg.apply_rope_fusion = False + return cfg + + +def _make_language_config() -> TransformerConfig: + return TransformerConfig( + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + cross_entropy_loss_fusion=True, + ) + + +def _build_model_specs(): + vision_config = _make_vision_config() + language_config = _make_language_config() + + vision_encoder = ModuleSpec( + module=CLIPViTModel, + params={ + "transformer_config": vision_config, + "transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(), + "patch_dim": _PATCH_DIM, + "img_h": _IMG_SIZE, + "img_w": _IMG_SIZE, + }, + ) + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + params={}, + submodules={"encoders": {"clip": vision_encoder}}, + ) + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": language_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": _VOCAB_SIZE, + "max_sequence_length": _SEQ_LENGTH, + }, + ) + return language_model_spec, {"vision": vision_submodule_spec}, {"vision": _SPECIAL_TOKEN_ID} + + +def _build_parallelism_config() -> MimoParallelismConfig: + """Build parallelism config from MIMO_* env vars (set by shell wrapper). + + Env vars (with defaults for 8-GPU TP=4 both): + MIMO_LLM_TP, MIMO_LLM_PP, MIMO_LLM_DP, MIMO_LLM_OFFSET + MIMO_VISION_TP, MIMO_VISION_PP, MIMO_VISION_DP, MIMO_VISION_OFFSET + """ + return MimoParallelismConfig( + module_parallelisms={ + "language": ModuleParallelismConfig( + tensor_model_parallel_size=int(os.environ.get("MIMO_LLM_TP", "4")), + pipeline_model_parallel_size=int(os.environ.get("MIMO_LLM_PP", "1")), + data_parallel_size=int(os.environ.get("MIMO_LLM_DP", "1")), + rank_offset=int(os.environ.get("MIMO_LLM_OFFSET", "0")), + ), + "vision": ModuleParallelismConfig( + tensor_model_parallel_size=int(os.environ.get("MIMO_VISION_TP", "4")), + pipeline_model_parallel_size=int(os.environ.get("MIMO_VISION_PP", "1")), + data_parallel_size=int(os.environ.get("MIMO_VISION_DP", "1")), + rank_offset=int(os.environ.get("MIMO_VISION_OFFSET", "4")), + ), + }, + ) + + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + + +def _build_mock_data_provider() -> MockMimoProvider: + provider = MockMimoProvider( + seq_length=_SEQ_LENGTH, + processor_paths={"vision": "openai/clip-vit-base-patch16"}, + tokenizer_path="gpt2", + special_token_ids={"vision": _SPECIAL_TOKEN_ID}, + encoder_seq_lengths={"vision": _ENCODER_SEQ_LEN}, + modality_configs={"vision": {"type": "image", "width": _IMG_SIZE, "height": _IMG_SIZE}}, + ) + provider.drop_last = True + return provider + + +def _wrap_iter(loader_iter): + for batch in loader_iter: + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda(non_blocking=True) + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, torch.Tensor): + value[k] = v.cuda(non_blocking=True) + elif isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, torch.Tensor): + value[k][kk] = vv.cuda(non_blocking=True) + + mi = batch.get("modality_inputs") + if mi and "vision" in mi: + pv = mi["vision"].get("pixel_values") + if pv is not None: + mi["vision"] = {"clip": {"x": pv.to(torch.bfloat16)}} + + if "loss_mask" not in batch or batch["loss_mask"] is None: + batch["loss_mask"] = torch.ones_like(batch["input_ids"], dtype=torch.float) + + batch["attention_mask"] = None + yield batch + + +def _build_data_iterators(cfg, mimo_infra, *, train_state=None): + """Build data iterators. Accepts optional train_state for resume support.""" + from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders + + if train_state is None: + train_state = TrainState() + + train_samples = cfg.train.train_iters * cfg.train.global_batch_size + train_loader, _, _ = build_mimo_data_loaders( + cfg=cfg, + train_state=train_state, + mimo_provider=cfg.dataset, + train_samples=max(train_samples, 10), + valid_samples=0, + test_samples=0, + ) + train_iter = _wrap_iter(train_loader) if train_loader is not None else None + return train_iter, None + + +# --------------------------------------------------------------------------- +# Config builder +# --------------------------------------------------------------------------- + + +def _build_config( + mimo_provider: MimoModelProvider, + mock_data_provider: MockMimoProvider, + opt_config: BridgeOptimizerConfig, + ckpt_dir: str, + *, + train_iters: int, + save_interval: int, + load_dir: str | None = None, +) -> ConfigContainer: + par_cfg = mimo_provider.mimo_parallelism_config + max_dp = max(p.data_parallel_size for p in par_cfg.module_parallelisms.values()) + + train_cfg = TrainingConfig( + micro_batch_size=1, + global_batch_size=max_dp, + train_iters=train_iters, + ) + train_cfg.num_microbatches = 1 + train_cfg.grad_reduce_in_fp32 = False + train_cfg.overlap_grad_reduce = False + train_cfg.use_distributed_optimizer = True + train_cfg.check_for_nan_in_grad = False + train_cfg.log_interval = 1 + + logger_cfg = LoggerConfig() + logger_cfg.log_interval = 1 + + llm_pp = par_cfg.module_parallelisms["language"].pipeline_model_parallel_size + ckpt_cfg = CheckpointConfig( + save_interval=save_interval, + save=ckpt_dir, + ckpt_format="torch_dist", + # TODO: Re-enable fully_parallel_save for PP>1 after fixing MIMO sharded + # checkpoint access pattern validation for nested DDP language model params. + fully_parallel_save=(llm_pp == 1), + dist_ckpt_optim_fully_reshardable=True, + # MiMo RNG save is not yet supported: each module produces ShardedObject + # with key "rng_state" using module-local PP/TP/DP ranks, causing + # duplicate shard keys across modules. Disable until upstream fix. + save_rng=False, + ) + if load_dir is not None: + ckpt_cfg.load = load_dir + + cfg = ConfigContainer( + train=train_cfg, + model=mimo_provider, + optimizer=opt_config, + scheduler=SchedulerConfig(start_weight_decay=0.0, end_weight_decay=0.0), + dataset=mock_data_provider, + logger=logger_cfg, + tokenizer=TokenizerConfig(), + checkpoint=ckpt_cfg, + ) + cfg.data_parallel_size = max_dp + return cfg + + +# --------------------------------------------------------------------------- +# Phases +# --------------------------------------------------------------------------- + +MARKER_FILE = "resume_marker.json" + + +def _run_phase_save(ckpt_dir: str) -> None: + """Phase 1: Train for SAVE_STEPS steps and save checkpoint.""" + rank = dist.get_rank() + _log(f"Phase SAVE: training for {SAVE_STEPS} steps, saving to {ckpt_dir}") + + language_spec, modality_specs, special_tokens = _build_model_specs() + mimo_provider = MimoModelProvider( + language_model_spec=language_spec, + modality_submodules_spec=modality_specs, + special_token_ids=special_tokens, + mimo_parallelism_config=_build_parallelism_config(), + topology={"vision": ["language"], "language": []}, + use_cpu_initialization=True, + ) + if not hasattr(mimo_provider, "num_moe_experts"): + mimo_provider.num_moe_experts = None + if not hasattr(mimo_provider, "fp8"): + mimo_provider.fp8 = None + + mock_data = _build_mock_data_provider() + bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True) + mcore_opt = MCoreOptimizerConfig( + optimizer="adam", + lr=1e-4, + min_lr=0.0, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + + cfg = _build_config( + mimo_provider, + mock_data, + bridge_opt, + ckpt_dir, + train_iters=SAVE_STEPS, + save_interval=SAVE_STEPS, + ) + + global_state = GlobalState() + + pretrain_mimo( + cfg=cfg, + mimo_provider=mimo_provider, + forward_step_func=mimo_forward_step, + build_data_iterators_fn=_build_data_iterators, + opt_config=mcore_opt, + schedulers={}, + global_state=global_state, + ) + + ts = global_state.train_state + _log(f"Phase SAVE complete: step={ts.step}, consumed_train_samples={ts.consumed_train_samples}") + + if rank == 0: + marker = { + "step": ts.step, + "consumed_train_samples": ts.consumed_train_samples, + "floating_point_operations_so_far": ts.floating_point_operations_so_far, + } + marker_path = os.path.join(ckpt_dir, MARKER_FILE) + with open(marker_path, "w") as f: + json.dump(marker, f) + _log(f"Wrote marker: {marker}") + + dist.barrier() + assert ts.step == SAVE_STEPS, f"Expected step={SAVE_STEPS}, got {ts.step}" + _log("Phase SAVE: PASSED") + + +def _run_phase_resume(ckpt_dir: str) -> None: + """Phase 2: Resume from checkpoint, train to TOTAL_STEPS, verify continuity.""" + rank = dist.get_rank() + _log(f"Phase RESUME: loading from {ckpt_dir}, training to {TOTAL_STEPS} steps") + + marker_path = os.path.join(ckpt_dir, MARKER_FILE) + with open(marker_path, "r") as f: + saved_marker = json.load(f) + _log(f"Loaded marker from phase 1: {saved_marker}") + + language_spec, modality_specs, special_tokens = _build_model_specs() + mimo_provider = MimoModelProvider( + language_model_spec=language_spec, + modality_submodules_spec=modality_specs, + special_token_ids=special_tokens, + mimo_parallelism_config=_build_parallelism_config(), + topology={"vision": ["language"], "language": []}, + use_cpu_initialization=True, + ) + if not hasattr(mimo_provider, "num_moe_experts"): + mimo_provider.num_moe_experts = None + if not hasattr(mimo_provider, "fp8"): + mimo_provider.fp8 = None + + mock_data = _build_mock_data_provider() + bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True) + mcore_opt = MCoreOptimizerConfig( + optimizer="adam", + lr=1e-4, + min_lr=0.0, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + + cfg = _build_config( + mimo_provider, + mock_data, + bridge_opt, + ckpt_dir, + train_iters=TOTAL_STEPS, + save_interval=TOTAL_STEPS, + load_dir=ckpt_dir, + ) + # Save phase used train_iters=SAVE_STEPS, so checkpoint scheduler state + # has lr_decay_steps / wd_incr_steps derived from SAVE_STEPS. Resume uses + # TOTAL_STEPS which produces different values. override_opt_param_scheduler + # tells the scheduler to use the current (resume) values without asserting + # against the checkpoint. Scheduler progress (num_steps) is still restored. + cfg.scheduler.override_opt_param_scheduler = True + + global_state = GlobalState() + + pretrain_mimo( + cfg=cfg, + mimo_provider=mimo_provider, + forward_step_func=mimo_forward_step, + build_data_iterators_fn=_build_data_iterators, + opt_config=mcore_opt, + schedulers={}, + global_state=global_state, + ) + + ts = global_state.train_state + + _log(f"Phase RESUME complete: step={ts.step}, consumed_train_samples={ts.consumed_train_samples}") + + # Verify step continuity + assert ts.step == TOTAL_STEPS, f"Step continuity failed: expected {TOTAL_STEPS}, got {ts.step}" + + # Verify consumed_train_samples did not reset to 0 + assert ts.consumed_train_samples >= saved_marker["consumed_train_samples"], ( + f"consumed_train_samples reset detected: " + f"saved={saved_marker['consumed_train_samples']}, resumed={ts.consumed_train_samples}" + ) + + # Verify consumed_train_samples advanced beyond the saved value + expected_consumed = saved_marker["consumed_train_samples"] + ( + (TOTAL_STEPS - SAVE_STEPS) * cfg.train.global_batch_size + ) + assert ts.consumed_train_samples == expected_consumed, ( + f"consumed_train_samples mismatch: expected {expected_consumed}, got {ts.consumed_train_samples}" + ) + + _log("Phase RESUME: PASSED — all continuity assertions hold") + + +# --------------------------------------------------------------------------- +# Logging + main +# --------------------------------------------------------------------------- + +_rank_log_file = None + + +def _log(msg): + global _rank_log_file + rank = dist.get_rank() if dist.is_initialized() else "?" + line = f"[Rank {rank}] {msg}\n" + if _rank_log_file: + _rank_log_file.write(line) + _rank_log_file.flush() + print(line, end="", flush=True) + + +def main(): + global _rank_log_file + + parser = argparse.ArgumentParser() + parser.add_argument("--phase", required=True, choices=["save", "resume"]) + parser.add_argument("--ckpt-dir", required=True) + args = parser.parse_args() + + dist.init_process_group("nccl") + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + log_dir = "/tmp/mimo_resume_e2e_logs" + os.makedirs(log_dir, exist_ok=True) + _rank_log_file = open(f"{log_dir}/rank_{rank}_{args.phase}.log", "w") + + logging.basicConfig( + level=logging.INFO, + format=f"[Rank {rank}] %(name)s: %(message)s", + handlers=[ + logging.FileHandler(f"{log_dir}/rank_{rank}_{args.phase}_full.log", mode="w"), + logging.StreamHandler(sys.stderr), + ], + force=True, + ) + + if args.phase == "save": + _run_phase_save(args.ckpt_dir) + else: + _run_phase_resume(args.ckpt_dir) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/mimo/test_mimo_training_e2e.py b/tests/e2e/mimo/test_mimo_training_e2e.py new file mode 100644 index 0000000000..7564794a19 --- /dev/null +++ b/tests/e2e/mimo/test_mimo_training_e2e.py @@ -0,0 +1,382 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""End-to-end MIMO training test. + +Exercises the full training loop: pretrain_mimo -> setup_mimo -> train_mimo +on 8 GPUs with synthetic data using the real data pipeline. +LLM on ranks 0-3 (TP=4), vision encoder on ranks 4-7 (TP=4). + +Run: + torchrun --nproc_per_node=8 tests/e2e/mimo/test_mimo_training_e2e.py +""" + +from __future__ import annotations + +import logging +import os +import sys + +import torch +import torch.distributed as dist +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.clip_vit_model import CLIPViTModel +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + + +def _make_vision_config() -> TransformerConfig: + cfg = TransformerConfig( + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + ) + cfg.add_bias_linear = True + cfg.add_qkv_bias = True + cfg.hidden_dropout = 0.0 + cfg.attention_dropout = 0.0 + cfg.gated_linear_unit = False + cfg.layernorm_zero_centered_gamma = False + cfg.apply_query_key_layer_scaling = False + cfg.bias_activation_fusion = False + cfg.bias_dropout_fusion = False + cfg.attention_softmax_in_fp32 = True + cfg.normalization = "LayerNorm" + cfg.apply_rope_fusion = False + return cfg + + +def _make_language_config() -> TransformerConfig: + return TransformerConfig( + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + cross_entropy_loss_fusion=True, + ) + + +_ENCODER_SEQ_LEN = 197 # (224/16)^2 = 196 patches + 1 class token +_SPECIAL_TOKEN_ID = 32000 +_VOCAB_SIZE = 50304 +_SEQ_LENGTH = 256 +_IMG_SIZE = 224 +_PATCH_DIM = 16 + + +def _build_model_specs(): + """Return (language_model_spec, modality_submodules_spec, special_token_ids).""" + vision_config = _make_vision_config() + language_config = _make_language_config() + + vision_encoder = ModuleSpec( + module=CLIPViTModel, + params={ + "transformer_config": vision_config, + "transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(), + "patch_dim": _PATCH_DIM, + "img_h": _IMG_SIZE, + "img_w": _IMG_SIZE, + }, + ) + + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + params={}, + submodules={ + "encoders": {"clip": vision_encoder}, + }, + ) + + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": language_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": _VOCAB_SIZE, + "max_sequence_length": _SEQ_LENGTH, + }, + ) + + modality_submodules_spec = {"vision": vision_submodule_spec} + special_token_ids = {"vision": _SPECIAL_TOKEN_ID} + return language_model_spec, modality_submodules_spec, special_token_ids + + +from megatron.bridge.models.mimo.mimo_config import ( + MimoParallelismConfig, + ModuleParallelismConfig, +) + + +def _build_parallelism_config() -> MimoParallelismConfig: + return MimoParallelismConfig( + module_parallelisms={ + "language": ModuleParallelismConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + data_parallel_size=1, + rank_offset=0, + ), + "vision": ModuleParallelismConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + data_parallel_size=1, + rank_offset=4, + ), + }, + ) + + +from megatron.bridge.data.mimo.mock_provider import MockMimoProvider + + +def _build_mock_data_provider() -> MockMimoProvider: + """Build a MockMimoProvider with HF processor (CLIP) and tokenizer (GPT-2).""" + provider = MockMimoProvider( + seq_length=_SEQ_LENGTH, + processor_paths={"vision": "openai/clip-vit-base-patch16"}, + tokenizer_path="gpt2", + special_token_ids={"vision": _SPECIAL_TOKEN_ID}, + encoder_seq_lengths={"vision": _ENCODER_SEQ_LEN}, + modality_configs={ + "vision": {"type": "image", "width": _IMG_SIZE, "height": _IMG_SIZE}, + }, + ) + provider.drop_last = True + return provider + + +def _wrap_iter(loader_iter): + """Adapt data-loader batches for the MIMO model. + + Transforms: + - modality_inputs["vision"]["pixel_values"] -> modality_inputs["vision"]["clip"]["x"] + so VisionModalitySubmodules.encode() finds the "clip" encoder key and + CLIPViTModel.forward() receives ``x=...``. + - Sets attention_mask=None (not needed for this test). + - Generates loss_mask if not present. + """ + for batch in loader_iter: + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda(non_blocking=True) + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, torch.Tensor): + value[k] = v.cuda(non_blocking=True) + elif isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, torch.Tensor): + value[k][kk] = vv.cuda(non_blocking=True) + + mi = batch.get("modality_inputs") + if mi and "vision" in mi: + pv = mi["vision"].get("pixel_values") + if pv is not None: + mi["vision"] = {"clip": {"x": pv.to(torch.bfloat16)}} + + if "loss_mask" not in batch or batch["loss_mask"] is None: + batch["loss_mask"] = torch.ones_like(batch["input_ids"], dtype=torch.float) + + batch["attention_mask"] = None + + yield batch + + +def _build_data_iterators(cfg, mimo_infra): + """Build data iterators compatible with setup_mimo's build_data_iterators_fn.""" + from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders + from megatron.bridge.training.state import TrainState + + train_state = TrainState() + + train_samples = cfg.train.train_iters * cfg.train.global_batch_size + valid_samples = 0 + test_samples = 0 + + train_loader, valid_loader, _ = build_mimo_data_loaders( + cfg=cfg, + train_state=train_state, + mimo_provider=cfg.dataset, + train_samples=max(train_samples, 10), + valid_samples=valid_samples, + test_samples=test_samples, + ) + + train_iter = _wrap_iter(train_loader) if train_loader is not None else None + valid_iter = None + return train_iter, valid_iter + + +from megatron.core.optimizer.optimizer_config import OptimizerConfig as MCoreOptimizerConfig + +from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + SchedulerConfig, + TrainingConfig, +) +from megatron.bridge.training.config import OptimizerConfig as BridgeOptimizerConfig +from megatron.bridge.training.tokenizers.config import TokenizerConfig + + +def _build_config( + mimo_provider: MimoModelProvider, + mock_data_provider: MockMimoProvider, + opt_config: BridgeOptimizerConfig, + log_interval: int = 1, + wandb_project: str | None = None, + wandb_exp_name: str | None = None, + wandb_entity: str | None = None, + wandb_save_dir: str | None = None, +) -> ConfigContainer: + train_cfg = TrainingConfig( + micro_batch_size=1, + global_batch_size=1, + train_iters=2, + ) + train_cfg.num_microbatches = 1 + train_cfg.grad_reduce_in_fp32 = False + train_cfg.overlap_grad_reduce = False + train_cfg.use_distributed_optimizer = True + train_cfg.check_for_nan_in_grad = False + train_cfg.log_interval = log_interval + + logger_cfg = LoggerConfig() + logger_cfg.log_interval = log_interval + logger_cfg.wandb_project = wandb_project + logger_cfg.wandb_exp_name = wandb_exp_name + logger_cfg.wandb_entity = wandb_entity + logger_cfg.wandb_save_dir = wandb_save_dir + logger_cfg.tensorboard_dir = os.path.join(wandb_save_dir or "/tmp/tb_logs", "tb_logs") if wandb_project else None + + cfg = ConfigContainer( + train=train_cfg, + model=mimo_provider, + optimizer=opt_config, + scheduler=SchedulerConfig(start_weight_decay=0.0, end_weight_decay=0.0), + dataset=mock_data_provider, + logger=logger_cfg, + tokenizer=TokenizerConfig(), + checkpoint=CheckpointConfig(), + ) + cfg.data_parallel_size = 1 + return cfg + + +from megatron.bridge.training.mimo_step import forward_step as mimo_forward_step +from megatron.bridge.training.pretrain_mimo import pretrain_mimo + + +_rank_log_file = None + + +def _log(msg): + """Write with rank prefix to per-rank log file and flush.""" + global _rank_log_file + rank = dist.get_rank() if dist.is_initialized() else "?" + line = f"[Rank {rank}] {msg}\n" + if _rank_log_file: + _rank_log_file.write(line) + _rank_log_file.flush() + print(line, end="", flush=True) + + +def main(): + global _rank_log_file + + dist.init_process_group("nccl") + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + log_dir = "/tmp/mimo_e2e_logs" + os.makedirs(log_dir, exist_ok=True) + _rank_log_file = open(f"{log_dir}/rank_{rank}.log", "w") + + logging.basicConfig( + level=logging.INFO, + format=f"[Rank {rank}] %(name)s: %(message)s", + handlers=[ + logging.FileHandler(f"{log_dir}/rank_{rank}_full.log", mode="w"), + logging.StreamHandler(sys.stderr), + ], + force=True, + ) + logging.getLogger("megatron.core.pipeline_parallel.bridge_communicator").setLevel(logging.DEBUG) + logging.getLogger("megatron.core.pipeline_parallel.multimodule_communicator").setLevel(logging.DEBUG) + + _log(f"distributed initialized (world_size={dist.get_world_size()})") + + _log("building model specs") + language_model_spec, modality_submodules_spec, special_token_ids = _build_model_specs() + mimo_parallelism_config = _build_parallelism_config() + + mimo_provider = MimoModelProvider( + language_model_spec=language_model_spec, + modality_submodules_spec=modality_submodules_spec, + special_token_ids=special_token_ids, + mimo_parallelism_config=mimo_parallelism_config, + topology={"vision": ["language"], "language": []}, + use_cpu_initialization=True, + ) + if not hasattr(mimo_provider, "num_moe_experts"): + mimo_provider.num_moe_experts = None + + _log("building data provider") + mock_data_provider = _build_mock_data_provider() + + mcore_opt_config = MCoreOptimizerConfig( + optimizer="adam", + lr=1e-4, + min_lr=0.0, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + bridge_opt_config = BridgeOptimizerConfig(lr=1e-4) + + _log("building config") + cfg = _build_config( + mimo_provider, + mock_data_provider, + bridge_opt_config, + wandb_project=os.environ.get("WANDB_PROJECT", "Megatron-Bridge-MIMO"), + wandb_exp_name=os.environ.get("WANDB_EXP_NAME", "mimo-e2e-test"), + wandb_entity=os.environ.get("WANDB_ENTITY"), + wandb_save_dir=os.environ.get("WANDB_SAVE_DIR", "/tmp/wandb"), + ) + + _log("launching pretrain_mimo") + pretrain_mimo( + cfg=cfg, + mimo_provider=mimo_provider, + forward_step_func=mimo_forward_step, + build_data_iterators_fn=_build_data_iterators, + opt_config=mcore_opt_config, + schedulers={}, + ) + + _log("PASSED") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/mimo/test_mimo_training_llava.py b/tests/e2e/mimo/test_mimo_training_llava.py new file mode 100644 index 0000000000..5d82ea1323 --- /dev/null +++ b/tests/e2e/mimo/test_mimo_training_llava.py @@ -0,0 +1,572 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + + +from __future__ import annotations + +import argparse +import logging +import os +import random +import sys + +import numpy as np +import torch +import torch.distributed as dist +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.clip_vit_model import CLIPViTModel +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + + +# --------------------------------------------------------------------------- +# LLaVA model configs (Vicuna-7B + CLIP ViT-L/14 + MLP projection) +# --------------------------------------------------------------------------- + +IMAGE_SPECIAL_TOKEN_ID = 32000 +VOCAB_SIZE = 32256 +CLIP_OUTPUT_DIM = 1024 # CLIP ViT-L/14 hidden size +MAX_SEQ_LENGTH = 4096 +_IMG_SIZE = 336 +_PATCH_DIM = 14 +# CLIP ViT-L/14 @ 336×336: (336/14)^2 = 576 patches + 1 class token = 577 +_ENCODER_SEQ_LEN = 577 + + +def _make_vision_config() -> TransformerConfig: + """CLIP ViT-L/14 vision encoder config.""" + cfg = TransformerConfig( + num_layers=24, + hidden_size=1024, + ffn_hidden_size=4096, + num_attention_heads=16, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", + ) + cfg.add_bias_linear = True + cfg.add_qkv_bias = True + cfg.hidden_dropout = 0.0 + cfg.attention_dropout = 0.0 + cfg.gated_linear_unit = False + cfg.layernorm_zero_centered_gamma = False + cfg.apply_query_key_layer_scaling = False + cfg.bias_activation_fusion = False + cfg.bias_dropout_fusion = False + cfg.attention_softmax_in_fp32 = True + cfg.normalization = "LayerNorm" + cfg.apply_rope_fusion = False + return cfg + + +def _make_language_config() -> TransformerConfig: + """Vicuna-7B language model config (same arch as Llama-7B).""" + cfg = TransformerConfig( + num_layers=32, + hidden_size=4096, + num_attention_heads=32, + use_cpu_initialization=True, + ) + + cfg.ffn_hidden_size = 11008 + cfg.activation_func = torch.nn.functional.silu + cfg.gated_linear_unit = True + + cfg.normalization = "RMSNorm" + cfg.rms_norm_eps = 1e-5 + + cfg.position_embedding_type = "rope" + cfg.rotary_base = 10000 + cfg.rotary_percent = 1.0 + + cfg.seq_length = MAX_SEQ_LENGTH + cfg.max_position_embeddings = MAX_SEQ_LENGTH + + cfg.attention_dropout = 0.0 + cfg.hidden_dropout = 0.0 + + cfg.num_query_groups = 32 + cfg.add_bias_linear = False + cfg.untie_embeddings_and_output_weights = False + + cfg.bias_activation_fusion = True + cfg.masked_softmax_fusion = True + cfg.persist_layer_norm = True + cfg.bias_dropout_fusion = True + cfg.apply_rope_fusion = True + + cfg.pipeline_dtype = torch.bfloat16 + cfg.bf16 = True + cfg.cross_entropy_loss_fusion = True + cfg.variable_seq_lengths = True + + return cfg + + +def _make_projection_config(hidden_size: int = 4096) -> TransformerConfig: + """Vision→language projection MLP config.""" + cfg = TransformerConfig(num_layers=1, hidden_size=hidden_size, num_attention_heads=1) + cfg.ffn_hidden_size = 4096 + cfg.bias_activation_fusion = True + cfg.add_bias_linear = True + cfg.activation_func = torch.nn.functional.gelu + return cfg + + +def _build_model_specs(): + """Return (language_model_spec, modality_submodules_spec, special_token_ids).""" + vision_config = _make_vision_config() + language_config = _make_language_config() + projection_config = _make_projection_config(hidden_size=language_config.hidden_size) + + # CLIP ViT-L/14 encoder + vision_encoder = ModuleSpec( + module=CLIPViTModel, + params={ + "transformer_config": vision_config, + "transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(), + "patch_dim": _PATCH_DIM, + "img_h": _IMG_SIZE, + "img_w": _IMG_SIZE, + }, + ) + + # Vision→language projection MLP + vision_projection = ModuleSpec( + module=MultimodalProjector, + params={ + "config": projection_config, + "submodules": MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + "projector_type": "mlp", + "input_size": CLIP_OUTPUT_DIM, + }, + ) + + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + params={}, + submodules={ + "encoders": {"clip": vision_encoder}, + "input_projections": [vision_projection], + }, + ) + + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": language_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": VOCAB_SIZE, + "max_sequence_length": MAX_SEQ_LENGTH, + "position_embedding_type": "rope", + }, + ) + + modality_submodules_spec = {"images": vision_submodule_spec} + special_token_ids = {"images": IMAGE_SPECIAL_TOKEN_ID} + return language_model_spec, modality_submodules_spec, special_token_ids + + +# --------------------------------------------------------------------------- +# Parallelism config (8 GPUs: TP=4 for both modules) +# --------------------------------------------------------------------------- + +from megatron.bridge.models.mimo.mimo_config import ( + MimoParallelismConfig, + ModuleParallelismConfig, +) + + +def _build_parallelism_config() -> MimoParallelismConfig: + return MimoParallelismConfig( + module_parallelisms={ + "language": ModuleParallelismConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + data_parallel_size=1, + rank_offset=0, + ), + "images": ModuleParallelismConfig( + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + data_parallel_size=1, + rank_offset=4, + ), + }, + ) + + +# --------------------------------------------------------------------------- +# Data pipeline +# --------------------------------------------------------------------------- + +from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider + + +def _llava_preprocess(example, dataset_root): + """Convert LLaVA conversations format to plain text and resolve image paths.""" + conversations = example.get("conversations", []) + text_parts = [turn.get("value", "") for turn in conversations] + example["text"] = " ".join(text_parts).replace("", "").strip() + # Resolve relative image paths to absolute paths + if "image" in example and example["image"] and not os.path.isabs(example["image"]): + example["image"] = os.path.join(dataset_root, example["image"]) + return example + + +def _build_hf_data_provider(dataset_root: str) -> HFMimoDatasetProvider: + """Build an HFMimoDatasetProvider for liuhaotian/LLaVA-Pretrain.""" + provider = HFMimoDatasetProvider( + seq_length=MAX_SEQ_LENGTH, + hf_dataset_path=dataset_root, + hf_data_files="blip_laion_cc_sbu_558k.json", + hf_tokenizer_path="llava-hf/llava-1.5-7b-hf", + processor_paths={"images": "openai/clip-vit-large-patch14-336"}, + special_token_ids={"images": IMAGE_SPECIAL_TOKEN_ID}, + encoder_seq_lengths={"images": _ENCODER_SEQ_LEN}, + modality_columns={"images": "image"}, + text_column="text", + train_split="train", + preprocess_fn=lambda example: _llava_preprocess(example, dataset_root), + ) + provider.drop_last = True + + return provider + + +def _wrap_iter(loader_iter): + """Adapt data-loader batches for the MIMO model. + + Transforms: + - modality_inputs["images"]["pixel_values"] → modality_inputs["images"]["clip"]["x"] + so VisionModalitySubmodules.encode() finds the "clip" encoder key and + CLIPViTModel.forward() receives ``x=...``. + - Sets attention_mask=None (not needed for this test). + - Generates loss_mask if not present. + """ + for batch in loader_iter: + # Move tensors to GPU + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda(non_blocking=True) + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, torch.Tensor): + value[k] = v.cuda(non_blocking=True) + elif isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, torch.Tensor): + value[k][kk] = vv.cuda(non_blocking=True) + + # Rewrap modality_inputs: {"images": {"pixel_values": t}} → {"images": {"clip": {"x": t}}} + # Cast to bfloat16 to match model weights + mi = batch.get("modality_inputs") + if mi and "images" in mi: + pv = mi["images"].get("pixel_values") + if pv is not None: + mi["images"] = {"clip": {"x": pv.to(torch.bfloat16)}} + + # Ensure loss_mask exists + if "loss_mask" not in batch or batch["loss_mask"] is None: + batch["loss_mask"] = torch.ones_like(batch["input_ids"], dtype=torch.float) + + # Drop attention_mask (not needed) + batch["attention_mask"] = None + + yield batch + + +def _build_data_iterators(cfg, mimo_infra): + """Build data iterators compatible with setup_mimo's build_data_iterators_fn. + + Signature: (cfg, mimo_infra) -> (train_iter, valid_iter) + Uses build_mimo_data_loaders which auto-detects MIMO path via cfg.model. + """ + from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders + from megatron.bridge.training.state import TrainState + + train_state = TrainState() + + # Compute sample counts + train_samples = cfg.train.train_iters * cfg.train.global_batch_size + valid_samples = 0 + test_samples = 0 + + train_loader, _, _ = build_mimo_data_loaders( + cfg=cfg, + train_state=train_state, + mimo_provider=cfg.dataset, + train_samples=max(train_samples, 10), # min 10 samples + valid_samples=valid_samples, + test_samples=test_samples, + ) + + train_iter = _wrap_iter(train_loader) if train_loader is not None else None + valid_iter = None + return train_iter, valid_iter + + +# --------------------------------------------------------------------------- +# Config assembly +# --------------------------------------------------------------------------- + +from megatron.core.optimizer.optimizer_config import OptimizerConfig as MCoreOptimizerConfig + +from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + SchedulerConfig, + TrainingConfig, +) +from megatron.bridge.training.config import OptimizerConfig as BridgeOptimizerConfig +from megatron.bridge.training.tokenizers.config import TokenizerConfig + + +def _build_config( + mimo_provider: MimoModelProvider, + data_provider: HFMimoDatasetProvider, + opt_config: BridgeOptimizerConfig, + micro_batch_size: int = 1, + global_batch_size: int = 1, + train_iters: int = 2, + log_interval: int = 1, + wandb_project: str | None = None, + wandb_exp_name: str | None = None, + wandb_entity: str | None = None, + wandb_save_dir: str | None = None, + lr_warmup_iters: int = 0, +) -> ConfigContainer: + train_cfg = TrainingConfig( + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + train_iters=train_iters, + ) + # Runtime patches for MIMO + train_cfg.num_microbatches = 1 + train_cfg.grad_reduce_in_fp32 = False + train_cfg.overlap_grad_reduce = False + train_cfg.use_distributed_optimizer = True + train_cfg.check_for_nan_in_grad = False + train_cfg.log_interval = log_interval + + logger_cfg = LoggerConfig() + logger_cfg.log_timers_to_tensorboard = True + logger_cfg.log_interval = log_interval + logger_cfg.wandb_project = wandb_project + logger_cfg.wandb_exp_name = wandb_exp_name + logger_cfg.wandb_entity = wandb_entity + logger_cfg.wandb_save_dir = wandb_save_dir + logger_cfg.tensorboard_dir = os.path.join(wandb_save_dir or "/tmp/tb_logs", "tb_logs") if wandb_project else None + + scheduler_cfg = SchedulerConfig( + lr_decay_style="cosine", + lr_warmup_iters=lr_warmup_iters, + lr_warmup_init=opt_config.min_lr, + start_weight_decay=opt_config.weight_decay, + end_weight_decay=opt_config.weight_decay, + ) + + cfg = ConfigContainer( + train=train_cfg, + model=mimo_provider, + optimizer=opt_config, + scheduler=scheduler_cfg, + dataset=data_provider, + logger=logger_cfg, + tokenizer=TokenizerConfig(), + checkpoint=CheckpointConfig(), + ) + cfg.data_parallel_size = 1 + return cfg + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +from megatron.bridge.training.mimo_step import forward_step as mimo_forward_step +from megatron.bridge.training.pretrain_mimo import pretrain_mimo + + +_rank_log_file = None + + +def _log(msg): + """Write with rank prefix to per-rank log file and flush.""" + global _rank_log_file + rank = dist.get_rank() if dist.is_initialized() else "?" + line = f"[Rank {rank}] {msg}\n" + if _rank_log_file: + _rank_log_file.write(line) + _rank_log_file.flush() + print(line, end="", flush=True) + + +def parse_args(): + parser = argparse.ArgumentParser(description="MIMO LLaVA training") + parser.add_argument("--micro-batch-size", type=int, default=1, help="Micro batch size per GPU") + parser.add_argument("--global-batch-size", type=int, default=1, help="Global batch size across all GPUs") + parser.add_argument("--train-iters", type=int, default=2, help="Number of training iterations") + parser.add_argument("--min-lr", type=float, default=2.0e-5) + parser.add_argument("--adam-beta1", type=float, default=0.9) + parser.add_argument("--adam-beta2", type=float, default=0.95) + parser.add_argument("--clip-grad", type=float, default=1.0) + parser.add_argument("--log-interval", type=int, default=1) + parser.add_argument("--checkpoint-interval", type=int, default=None, help="Checkpoint save interval (iterations)") + parser.add_argument("--checkpoint-dir", type=str, default=None, help="Checkpoint output directory") + parser.add_argument("--load-checkpoint", type=str, default=None, help="Checkpoint directory to resume from") + parser.add_argument("--lr", type=float, default=2e-5) + parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--wandb-project", type=str, default="Megatron-Bridge-MIMO", help="W&B project name") + parser.add_argument("--wandb-exp-name", type=str, default="mimo-llava-e2e-test", help="W&B experiment name") + parser.add_argument("--wandb-entity", type=str, default=None, help="W&B entity") + parser.add_argument("--wandb-save-dir", type=str, default="/tmp/wandb", help="W&B save directory") + parser.add_argument( + "--lr-warmup-iters", type=int, default=20, help="Number of iterations to linearly warmup learning rate" + ) + parser.add_argument("--dataset-root", type=str, required=True, help="Root directory of the LLaVA-Pretrain dataset") + return parser.parse_args() + + +def main(): + global _rank_log_file + + args = parse_args() + + # 1. Initialize distributed first so we know rank + dist.init_process_group("nccl") + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # Seed all RNGs for reproducible weight initialization + seed = 42 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Open per-rank log file + log_dir = os.environ.get("MIMO_LOG_DIR", "/tmp/mimo_llava_logs") + os.makedirs(log_dir, exist_ok=True) + _rank_log_file = open(f"{log_dir}/rank_{rank}.log", "w") + + logging.basicConfig( + level=logging.INFO, + format=f"[Rank {rank}] %(name)s: %(message)s", + handlers=[logging.FileHandler(f"{log_dir}/rank_{rank}_full.log", mode="w"), logging.StreamHandler(sys.stderr)], + force=True, + ) + # Enable debug logging for bridge communicator to trace P2P ops + logging.getLogger("megatron.core.pipeline_parallel.bridge_communicator").setLevel(logging.DEBUG) + logging.getLogger("megatron.core.pipeline_parallel.multimodule_communicator").setLevel(logging.DEBUG) + + _log(f"distributed initialized (world_size={dist.get_world_size()})") + + # No parallel_state.initialize_model_parallel() — MIMO manages its own + # parallelism via HyperCommGrids and pg_collections. Float16Module is + # skipped (direct bf16 cast), and cross_entropy_loss_fusion=True ensures + # the fused CE path uses pg_collection.tp instead of global parallel_state. + + # 2. Build model provider + _log("building model specs") + language_model_spec, modality_submodules_spec, special_token_ids = _build_model_specs() + mimo_parallelism_config = _build_parallelism_config() + + mimo_provider = MimoModelProvider( + language_model_spec=language_model_spec, + modality_submodules_spec=modality_submodules_spec, + special_token_ids=special_token_ids, + mimo_parallelism_config=mimo_parallelism_config, + topology={"images": ["language"], "language": []}, + use_cpu_initialization=True, + bf16=True, + ) + # Patch: training_log accesses config.model.num_moe_experts + if not hasattr(mimo_provider, "num_moe_experts"): + mimo_provider.num_moe_experts = None + + # 4. Build data provider + _log("building data provider") + data_provider = _build_hf_data_provider(args.dataset_root) + + # 5. Build optimizer configs + # MCore OptimizerConfig (with __post_init__) for get_mimo_optimizer + _log("building optimizer configs") + print_rank_0 = lambda msg: _log(msg) if dist.get_rank() == 0 else None + print_rank_0( + f"Optimizer config: lr={args.lr}, min_lr={args.min_lr}, weight_decay={args.weight_decay}, " + f"adam_beta1={args.adam_beta1}, adam_beta2={args.adam_beta2}, clip_grad={args.clip_grad}" + ) + mcore_opt_config = MCoreOptimizerConfig( + optimizer="adam", + lr=args.lr, + min_lr=args.min_lr, + weight_decay=args.weight_decay, + adam_beta1=args.adam_beta1, + adam_beta2=args.adam_beta2, + clip_grad=args.clip_grad, + bf16=True, + use_distributed_optimizer=True, + ) + # Bridge OptimizerConfig (deferred post_init) for ConfigContainer + bridge_opt_config = BridgeOptimizerConfig(lr=args.lr, min_lr=args.min_lr, use_distributed_optimizer=True) + + # 6. Build config container + _log("building config") + cfg = _build_config( + mimo_provider, + data_provider, + bridge_opt_config, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + train_iters=args.train_iters, + log_interval=args.log_interval, + wandb_project=args.wandb_project, + wandb_exp_name=args.wandb_exp_name, + wandb_entity=args.wandb_entity, + wandb_save_dir=args.wandb_save_dir, + lr_warmup_iters=args.lr_warmup_iters, + ) + + # Configure checkpointing from CLI args + if args.checkpoint_interval is not None: + cfg.checkpoint.save_interval = args.checkpoint_interval + if args.checkpoint_dir is not None: + cfg.checkpoint.save = args.checkpoint_dir + if args.load_checkpoint is not None: + cfg.checkpoint.load = args.load_checkpoint + + # 7. Run training + _log("launching pretrain_mimo") + pretrain_mimo( + cfg=cfg, + mimo_provider=mimo_provider, + forward_step_func=mimo_forward_step, + build_data_iterators_fn=_build_data_iterators, + opt_config=mcore_opt_config, + ) + + _log("PASSED") + + # 8. Cleanup + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/training/mimo/test_mimo_checkpointing.py b/tests/unit_tests/training/mimo/test_mimo_checkpointing.py new file mode 100644 index 0000000000..f7381c2f49 --- /dev/null +++ b/tests/unit_tests/training/mimo/test_mimo_checkpointing.py @@ -0,0 +1,1168 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MiMo checkpoint saving and loading wiring. + +Tests validate that MiMo training correctly uses shared checkpoint helpers +(save_checkpoint_and_time, checkpoint_and_decide_exit, load_checkpoint) with +the right arguments, without actually saving/loading checkpoints. +""" + +from __future__ import annotations + +import inspect +import time +from types import SimpleNamespace +from typing import Any, Dict +from unittest.mock import MagicMock, Mock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_scheduler_mock() -> MagicMock: + """Create a scheduler mock that supports param_groups[0] access.""" + sched = MagicMock() + sched.optimizer.param_groups = [{"lr": 1e-4}] + sched.get_lr.return_value = 1e-4 + return sched + + +def _make_mimo_infra(*, num_active_pgs: int = 1) -> Mock: + """Create a mock MimoModelInfra with the given number of active PG collections.""" + infra = Mock() + pgs: Dict[str, Any] = {} + for i in range(num_active_pgs): + pgs[f"module_{i}"] = Mock() + infra.pg_collections = pgs + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + return infra + + +def _make_global_state( + *, + save_interval: int | None = 10, + save_dir: str | None = "/tmp/ckpt", + train_iters: int = 100, + step: int = 0, + non_persistent_save_interval: int | None = None, + exit_signal_handler: bool = False, + exit_duration_in_mins: float | None = None, + exit_interval: int | None = None, +) -> SimpleNamespace: + """Create a minimal GlobalState-like namespace for train_mimo tests.""" + timer_handle = Mock() + timers = Mock(return_value=timer_handle) + timers.log = Mock() + + state = SimpleNamespace( + timers=timers, + energy_monitor=None, + cfg=SimpleNamespace( + train=SimpleNamespace( + train_iters=train_iters, + micro_batch_size=1, + exit_signal_handler=exit_signal_handler, + exit_duration_in_mins=exit_duration_in_mins, + exit_interval=exit_interval, + eval_interval=None, + ), + dataset=SimpleNamespace(seq_length=128), + checkpoint=SimpleNamespace( + save=save_dir, + save_interval=save_interval, + non_persistent_save_interval=non_persistent_save_interval, + async_save=False, + ), + ddp=SimpleNamespace(use_megatron_fsdp=False, overlap_param_gather=True), + optimizer=SimpleNamespace(use_distributed_optimizer=True), + model=SimpleNamespace(fp8=None, seq_length=128), + logger=SimpleNamespace( + log_progress=False, + skip_train_metrics_log=True, + timing_log_level=0, + timing_log_option="minmax", + log_timers_to_tensorboard=False, + log_interval=1, + ), + profiling=None, + data_parallel_size=1, + ), + train_state=SimpleNamespace( + step=step, + consumed_train_samples=0, + floating_point_operations_so_far=0, + ), + start_time=time.time(), + signal_handler=Mock(), + nvrx_straggler_manager=None, + tensorboard_logger=None, + wandb_logger=None, + ) + state.signal_handler.signals_received.return_value = [] + return state + + +# --------------------------------------------------------------------------- +# Tests: pg_collection forwarding in shared helpers +# --------------------------------------------------------------------------- + + +class TestPgCollectionForwarding: + """Verify save_checkpoint_and_time and checkpoint_and_decide_exit + forward pg_collection to save_checkpoint.""" + + @patch("megatron.bridge.training.train.force_param_sync") + @patch("megatron.bridge.training.train.should_disable_forward_pre_hook", return_value=False) + @patch("megatron.bridge.training.train.save_checkpoint") + def test_save_checkpoint_and_time_forwards_pg_collection( + self, + mock_save_checkpoint, + mock_should_disable, + mock_force_param_sync, + ): + from megatron.bridge.training.train import save_checkpoint_and_time + + state = _make_global_state() + pg = Mock() + + save_checkpoint_and_time( + state=state, + model=[Mock()], + optimizer=Mock(), + opt_param_scheduler=Mock(), + num_floating_point_operations_so_far=0, + checkpointing_context={}, + pg_collection=pg, + ) + + _, kwargs = mock_save_checkpoint.call_args + assert kwargs["pg_collection"] is pg + + @patch("megatron.bridge.training.train.force_param_sync") + @patch("megatron.bridge.training.train.should_disable_forward_pre_hook", return_value=False) + @patch("megatron.bridge.training.train.save_checkpoint") + def test_save_checkpoint_and_time_defaults_pg_collection_to_none( + self, + mock_save_checkpoint, + mock_should_disable, + mock_force_param_sync, + ): + from megatron.bridge.training.train import save_checkpoint_and_time + + state = _make_global_state() + + save_checkpoint_and_time( + state=state, + model=[Mock()], + optimizer=Mock(), + opt_param_scheduler=Mock(), + num_floating_point_operations_so_far=0, + checkpointing_context={}, + ) + + _, kwargs = mock_save_checkpoint.call_args + assert kwargs["pg_collection"] is None + + @patch("megatron.bridge.training.train.save_checkpoint_and_time") + @patch("megatron.bridge.training.train.barrier_and_log") + @patch("megatron.bridge.training.train.check_nvrx_straggler_detection", return_value=False) + def test_checkpoint_and_decide_exit_forwards_pg_collection( + self, + mock_check_nvrx, + mock_barrier_log, + mock_save_and_time, + ): + from megatron.bridge.training.train import checkpoint_and_decide_exit + + state = _make_global_state(save_interval=5, step=10) + pg = Mock() + + checkpoint_and_decide_exit( + state=state, + model=[Mock()], + optimizer=Mock(), + opt_param_scheduler=Mock(), + num_floating_point_operations_so_far=0, + checkpointing_context={}, + train_data_iterator=None, + pg_collection=pg, + ) + + _, kwargs = mock_save_and_time.call_args + assert kwargs["pg_collection"] is pg + + +# --------------------------------------------------------------------------- +# Tests: pretrain_mimo setup wiring +# --------------------------------------------------------------------------- + + +class TestPretrainMimoSetup: + """Verify pretrain_mimo properly initializes checkpointing runtime.""" + + @patch("megatron.bridge.training.pretrain_mimo.init_checkpointing_context") + @patch("megatron.bridge.training.pretrain_mimo.MultiModulePipelineCommunicator") + @patch("megatron.bridge.training.pretrain_mimo.get_model_config") + @patch("megatron.bridge.training.pretrain_mimo.validate_no_stub_ranks") + @patch("megatron.bridge.training.pretrain_mimo.build_pg_collection_for_schedule") + @patch("megatron.bridge.training.pretrain_mimo.get_module_to_grid_tuple") + @patch("torch.distributed.all_reduce") + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_world_size", return_value=2) + def test_setup_mimo_initializes_checkpointing_context( + self, + mock_world_size, + mock_get_rank, + mock_all_reduce, + mock_get_grid, + mock_build_pg, + mock_validate, + mock_get_config, + mock_communicator, + mock_init_ctx, + ): + from megatron.bridge.training.pretrain_mimo import setup_mimo + + mock_init_ctx.return_value = {"test": "context"} + + model_config = Mock() + model_config.pipeline_dtype = None + model_config.bf16 = True + mock_get_config.return_value = model_config + + global_state = Mock() + global_state.start_time = time.time() + global_state.cfg = None + + cfg = Mock() + cfg.checkpoint = Mock() + cfg.train = Mock() + cfg.train.grad_reduce_in_fp32 = False + cfg.train.overlap_grad_reduce = True + cfg.train.use_distributed_optimizer = False + cfg.train.check_for_nan_in_grad = False + cfg.model = Mock() + cfg.model.fp16 = False + cfg.model.bf16 = True + + provider = Mock() + infra = Mock() + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + infra.pg_collections = {"language": Mock()} + provider.build_infra.return_value = infra + provider.provide_distributed_model.return_value = [Mock()] + + result = setup_mimo(cfg, provider, global_state=global_state) + + mock_init_ctx.assert_called_once_with(cfg.checkpoint) + global_state.initialize_async_checkpoint_worker.assert_called_once() + assert result.checkpointing_context == {"test": "context"} + + def test_rampup_guard_rejects_rampup_batch_size(self): + from megatron.bridge.training.pretrain_mimo import pretrain_mimo + + cfg = Mock() + cfg.train.rampup_batch_size = [100, 200, 300] + + with pytest.raises(AssertionError, match="Microbatch rampup is not supported"): + pretrain_mimo( + cfg=cfg, + mimo_provider=Mock(), + forward_step_func=Mock(), + build_data_iterators_fn=Mock(), + opt_config=Mock(), + ) + + +# --------------------------------------------------------------------------- +# Tests: non-colocated runtime guard +# --------------------------------------------------------------------------- + + +class TestNonColocatedGuard: + """Verify the non-colocated topology assertion in train_mimo.""" + + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule", return_value=Mock(spec=[])) + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + def test_rejects_multiple_active_pgs(self, *_mocks): + from megatron.bridge.training.train_mimo import train_mimo + + infra = _make_mimo_infra(num_active_pgs=2) + state = _make_global_state(train_iters=0) + + with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={}, + train_data_iterator=Mock(), + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context={}, + ) + + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule", return_value=Mock(spec=[])) + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + def test_rejects_zero_active_pgs(self, *_mocks): + from megatron.bridge.training.train_mimo import train_mimo + + infra = _make_mimo_infra(num_active_pgs=0) + state = _make_global_state(train_iters=0) + + with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={}, + train_data_iterator=Mock(), + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context={}, + ) + + +# --------------------------------------------------------------------------- +# Tests: checkpoint_and_decide_exit integration in train_mimo +# --------------------------------------------------------------------------- + + +class TestTrainMimoCheckpointIntegration: + """Verify train_mimo calls checkpoint_and_decide_exit with the right args.""" + + @patch("megatron.bridge.training.train_mimo.checkpoint_and_decide_exit", return_value=False) + @patch("megatron.bridge.training.train_mimo.maybe_finalize_async_save") + @patch("megatron.bridge.training.train_mimo.train_step_mimo") + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule") + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_world_size", return_value=1) + def test_calls_checkpoint_and_decide_exit_with_pg_collection( + self, + mock_world_size, + mock_rank, + mock_num_mb, + mock_prep_fwd, + mock_get_config, + mock_get_grid, + mock_build_pg, + mock_train_step, + mock_async_finalize, + mock_ckpt_exit, + ): + from megatron.bridge.training.train_mimo import train_mimo + + mock_train_step.return_value = ({}, 0, 0.0, 0) + mock_config = Mock() + mock_config.variable_seq_lengths = True + mock_get_config.return_value = mock_config + + pg = Mock() + infra = Mock() + infra.pg_collections = {"language": pg} + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + + mock_build_pg.return_value = Mock(spec=[]) # not a list + + state = _make_global_state(train_iters=1, step=0) + ctx = {"key": "value"} + train_iter = Mock() + + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={"language": _make_scheduler_mock()}, + train_data_iterator=train_iter, + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context=ctx, + ) + + mock_ckpt_exit.assert_called_once() + _, kwargs = mock_ckpt_exit.call_args + assert kwargs["pg_collection"] is pg + assert kwargs["checkpointing_context"] is ctx + assert kwargs["train_data_iterator"] is train_iter + assert kwargs["num_floating_point_operations_so_far"] == 0 + + @patch("megatron.bridge.training.train_mimo.checkpoint_and_decide_exit", return_value=True) + @patch("megatron.bridge.training.train_mimo.maybe_finalize_async_save") + @patch("megatron.bridge.training.train_mimo.train_step_mimo") + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule") + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_world_size", return_value=1) + def test_exits_loop_when_checkpoint_and_decide_exit_returns_true( + self, + mock_world_size, + mock_rank, + mock_num_mb, + mock_prep_fwd, + mock_get_config, + mock_get_grid, + mock_build_pg, + mock_train_step, + mock_async_finalize, + mock_ckpt_exit, + ): + from megatron.bridge.training.train_mimo import train_mimo + + mock_train_step.return_value = ({}, 0, 0.0, 0) + mock_config = Mock() + mock_config.variable_seq_lengths = True + mock_get_config.return_value = mock_config + + infra = Mock() + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + mock_build_pg.return_value = Mock(spec=[]) + + state = _make_global_state(train_iters=100, step=0) + + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={"language": _make_scheduler_mock()}, + train_data_iterator=Mock(), + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context={}, + ) + + # Should have exited after 1 iteration, not 100 + assert mock_train_step.call_count == 1 + assert state.train_state.step == 1 + + @patch("megatron.bridge.training.train_mimo.checkpoint_and_decide_exit", return_value=False) + @patch("megatron.bridge.training.train_mimo.maybe_finalize_async_save") + @patch("megatron.bridge.training.train_mimo.train_step_mimo") + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule") + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_world_size", return_value=1) + def test_async_finalize_called_at_top_of_loop( + self, + mock_world_size, + mock_rank, + mock_num_mb, + mock_prep_fwd, + mock_get_config, + mock_get_grid, + mock_build_pg, + mock_train_step, + mock_async_finalize, + mock_ckpt_exit, + ): + from megatron.bridge.training.train_mimo import train_mimo + + mock_train_step.return_value = ({}, 0, 0.0, 0) + mock_config = Mock() + mock_config.variable_seq_lengths = True + mock_get_config.return_value = mock_config + + infra = Mock() + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + mock_build_pg.return_value = Mock(spec=[]) + + state = _make_global_state(train_iters=2, step=0) + + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={"language": _make_scheduler_mock()}, + train_data_iterator=Mock(), + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context={}, + ) + + # 2 non-blocking calls (top of each iteration) + 1 blocking call (shutdown) + assert mock_async_finalize.call_count == 3 + + non_blocking_calls = [c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is False] + blocking_calls = [c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is True] + assert len(non_blocking_calls) == 2 + assert len(blocking_calls) == 1 + assert blocking_calls[0].kwargs.get("terminate") is True + + @patch("megatron.bridge.training.train_mimo.checkpoint_and_decide_exit", return_value=False) + @patch("megatron.bridge.training.train_mimo.maybe_finalize_async_save") + @patch("megatron.bridge.training.train_mimo.train_step_mimo") + @patch("megatron.bridge.training.train_mimo.build_pg_collection_for_schedule") + @patch("megatron.bridge.training.train_mimo.get_module_to_grid_tuple") + @patch("megatron.bridge.training.train_mimo.get_model_config") + @patch("megatron.bridge.training.train_mimo.prepare_forward_step_func") + @patch("megatron.bridge.training.train_mimo.get_num_microbatches", return_value=1) + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_world_size", return_value=1) + def test_no_inline_save_checkpoint_call( + self, + mock_world_size, + mock_rank, + mock_num_mb, + mock_prep_fwd, + mock_get_config, + mock_get_grid, + mock_build_pg, + mock_train_step, + mock_async_finalize, + mock_ckpt_exit, + ): + """Verify there is no inline save_checkpoint call — all saves go through + checkpoint_and_decide_exit.""" + from megatron.bridge.training.train_mimo import train_mimo + + mock_train_step.return_value = ({}, 0, 0.0, 0) + mock_config = Mock() + mock_config.variable_seq_lengths = True + mock_get_config.return_value = mock_config + + infra = Mock() + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} + infra.topology = Mock() + mock_build_pg.return_value = Mock(spec=[]) + + state = _make_global_state(save_interval=1, train_iters=3, step=0) + + with patch("megatron.bridge.training.train_mimo.save_checkpoint_and_time") as mock_direct_save: + train_mimo( + forward_step_func=Mock(), + model=Mock(), + optimizer=Mock(), + schedulers={"language": _make_scheduler_mock()}, + train_data_iterator=Mock(), + valid_data_iterator=None, + global_state=state, + mimo_infra=infra, + multimodule_communicator=Mock(), + checkpointing_context={}, + ) + + # save_checkpoint_and_time should NOT be called directly from train_mimo. + # All saves should go through checkpoint_and_decide_exit. + mock_direct_save.assert_not_called() + + # But checkpoint_and_decide_exit should have been called + assert mock_ckpt_exit.call_count == 3 + + +# --------------------------------------------------------------------------- +# Helpers for load-side tests +# --------------------------------------------------------------------------- + + +def _make_setup_output_for_load( + *, + pg_collections: Dict[str, Any] | None = None, + train_state_step: int = 0, + consumed_train_samples: int = 0, + floating_point_operations_so_far: int = 0, +) -> SimpleNamespace: + """Create a MimoSetupOutput-like namespace suitable for pretrain_mimo load tests.""" + if pg_collections is None: + pg_collections = {"language": Mock()} + + train_state = SimpleNamespace( + step=train_state_step, + consumed_train_samples=consumed_train_samples, + floating_point_operations_so_far=floating_point_operations_so_far, + ) + timers_handle = Mock() + timers = Mock(return_value=timers_handle) + timers.log = Mock() + + global_state = Mock() + global_state.timers = timers + global_state.train_state = train_state + + return SimpleNamespace( + model=MagicMock(), + mimo_infra=SimpleNamespace( + module_to_grid_map={"language": Mock()}, + pg_collections=pg_collections, + topology=Mock(), + ), + multimodule_communicator=MagicMock(), + train_data_iterator=None, + valid_data_iterator=None, + global_state=global_state, + checkpointing_context={"test": "context"}, + ) + + +def _make_pretrain_cfg( + *, + load_path: str | None = None, + pretrained_path: str | None = None, + non_persistent_ckpt_type: str | None = None, +) -> MagicMock: + """Create a ConfigContainer-like mock for pretrain_mimo tests.""" + cfg = MagicMock() + cfg.train = SimpleNamespace( + rampup_batch_size=None, + global_batch_size=1, + micro_batch_size=1, + decrease_batch_size_if_needed=False, + ) + cfg.data_parallel_size = 1 + cfg.checkpoint = SimpleNamespace( + load=load_path, + pretrained_checkpoint=pretrained_path, + non_persistent_ckpt_type=non_persistent_ckpt_type, + ) + cfg.scheduler = SimpleNamespace( + lr_warmup_init=0.0, + lr_warmup_steps=0, + lr_decay_steps=100, + lr_decay_style="linear", + start_weight_decay=0.0, + end_weight_decay=0.0, + wd_incr_steps=0, + weight_decay_incr_style="constant", + use_checkpoint_opt_param_scheduler=False, + override_opt_param_scheduler=False, + wsd_decay_steps=None, + lr_wsd_decay_style=None, + ) + return cfg + + +def _run_pretrain_mimo( + *, + cfg: MagicMock | None = None, + setup_output: SimpleNamespace | None = None, + schedulers: Dict[str, Any] | None = None, + checkpoint_exists_return: bool = False, + build_data_iterators_fn: Any | None = None, +) -> Dict[str, Mock]: + """Run pretrain_mimo with full mocking and return all mock handles. + + Returns dict with keys: setup_mimo, load_checkpoint, checkpoint_exists, + train_mimo, build_data_iterators_fn, unwrap. + """ + from megatron.bridge.training.pretrain_mimo import pretrain_mimo + + if cfg is None: + cfg = _make_pretrain_cfg() + if setup_output is None: + setup_output = _make_setup_output_for_load() + if schedulers is None: + schedulers = {} + if build_data_iterators_fn is None: + build_data_iterators_fn = Mock(return_value=(iter([]), None)) + + mocks = {} + + with ( + patch("megatron.bridge.training.pretrain_mimo.train_mimo") as m_train, + patch("megatron.bridge.training.pretrain_mimo.setup_mimo", return_value=setup_output) as m_setup, + patch("megatron.bridge.training.pretrain_mimo.unwrap_mimo_model") as m_unwrap, + patch("megatron.bridge.training.pretrain_mimo.load_checkpoint") as m_load, + patch( + "megatron.bridge.training.pretrain_mimo.checkpoint_exists", + return_value=checkpoint_exists_return, + ) as m_ckpt_exists, + patch("megatron.bridge.training.pretrain_mimo.dist") as m_dist, + patch("megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR", None), + patch("megatron.core.num_microbatches_calculator.init_num_microbatches_calculator"), + patch("megatron.core.models.mimo.optimizer.get_mimo_optimizer") as m_get_opt, + ): + m_dist.get_rank.return_value = 0 + m_dist.get_world_size.return_value = 2 + m_unwrap.return_value = MagicMock( + mimo_config=SimpleNamespace(module_to_grid_map={"language": Mock()}), + ) + mock_optimizer = MagicMock() + mock_optimizer.module_infos = {} + mock_optimizer.is_stub_optimizer = False + m_get_opt.return_value = mock_optimizer + + pretrain_mimo( + cfg=cfg, + mimo_provider=MagicMock(), + forward_step_func=MagicMock(), + build_data_iterators_fn=build_data_iterators_fn, + opt_config=MagicMock(finalize=MagicMock(), lr=1e-4, min_lr=1e-5), + schedulers=schedulers, + global_state=setup_output.global_state, + ) + + mocks["setup_mimo"] = m_setup + mocks["load_checkpoint"] = m_load + mocks["checkpoint_exists"] = m_ckpt_exists + mocks["train_mimo"] = m_train + mocks["build_data_iterators_fn"] = build_data_iterators_fn + mocks["unwrap"] = m_unwrap + + return mocks + + +# --------------------------------------------------------------------------- +# Tests: load_checkpoint invocation from pretrain_mimo +# --------------------------------------------------------------------------- + + +class TestPretrainMimoLoadCheckpoint: + """Verify pretrain_mimo invokes load_checkpoint with correct arguments.""" + + def test_load_invoked_when_persistent_checkpoint_exists(self): + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=True) + mocks["load_checkpoint"].assert_called_once() + + def test_load_invoked_when_pretrained_checkpoint_exists(self): + cfg = _make_pretrain_cfg(pretrained_path="/tmp/pretrained") + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=True) + mocks["load_checkpoint"].assert_called_once() + + def test_load_invoked_for_non_persistent_intent_without_persistent_path(self): + """Non-persistent resume intent should trigger load even without cfg.checkpoint.load.""" + cfg = _make_pretrain_cfg(non_persistent_ckpt_type="local") + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=False) + mocks["load_checkpoint"].assert_called_once() + + def test_load_not_invoked_when_no_checkpoint_intent(self): + cfg = _make_pretrain_cfg() + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=False) + mocks["load_checkpoint"].assert_not_called() + + def test_load_forwards_list_wrapped_model(self): + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + setup_output = _make_setup_output_for_load() + mocks = _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + _, kwargs = mocks["load_checkpoint"].call_args + assert isinstance(kwargs["model"], list) + assert len(kwargs["model"]) == 1 + assert kwargs["model"][0] is setup_output.model + + def test_load_forwards_explicit_pg_collection(self): + pg = Mock() + setup_output = _make_setup_output_for_load(pg_collections={"language": pg}) + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + mocks = _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + _, kwargs = mocks["load_checkpoint"].call_args + assert kwargs["pg_collection"] is pg + + def test_load_forwards_checkpointing_context(self): + setup_output = _make_setup_output_for_load() + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + mocks = _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + _, kwargs = mocks["load_checkpoint"].call_args + assert kwargs["checkpointing_context"] is setup_output.checkpointing_context + + def test_load_forwards_first_scheduler(self): + sched_a = _make_scheduler_mock() + sched_b = _make_scheduler_mock() + schedulers = {"language": sched_a, "vision": sched_b} + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + mocks = _run_pretrain_mimo( + cfg=cfg, + schedulers=schedulers, + checkpoint_exists_return=True, + ) + _, kwargs = mocks["load_checkpoint"].call_args + assert kwargs["opt_param_scheduler"] is sched_a + + +# --------------------------------------------------------------------------- +# Tests: non-colocated PG guard in pretrain_mimo load path +# --------------------------------------------------------------------------- + + +class TestPretrainMimoLoadPgGuard: + """Verify pretrain_mimo fails fast when PG topology is invalid.""" + + def test_rejects_zero_active_pgs_in_pretrain(self): + setup_output = _make_setup_output_for_load(pg_collections={}) + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): + _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + + def test_rejects_multiple_active_pgs_in_pretrain(self): + setup_output = _make_setup_output_for_load( + pg_collections={"language": Mock(), "vision": Mock()}, + ) + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): + _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + + +# --------------------------------------------------------------------------- +# Tests: scheduler v1 fanout behavior +# --------------------------------------------------------------------------- + + +class TestSchedulerV1Fanout: + """Verify scheduler state is loaded into first_scheduler and fanned out.""" + + def test_scheduler_fanout_after_load(self): + """After load, all schedulers should have the state of first_scheduler.""" + sched_a = MagicMock() + sched_a.optimizer.param_groups = [{"lr": 1e-4}] + sched_a.state_dict.return_value = {"step": 50, "lr": 0.001} + + sched_b = MagicMock() + sched_b.optimizer.param_groups = [{"lr": 1e-4}] + + schedulers = {"language": sched_a, "vision": sched_b} + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + + # Simulate load succeeding and setting step > 0 via side_effect. + # load_checkpoint modifies global_state.train_state in-place, but + # in mock context it doesn't. We need step to remain 0 so iterator + # builder doesn't require train_state kwarg. + _run_pretrain_mimo(cfg=cfg, schedulers=schedulers, checkpoint_exists_return=True) + + # sched_b should have received the fanout + sched_b.load_state_dict.assert_called_once_with({"step": 50, "lr": 0.001}) + # sched_a should NOT have load_state_dict called by fanout (it's the source) + sched_a.load_state_dict.assert_not_called() + + def test_no_fanout_when_single_scheduler(self): + sched = MagicMock() + sched.optimizer.param_groups = [{"lr": 1e-4}] + sched.state_dict.return_value = {"step": 50} + + schedulers = {"language": sched} + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + _run_pretrain_mimo(cfg=cfg, schedulers=schedulers, checkpoint_exists_return=True) + + # No fanout needed with single scheduler + sched.load_state_dict.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: iterator resume semantics +# --------------------------------------------------------------------------- + + +class TestIteratorResumeSemanticsLoad: + """Verify iterators are built after load and receive train_state when resuming.""" + + def test_iterators_built_after_setup_not_during(self): + """setup_mimo should be called with build_data_iterators_fn=None.""" + cfg = _make_pretrain_cfg() + mocks = _run_pretrain_mimo(cfg=cfg) + _, kwargs = mocks["setup_mimo"].call_args + assert kwargs["build_data_iterators_fn"] is None + + def test_iterator_builder_called_without_train_state_when_not_resuming(self): + cfg = _make_pretrain_cfg() + build_fn = Mock(return_value=(iter([]), None)) + mocks = _run_pretrain_mimo(cfg=cfg, build_data_iterators_fn=build_fn) + build_fn.assert_called_once() + args, kwargs = build_fn.call_args + assert "train_state" not in kwargs + + def test_iterator_builder_receives_train_state_mock(self): + """When resuming (step > 0), builder receives train_state kwarg.""" + build_fn = MagicMock(return_value=(iter([]), None)) + # Give mock the train_state parameter so inspect.signature finds it + + def _sig_fn(cfg, mimo_infra, *, train_state=None): + pass + + build_fn.__signature__ = inspect.signature(_sig_fn) + + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + setup_output = _make_setup_output_for_load(train_state_step=10, consumed_train_samples=500) + + _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + build_data_iterators_fn=build_fn, + ) + + build_fn.assert_called_once() + _, kwargs = build_fn.call_args + assert "train_state" in kwargs + assert kwargs["train_state"].step == 10 + assert kwargs["train_state"].consumed_train_samples == 500 + + def test_iterator_builder_fails_fast_if_no_train_state_param_on_resume(self): + """Resuming with a builder that lacks train_state param raises RuntimeError.""" + + def legacy_builder(cfg, mimo_infra): + return (iter([]), None) + + build_fn = MagicMock(return_value=(iter([]), None)) + build_fn.__signature__ = inspect.signature(legacy_builder) + + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + setup_output = _make_setup_output_for_load(train_state_step=10) + + with pytest.raises(RuntimeError, match="build_data_iterators_fn does not accept"): + _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + build_data_iterators_fn=build_fn, + ) + + +# --------------------------------------------------------------------------- +# Tests: MimoOptimizer load-side compatibility +# --------------------------------------------------------------------------- + + +class TestMimoOptimizerLoadCompat: + """Verify MimoOptimizer load methods work correctly.""" + + def _make_mimo_optimizer(self): + from megatron.core.models.mimo.optimizer import MimoOptimizer, ModuleOptimizerInfo + + opt_a = MagicMock() + opt_b = MagicMock() + + module_infos = { + "language": ModuleOptimizerInfo(optimizer=opt_a, grid=Mock(), pg_collection=Mock(), is_active=True), + "vision": ModuleOptimizerInfo(optimizer=opt_b, grid=Mock(), pg_collection=Mock(), is_active=True), + } + config = MagicMock() + return MimoOptimizer(module_infos, config), opt_a, opt_b + + def test_load_state_dict_dispatches_per_module(self): + mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() + state = {"language": {"param": 1}, "vision": {"param": 2}} + mimo_opt.load_state_dict(state) + opt_a.load_state_dict.assert_called_once_with({"param": 1}) + opt_b.load_state_dict.assert_called_once_with({"param": 2}) + + def test_load_state_dict_skips_missing_keys(self): + mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() + state = {"language": {"param": 1}} + mimo_opt.load_state_dict(state) + opt_a.load_state_dict.assert_called_once() + opt_b.load_state_dict.assert_not_called() + + def test_sharded_state_dict_generates_per_module(self): + mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() + opt_a.sharded_state_dict.return_value = {"a": "sharded_a"} + opt_b.sharded_state_dict.return_value = {"b": "sharded_b"} + + result = mimo_opt.sharded_state_dict({}, is_loading=True) + assert "language" in result + assert "vision" in result + assert result["language"] == {"a": "sharded_a"} + assert result["vision"] == {"b": "sharded_b"} + + def test_reload_model_params_delegates_to_all_active(self): + mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() + mimo_opt.reload_model_params(state_dict={"model": {}}) + opt_a.reload_model_params.assert_called_once_with({"model": {}}) + opt_b.reload_model_params.assert_called_once_with({"model": {}}) + + def test_is_stub_optimizer_when_no_active(self): + from megatron.core.models.mimo.optimizer import MimoOptimizer, ModuleOptimizerInfo + + module_infos = { + "language": ModuleOptimizerInfo(optimizer=None, grid=Mock(), pg_collection=Mock(), is_active=False), + } + mimo_opt = MimoOptimizer(module_infos, MagicMock()) + assert mimo_opt.is_stub_optimizer is True + + +# --------------------------------------------------------------------------- +# Tests: train state restoration smoke +# --------------------------------------------------------------------------- + + +class TestTrainStateRestorationSmoke: + """Smoke tests for train_state being accessible after load.""" + + def test_train_state_step_accessible_after_load(self): + setup_output = _make_setup_output_for_load(train_state_step=42, consumed_train_samples=1000) + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + + def builder(cfg, mimo_infra, *, train_state=None): + return (iter([]), None) + + build_fn = MagicMock(return_value=(iter([]), None)) + build_fn.__signature__ = inspect.signature(builder) + + mocks = _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + build_data_iterators_fn=build_fn, + ) + + # train_state is passed to train_mimo via global_state + _, kwargs = mocks["train_mimo"].call_args + ts = kwargs["global_state"].train_state + assert ts.step == 42 + assert ts.consumed_train_samples == 1000 + + def test_floating_point_ops_preserved(self): + setup_output = _make_setup_output_for_load(floating_point_operations_so_far=99999) + cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") + mocks = _run_pretrain_mimo( + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, + ) + _, kwargs = mocks["train_mimo"].call_args + assert kwargs["global_state"].train_state.floating_point_operations_so_far == 99999 + + +# --------------------------------------------------------------------------- +# Tests: local checkpoint resume plumbing +# --------------------------------------------------------------------------- + + +class TestLocalCheckpointResumePlumbing: + """Verify non-persistent local checkpoint intent triggers load.""" + + def test_local_non_persistent_triggers_load(self): + cfg = _make_pretrain_cfg(non_persistent_ckpt_type="local") + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=False) + mocks["load_checkpoint"].assert_called_once() + + def test_global_non_persistent_triggers_load(self): + cfg = _make_pretrain_cfg(non_persistent_ckpt_type="global") + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=False) + mocks["load_checkpoint"].assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests: no-checkpoint graceful fallback +# --------------------------------------------------------------------------- + + +class TestNoCheckpointGracefulFallback: + """Verify load is not attempted and training starts from random init.""" + + def test_no_load_no_crash(self): + """When no checkpoint intent exists, load is skipped and training starts cleanly.""" + cfg = _make_pretrain_cfg() + mocks = _run_pretrain_mimo(cfg=cfg, checkpoint_exists_return=False) + mocks["load_checkpoint"].assert_not_called() + mocks["train_mimo"].assert_called_once() + + def test_iterators_still_built_without_checkpoint(self): + cfg = _make_pretrain_cfg() + build_fn = Mock(return_value=(iter([]), None)) + mocks = _run_pretrain_mimo(cfg=cfg, build_data_iterators_fn=build_fn) + build_fn.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests: load_checkpoint pg_collection explicit threading +# --------------------------------------------------------------------------- + + +class TestLoadCheckpointPgThreading: + """Verify load_checkpoint and _load_checkpoint_from_path accept and + thread explicit pg_collection.""" + + def test_load_checkpoint_forwards_pg_collection_to_inner(self): + from megatron.bridge.training.checkpointing import load_checkpoint + + pg = Mock() + state = Mock() + state.cfg.checkpoint.load = "/tmp/ckpt" + state.cfg.checkpoint.pretrained_checkpoint = None + + with patch( + "megatron.bridge.training.checkpointing._load_checkpoint_from_path", + return_value=(0, 0), + ) as m_inner: + with patch( + "megatron.bridge.training.checkpointing.checkpoint_exists", + return_value=True, + ): + load_checkpoint( + state=state, + model=[Mock()], + optimizer=Mock(), + opt_param_scheduler=Mock(), + pg_collection=pg, + ) + _, kwargs = m_inner.call_args + assert kwargs["pg_collection"] is pg + + def test_load_checkpoint_defaults_pg_collection_to_none(self): + from megatron.bridge.training.checkpointing import load_checkpoint + + state = Mock() + state.cfg.checkpoint.load = "/tmp/ckpt" + state.cfg.checkpoint.pretrained_checkpoint = None + + with patch( + "megatron.bridge.training.checkpointing._load_checkpoint_from_path", + return_value=(0, 0), + ) as m_inner: + with patch( + "megatron.bridge.training.checkpointing.checkpoint_exists", + return_value=True, + ): + load_checkpoint( + state=state, + model=[Mock()], + optimizer=Mock(), + opt_param_scheduler=Mock(), + ) + _, kwargs = m_inner.call_args + assert kwargs["pg_collection"] is None diff --git a/tests/unit_tests/training/mimo/test_pretrain_mimo.py b/tests/unit_tests/training/mimo/test_pretrain_mimo.py index 90f0c55c27..a9fb75c03c 100644 --- a/tests/unit_tests/training/mimo/test_pretrain_mimo.py +++ b/tests/unit_tests/training/mimo/test_pretrain_mimo.py @@ -16,17 +16,26 @@ def _make_cfg(): decrease_batch_size_if_needed=False, ) cfg.data_parallel_size = 1 + cfg.checkpoint.load = None + cfg.checkpoint.pretrained_checkpoint = None + cfg.checkpoint.non_persistent_ckpt_type = None return cfg def _make_setup_output(module_to_grid_map): + global_state = MagicMock() + global_state.train_state.step = 0 return SimpleNamespace( model=MagicMock(), - mimo_infra=SimpleNamespace(module_to_grid_map=module_to_grid_map), + mimo_infra=SimpleNamespace( + module_to_grid_map=module_to_grid_map, + pg_collections={"language": MagicMock()}, + ), multimodule_communicator=MagicMock(), train_data_iterator=iter([]), valid_data_iterator=None, - global_state=MagicMock(), + global_state=global_state, + checkpointing_context=None, ) @@ -70,7 +79,7 @@ def test_pretrain_mimo_uses_constructor_wired_config( cfg=cfg, mimo_provider=MagicMock(), forward_step_func=forward_step_func, - build_data_iterators_fn=MagicMock(), + build_data_iterators_fn=MagicMock(return_value=(iter([]), None)), opt_config=opt_config, schedulers=schedulers, global_state=MagicMock(),