diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index cc897f0111..10b3ddd4da 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit cc897f011196b71d2f5b46870c302c7965070660 +Subproject commit 10b3ddd4da10532b13e467b89c6cedf590e4c026 diff --git a/src/megatron/bridge/data/mimo/dp_utils.py b/src/megatron/bridge/data/mimo/dp_utils.py index 546e39e5f1..c6d867a4c4 100644 --- a/src/megatron/bridge/data/mimo/dp_utils.py +++ b/src/megatron/bridge/data/mimo/dp_utils.py @@ -6,9 +6,12 @@ from typing import TYPE_CHECKING, Dict, Tuple import torch.distributed as dist +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + if TYPE_CHECKING: from megatron.core.hyper_comm_grid import HyperCommGrid + from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig @@ -17,23 +20,23 @@ def get_mimo_dp_info( grids: Dict[str, "HyperCommGrid"], ) -> Tuple[int, int, bool, str]: """Get DP rank, size, data-loading responsibility, and loader module for MIMO. - + Determines which module's DP settings to use for data loading based on current rank's participation in heterogeneous deployment. - + In heterogeneous mode, each rank uses its own module's DP settings. - + Args: mimo_cfg: MIMO parallelism configuration. grids: Module name to HyperCommGrid mapping from build_hypercomm_grids(). - + Returns: Tuple of (dp_rank, dp_size, needs_data, loader_module): - dp_rank: This rank's position in DP group. - dp_size: Size of DP group for data sharding. - needs_data: Whether this rank needs to load data (first/last PP stage). - loader_module: Which module's DP settings are being used. - + Example: >>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids >>> grids = build_hypercomm_grids(mimo_cfg) @@ -55,7 +58,7 @@ def get_mimo_dp_info( if my_grid is None or my_module is None: # Rank doesn't participate in any module - return 0, 1, False, "llm" + return 0, 1, False, MIMO_LANGUAGE_MODULE_KEY dp_rank = my_grid.get_pg(["dp"]).rank() dp_size = my_grid.get_pg(["dp"]).size() @@ -64,7 +67,7 @@ def get_mimo_dp_info( pp_rank = pp_group.rank() pp_size = pp_group.size() - if my_module == "llm": + if my_module == MIMO_LANGUAGE_MODULE_KEY: needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1) else: needs_data = pp_rank == 0 diff --git a/src/megatron/bridge/models/mimo/mimo_config.py b/src/megatron/bridge/models/mimo/mimo_config.py index ec5efea755..4474427401 100644 --- a/src/megatron/bridge/models/mimo/mimo_config.py +++ b/src/megatron/bridge/models/mimo/mimo_config.py @@ -5,6 +5,8 @@ from dataclasses import dataclass, field from typing import Optional +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + @dataclass class ModuleParallelismConfig: @@ -62,7 +64,7 @@ class MimoParallelismConfig: Note: Phase 1 only supports heterogeneous deployment where each module can have different parallelism configurations and rank offsets. - The LLM module must be named "llm" in module_parallelisms. + The language module must be named MIMO_LANGUAGE_MODULE_KEY ("language") in module_parallelisms. """ module_parallelisms: dict[str, ModuleParallelismConfig] @@ -133,10 +135,10 @@ def is_power_of_two(n: int) -> bool: # Validate encoder DP >= LLM DP for embedding alignment # Encoder modules produce embeddings consumed by LLM. If encoder DP < LLM DP, # the same encoder batch would need to align with different LLM batches, which fails. - llm_dp = self.module_parallelisms["llm"].data_parallel_size + llm_dp = self.module_parallelisms[MIMO_LANGUAGE_MODULE_KEY].data_parallel_size if llm_dp is not None: for name, p in self.module_parallelisms.items(): - if name == "llm": + if name == MIMO_LANGUAGE_MODULE_KEY: continue encoder_dp = p.data_parallel_size if encoder_dp is not None and encoder_dp < llm_dp: @@ -152,9 +154,9 @@ def finalize(self, world_size: int) -> None: world_size: Total number of ranks in the distributed world. MIMO requires a distributed environment, so this must always be provided. """ - if "llm" not in self.module_parallelisms: + if MIMO_LANGUAGE_MODULE_KEY not in self.module_parallelisms: raise ValueError( - f"LLM module 'llm' must be in module_parallelisms. " + f"Language module '{MIMO_LANGUAGE_MODULE_KEY}' must be in module_parallelisms. " f"Found modules: {list(self.module_parallelisms.keys())}" ) diff --git a/src/megatron/bridge/models/mimo/mimo_ddp.py b/src/megatron/bridge/models/mimo/mimo_ddp.py index 1a045fa9d9..fbd0a91f8f 100644 --- a/src/megatron/bridge/models/mimo/mimo_ddp.py +++ b/src/megatron/bridge/models/mimo/mimo_ddp.py @@ -5,17 +5,22 @@ Note: This module only supports DDP wrapping. FSDP is not yet implemented. """ + from __future__ import annotations from typing import TYPE_CHECKING, Dict, Optional +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid + if TYPE_CHECKING: from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.models.mimo import MimoModel from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig @@ -44,9 +49,9 @@ def wrap_mimo_model_distributed( # Wrap language model if present and rank participates if mimo_model.language_model is not None: - llm_grid = grids.get("llm") + llm_grid = grids.get(MIMO_LANGUAGE_MODULE_KEY) if llm_grid is not None and is_current_rank_in_grid(llm_grid): - llm_pg = pg_collections.get("llm") + llm_pg = pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if llm_pg is not None: mimo_model.language_model = DistributedDataParallel( config=mimo_model.language_model.config, @@ -56,7 +61,7 @@ def wrap_mimo_model_distributed( ) # Wrap modality submodules - if hasattr(mimo_model, 'modality_submodules'): + if hasattr(mimo_model, "modality_submodules"): for module_name, submodule in mimo_model.modality_submodules.items(): if submodule is None: continue @@ -74,11 +79,11 @@ def wrap_mimo_model_distributed( # Note: We use the first encoder's config for DDP bucket sizing. # This assumes all encoders in a modality submodule share similar # parallelism settings, which is typical for MIMO models. - if hasattr(submodule, 'encoders') and submodule.encoders: + if hasattr(submodule, "encoders") and submodule.encoders: encoder_key = next(iter(submodule.encoders.keys())) first_encoder = submodule.encoders[encoder_key] - - if not hasattr(first_encoder, 'config'): + + if not hasattr(first_encoder, "config"): raise AttributeError( f"Encoder '{encoder_key}' in modality '{module_name}' does not have " f"a 'config' attribute. Encoders must be MegatronModule subclasses." diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index 410e905678..2cddb904ec 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -21,6 +21,7 @@ from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.mimo import MimoModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec @@ -60,6 +61,7 @@ class MimoModelInfra: topology: Dict[str, List[str]] pg_collections: Dict[str, Optional[ProcessGroupCollection]] participating_modules: List[str] + module_output_ndim: Dict[str, int] = field(default_factory=dict) @dataclass @@ -83,7 +85,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): Example: >>> mimo_parallelism_config = MimoParallelismConfig( ... module_parallelisms={ - ... "llm": ModuleParallelismConfig(tensor_model_parallel_size=8), + ... "language": ModuleParallelismConfig(tensor_model_parallel_size=8), ... "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), ... } ... ) @@ -108,10 +110,15 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): mimo_parallelism_config: Optional[MimoParallelismConfig] = None # Module data-flow DAG for MultiModulePipelineCommunicator. - # If None, auto-derived as: all modality_submodules → "llm" (terminal). - # Set explicitly for non-standard topologies (e.g., llm → generator). + # If None, auto-derived as: all modality_submodules → language module (terminal). + # Set explicitly for non-standard topologies (e.g., language → generator). topology: Optional[Dict[str, List[str]]] = None + # Output tensor dimensionality per module for bridge communicator routing. + # Vision/audio encoders typically produce 2D [S, H]; language modules produce 3D [S, B, H]. + # If None, auto-derived: language module → 3, all others → 2. + module_output_ndim: Optional[Dict[str, int]] = None + # Cached grids after build_model() - used by data loading _grids: Optional[Dict[str, "HyperCommGrid"]] = field(default=None, repr=False) @@ -150,18 +157,30 @@ def build_infra(self) -> MimoModelInfra: if self.topology is not None: topology = self.topology else: - topology = {name: ["llm"] for name in self.modality_submodules_spec} | {"llm": []} + topology = {name: [MIMO_LANGUAGE_MODULE_KEY] for name in self.modality_submodules_spec} | { + MIMO_LANGUAGE_MODULE_KEY: [] + } # Cache grids for later use (e.g., data loading) object.__setattr__(self, "_grids", grids) participating_modules = [name for name, pg in pg_collections.items() if pg is not None] + # Derive module output tensor dimensionality if not explicitly configured. + if self.module_output_ndim is not None: + output_ndim = self.module_output_ndim + else: + output_ndim = { + name: 3 if name == MIMO_LANGUAGE_MODULE_KEY else 2 + for name in grids + } + return MimoModelInfra( module_to_grid_map=grids, topology=topology, pg_collections=pg_collections, participating_modules=participating_modules, + module_output_ndim=output_ndim, ) def _get_pg_collections_from_grids( @@ -289,7 +308,7 @@ def provide( # Inject pg_collection into language model spec language_spec = self.language_model_spec if self.mimo_parallelism_config: - llm_pg = infra.pg_collections.get("llm") + llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if llm_pg is not None: language_spec = self._inject_pg_collection_into_language_spec( language_spec, @@ -312,7 +331,6 @@ def provide( modality_submodules_spec=modality_specs, special_token_ids=self.special_token_ids, module_to_grid_map=(infra.module_to_grid_map if self.mimo_parallelism_config is not None else None), - language_module_key="llm" if self.mimo_parallelism_config is not None else None, ) mimo_model = MimoModel(mimo_model_config) diff --git a/src/megatron/bridge/training/mimo_parallel_utils.py b/src/megatron/bridge/training/mimo_parallel_utils.py index a32f488933..94dd801ef8 100644 --- a/src/megatron/bridge/training/mimo_parallel_utils.py +++ b/src/megatron/bridge/training/mimo_parallel_utils.py @@ -23,6 +23,7 @@ import torch.distributed as dist from megatron.core.distributed.finalize_model_grads import finalize_model_grads as _finalize_model_grads from megatron.core.models.mimo import MimoModel +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.bridge.models.mimo.mimo_provider import MimoModelInfra @@ -94,7 +95,7 @@ def get_module_to_grid_tuple( continue # Get the actual module from the unwrapped model - if module_name == "llm": + if module_name == MIMO_LANGUAGE_MODULE_KEY: module = unwrapped_model.language_model elif hasattr(unwrapped_model, "modality_submodules") and module_name in unwrapped_model.modality_submodules: module = unwrapped_model.modality_submodules[module_name] @@ -128,7 +129,7 @@ def build_pg_collection_for_schedule(infra: MimoModelInfra): module_pgs = {k: v for k, v in infra.pg_collections.items() if v is not None} if not module_pgs: raise ValueError("module_pgs dict cannot be empty") - language_model_module_name = "llm" if "llm" in module_pgs else None + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY if MIMO_LANGUAGE_MODULE_KEY in module_pgs else None return MultiModuleProcessGroupCollection( module_pgs=module_pgs, language_model_module_name=language_model_module_name, diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py index c3e3323ae0..eb6598dd89 100644 --- a/src/megatron/bridge/training/mimo_step.py +++ b/src/megatron/bridge/training/mimo_step.py @@ -17,6 +17,7 @@ import torch from megatron.core.models.mimo import MimoModel +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.bridge.training.mimo_parallel_utils import unwrap_mimo_model from megatron.bridge.training.state import GlobalState @@ -135,7 +136,7 @@ def forward_step( needs_data = True if mimo_model.role is not None: if mimo_model.role.has_language_module: - module_name = mimo_model.role.language_module_name + module_name = MIMO_LANGUAGE_MODULE_KEY is_first_stage = mimo_model.role.is_first_stage(module_name) is_last_stage = mimo_model.role.is_last_stage(module_name) needs_data = is_first_stage or is_last_stage @@ -182,7 +183,7 @@ def forward_step( if mimo_model.role is None: is_last_stage = True elif mimo_model.role.has_language_module: - is_last_stage = mimo_model.role.is_last_stage(mimo_model.role.language_module_name) + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) else: is_last_stage = False diff --git a/src/megatron/bridge/training/pretrain_mimo.py b/src/megatron/bridge/training/pretrain_mimo.py index 8be96bb14d..b55084ffb2 100644 --- a/src/megatron/bridge/training/pretrain_mimo.py +++ b/src/megatron/bridge/training/pretrain_mimo.py @@ -166,6 +166,7 @@ def setup_mimo( mimo_infra.topology, model_config, dim_mapping={"s": 0, "b": 1, "h": 2}, # SBH mapping - matches MimoModel output + module_output_ndim=mimo_infra.module_output_ndim, ) # Build pg_collection for schedule @@ -261,11 +262,6 @@ def pretrain_mimo( "MimoModelConfig.module_to_grid_map must be set at model construction time. " "Ensure MimoModelProvider.provide() passes module_to_grid_map for MIMO parallelism." ) - assert unwrapped_model.mimo_config.language_module_key is not None, ( - "MimoModelConfig.language_module_key must be set at model construction time. " - "Ensure MimoModelProvider.provide() sets language_module_key for MIMO parallelism." - ) - logger.info(f"Rank {dist.get_rank()}: Creating MimoOptimizer") # Create MimoOptimizer using the factory function diff --git a/src/megatron/bridge/training/train_mimo.py b/src/megatron/bridge/training/train_mimo.py index ca70e85e47..e8d1dc2a3d 100644 --- a/src/megatron/bridge/training/train_mimo.py +++ b/src/megatron/bridge/training/train_mimo.py @@ -23,6 +23,7 @@ import torch import torch.distributed as dist +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel.schedules import forward_backward_pipelining_without_interleaving from megatron.core.utils import get_model_config @@ -144,10 +145,10 @@ def train_step_mimo( if mimo_model.role is None: is_last_stage = True elif mimo_model.role.has_language_module: - is_last_stage = mimo_model.role.is_last_stage(mimo_model.role.language_module_name) + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) if is_last_stage: - llm_pg = infra.pg_collections.get("llm") if infra.pg_collections else None + llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if infra.pg_collections else None for key in losses_reduced[0].keys(): val = [x[key].view(-1) for x in losses_reduced] if val[0].numel() == 2: diff --git a/tests/unit_tests/data/mimo/test_dp_utils.py b/tests/unit_tests/data/mimo/test_dp_utils.py index 3a7b4bd53c..b475a1dbde 100644 --- a/tests/unit_tests/data/mimo/test_dp_utils.py +++ b/tests/unit_tests/data/mimo/test_dp_utils.py @@ -9,7 +9,7 @@ class FakePG: """Fake process group for testing.""" - + def __init__(self, rank: int, size: int): self._rank = rank self._size = size @@ -23,9 +23,8 @@ def size(self) -> int: class FakeGrid: """Fake HyperCommGrid for testing.""" - - def __init__(self, rank_offset: int, size: int, dp_rank: int, dp_size: int, - pp_rank: int, pp_size: int): + + def __init__(self, rank_offset: int, size: int, dp_rank: int, dp_size: int, pp_rank: int, pp_size: int): self.rank_offset = rank_offset self.size = size self._pgs = { @@ -41,7 +40,7 @@ def _make_mimo_cfg() -> MimoParallelismConfig: """Create test MIMO config for heterogeneous deployment.""" module_parallelisms = { "vision": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=0), - "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=4), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=4), } return MimoParallelismConfig( module_parallelisms=module_parallelisms, @@ -55,11 +54,11 @@ def test_get_mimo_dp_info_encoder_first_pp(monkeypatch): grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=2), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - + assert loader_module == "vision" assert dp_rank == 0 assert dp_size == 2 @@ -73,11 +72,11 @@ def test_get_mimo_dp_info_encoder_non_first_pp(monkeypatch): grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=1, pp_size=2), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - + assert loader_module == "vision" assert needs_data is False # Not first PP stage @@ -89,12 +88,12 @@ def test_get_mimo_dp_info_llm_first_pp(monkeypatch): grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=2), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=2), } dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - - assert loader_module == "llm" + + assert loader_module == "language" assert needs_data is True # First PP stage @@ -105,12 +104,12 @@ def test_get_mimo_dp_info_llm_last_pp(monkeypatch): grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=1, dp_size=4, pp_rank=1, pp_size=2), + "language": FakeGrid(4, 4, dp_rank=1, dp_size=4, pp_rank=1, pp_size=2), } dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - - assert loader_module == "llm" + + assert loader_module == "language" assert needs_data is True # Last PP stage @@ -121,10 +120,10 @@ def test_get_mimo_dp_info_non_participating_rank(monkeypatch): grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - + assert needs_data is False - assert loader_module == "llm" # Default to LLM + assert loader_module == "language" # Default to LLM diff --git a/tests/unit_tests/models/mimo/test_llava_provider.py b/tests/unit_tests/models/mimo/test_llava_provider.py index b9bbe78417..48fe8ce2a8 100644 --- a/tests/unit_tests/models/mimo/test_llava_provider.py +++ b/tests/unit_tests/models/mimo/test_llava_provider.py @@ -224,7 +224,7 @@ def test_can_set_parallelism_config(self): mock_vision_encoder = Mock mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=4), + "language": ModuleParallelismConfig(tensor_model_parallel_size=4), } ) diff --git a/tests/unit_tests/models/mimo/test_mimo_builder.py b/tests/unit_tests/models/mimo/test_mimo_builder.py index 831c3bfdbb..346806cc57 100644 --- a/tests/unit_tests/models/mimo/test_mimo_builder.py +++ b/tests/unit_tests/models/mimo/test_mimo_builder.py @@ -15,7 +15,7 @@ def test_build_with_single_module(self, mock_grid_class): """Test build_hypercomm_grids with single LLM module.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, context_parallel_size=1, expert_tensor_parallel_size=1, @@ -32,8 +32,8 @@ def test_build_with_single_module(self, mock_grid_class): grids = build_hypercomm_grids(mimo_config) # Should create one grid - assert "llm" in grids - assert grids["llm"] == mock_grid + assert "language" in grids + assert grids["language"] == mock_grid # Check grid was created with correct shape mock_grid_class.assert_called_once() @@ -57,7 +57,7 @@ def test_build_with_multiple_modules(self, mock_grid_class): """Test build_hypercomm_grids with multiple modules.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=4, data_parallel_size=2, rank_offset=0, @@ -82,7 +82,7 @@ def test_build_with_multiple_modules(self, mock_grid_class): grids = build_hypercomm_grids(mimo_config) # Should create three grids - assert "llm" in grids + assert "language" in grids assert "clip_encoder" in grids assert "dino_encoder" in grids assert len(grids) == 3 @@ -95,7 +95,7 @@ def test_build_with_different_parallelism_per_module(self, mock_grid_class): """Test grids with different parallelism configs per module.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=8, pipeline_model_parallel_size=2, data_parallel_size=1, @@ -134,7 +134,7 @@ def test_build_creates_all_dimension_groups(self, mock_grid_class): """Test that all dimension process groups are created.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, context_parallel_size=2, expert_tensor_parallel_size=2, @@ -170,7 +170,7 @@ def test_build_uses_nccl_backend(self, mock_grid_class): """Test that grids use nccl backend.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), } ) @@ -189,7 +189,7 @@ def test_build_with_rank_offsets(self, mock_grid_class): """Test that rank_offset is correctly passed to grids.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, rank_offset=0, diff --git a/tests/unit_tests/models/mimo/test_mimo_ddp.py b/tests/unit_tests/models/mimo/test_mimo_ddp.py index 8f03ff85b7..e511bcbe14 100644 --- a/tests/unit_tests/models/mimo/test_mimo_ddp.py +++ b/tests/unit_tests/models/mimo/test_mimo_ddp.py @@ -4,66 +4,66 @@ from unittest.mock import MagicMock, patch from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid -from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig +from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed class TestIsCurrentRankInGrid: """Test cases for is_current_rank_in_grid helper.""" - @patch('torch.distributed.get_rank') + @patch("torch.distributed.get_rank") def test_rank_in_grid(self, mock_get_rank): """Rank within grid range should return True.""" mock_get_rank.return_value = 2 - + mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is True - @patch('torch.distributed.get_rank') + @patch("torch.distributed.get_rank") def test_rank_at_grid_start(self, mock_get_rank): """Rank at grid start should return True.""" mock_get_rank.return_value = 4 - + mock_grid = MagicMock() mock_grid.rank_offset = 4 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is True - @patch('torch.distributed.get_rank') + @patch("torch.distributed.get_rank") def test_rank_at_grid_end_exclusive(self, mock_get_rank): """Rank at grid end (exclusive) should return False.""" mock_get_rank.return_value = 8 - + mock_grid = MagicMock() mock_grid.rank_offset = 4 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is False - @patch('torch.distributed.get_rank') + @patch("torch.distributed.get_rank") def test_rank_before_grid(self, mock_get_rank): """Rank before grid range should return False.""" mock_get_rank.return_value = 2 - + mock_grid = MagicMock() mock_grid.rank_offset = 4 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is False - @patch('torch.distributed.get_rank') + @patch("torch.distributed.get_rank") def test_rank_after_grid(self, mock_get_rank): """Rank after grid range should return False.""" mock_get_rank.return_value = 10 - + mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is False @@ -73,13 +73,13 @@ class TestWrapMimoModelDistributed: def _create_mock_mimo_model(self, has_language_model=True, modality_names=None): """Create a mock MimoModel for testing.""" mock_model = MagicMock() - + if has_language_model: mock_model.language_model = MagicMock() mock_model.language_model.config = MagicMock() else: mock_model.language_model = None - + if modality_names: mock_model.modality_submodules = {} for name in modality_names: @@ -89,7 +89,7 @@ def _create_mock_mimo_model(self, has_language_model=True, modality_names=None): mock_model.modality_submodules[name] = submodule else: mock_model.modality_submodules = {} - + return mock_model def _create_mock_grid(self, rank_offset=0, size=4): @@ -113,242 +113,230 @@ def _create_mimo_parallelism_config(self, modules): module_parallelisms=module_parallelisms, ) - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_wrap_language_model(self, mock_get_rank, mock_ddp): """Test that language model is wrapped with DDP when rank participates.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - + mimo_model = self._create_mock_mimo_model(has_language_model=True) ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - }) - - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} - - result = wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + } ) - + + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should wrap language model mock_ddp.assert_called_once() assert result.language_model == mock_ddp.return_value - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_skip_language_model_non_participating_rank(self, mock_get_rank, mock_ddp): """Test that language model is NOT wrapped when rank doesn't participate.""" mock_get_rank.return_value = 10 # Outside grid range - + mimo_model = self._create_mock_mimo_model(has_language_model=True) original_lm = mimo_model.language_model - + ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - }) - - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} - - result = wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + } ) - + + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should NOT wrap language model mock_ddp.assert_not_called() assert result.language_model == original_lm - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_wrap_modality_submodules(self, mock_get_rank, mock_ddp): """Test that modality submodules are wrapped with DDP.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - - mimo_model = self._create_mock_mimo_model( - has_language_model=True, - modality_names=["images"] - ) + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images"]) ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - "images": {"tp": 1, "dp": 4}, - }) - + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + } + ) + grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=0, size=4), } pg_collections = { - "llm": MagicMock(), + "language": MagicMock(), "images": MagicMock(), } - - wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections - ) - + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should wrap both language model and images submodule assert mock_ddp.call_count == 2 - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_skip_modality_submodule_no_grid(self, mock_get_rank, mock_ddp): """Test that modality submodules without grids are skipped.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - - mimo_model = self._create_mock_mimo_model( - has_language_model=True, - modality_names=["images", "audio"] - ) + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images", "audio"]) ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - "images": {"tp": 1, "dp": 4}, - # Note: no "audio" in parallelism config - }) - + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + # Note: no "audio" in parallelism config + } + ) + # Only llm and images have grids grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=0, size=4), } pg_collections = { - "llm": MagicMock(), + "language": MagicMock(), "images": MagicMock(), } - - wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections - ) - + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should wrap llm and images, but not audio (no grid) assert mock_ddp.call_count == 2 - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_heterogeneous_different_rank_ranges(self, mock_get_rank, mock_ddp): """Test heterogeneous deployment with different rank ranges per module.""" mock_get_rank.return_value = 4 # In images grid but not llm grid mock_ddp.return_value = MagicMock() - - mimo_model = self._create_mock_mimo_model( - has_language_model=True, - modality_names=["images"] - ) + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images"]) original_lm = mimo_model.language_model - + ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2, "rank_offset": 0}, - "images": {"tp": 2, "dp": 2, "rank_offset": 4}, - }) - + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2, "rank_offset": 0}, + "images": {"tp": 2, "dp": 2, "rank_offset": 4}, + } + ) + grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=4, size=4), } pg_collections = { - "llm": None, # Rank 4 doesn't participate in LLM + "language": None, # Rank 4 doesn't participate in LLM "images": MagicMock(), } - - result = wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections - ) - + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should wrap only images (rank 4 is in images grid, not llm grid) assert mock_ddp.call_count == 1 # Language model should be unchanged assert result.language_model == original_lm - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_no_language_model(self, mock_get_rank, mock_ddp): """Test model without language model.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - - mimo_model = self._create_mock_mimo_model( - has_language_model=False, - modality_names=["images"] - ) + + mimo_model = self._create_mock_mimo_model(has_language_model=False, modality_names=["images"]) ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - "images": {"tp": 1, "dp": 4}, - }) - + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + } + ) + grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=0, size=4), } pg_collections = { - "llm": MagicMock(), + "language": MagicMock(), "images": MagicMock(), } - - result = wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections - ) - + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should wrap only images (no language model) assert mock_ddp.call_count == 1 assert result.language_model is None - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_returns_same_model_instance(self, mock_get_rank, mock_ddp): """Test that wrap_mimo_model_distributed returns the same model instance.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - + mimo_model = self._create_mock_mimo_model(has_language_model=True) ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - }) - - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} - - result = wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + } ) - + + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Should return the same model instance (modified in-place) assert result is mimo_model - @patch('megatron.core.distributed.DistributedDataParallel') - @patch('torch.distributed.get_rank') + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") def test_ddp_called_with_correct_args(self, mock_get_rank, mock_ddp): """Test that DDP is called with correct arguments.""" mock_get_rank.return_value = 0 mock_ddp.return_value = MagicMock() - + mimo_model = self._create_mock_mimo_model(has_language_model=True) # Capture original config before wrapping (wrapping replaces language_model) original_lm_config = mimo_model.language_model.config original_lm = mimo_model.language_model - + ddp_config = MagicMock() - mimo_parallelism_config = self._create_mimo_parallelism_config({ - "llm": {"tp": 2, "dp": 2}, - }) - - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - llm_pg_collection = MagicMock() - pg_collections = {"llm": llm_pg_collection} - - wrap_mimo_model_distributed( - mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + } ) - + + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + llm_pg_collection = MagicMock() + pg_collections = {"language": llm_pg_collection} + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + # Verify DDP call arguments mock_ddp.assert_called_once() call_kwargs = mock_ddp.call_args.kwargs diff --git a/tests/unit_tests/models/mimo/test_mimo_provider.py b/tests/unit_tests/models/mimo/test_mimo_provider.py index b907aeab6e..6b7bc1c733 100644 --- a/tests/unit_tests/models/mimo/test_mimo_provider.py +++ b/tests/unit_tests/models/mimo/test_mimo_provider.py @@ -34,7 +34,7 @@ def test_provider_initialization_full(self): modality_spec = ModuleSpec(module=Mock, params={}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2), }, ) @@ -93,7 +93,6 @@ def test_provide_returns_model_directly(self, mock_build_grids, mock_mimo_model) mock_build_grids.assert_not_called() config_arg = mock_mimo_model.call_args[0][0] assert config_arg.module_to_grid_map is None - assert config_arg.language_module_key is None @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @@ -121,7 +120,7 @@ def test_build_infra_without_parallelism(self, mock_build_grids): # Should return infrastructure with auto-derived topology assert isinstance(infra, MimoModelInfra) assert infra.module_to_grid_map == {} - assert infra.topology == {"llm": []} + assert infra.topology == {"language": []} assert infra.pg_collections == {} assert infra.participating_modules == [] @@ -141,7 +140,7 @@ def test_build_infra_with_parallelism(self, mock_build_grids, mock_get_rank, moc mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, ), @@ -153,7 +152,7 @@ def test_build_infra_with_parallelism(self, mock_build_grids, mock_get_rank, moc mock_grid.rank_offset = 0 mock_grid.size = 4 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -167,9 +166,9 @@ def test_build_infra_with_parallelism(self, mock_build_grids, mock_get_rank, moc # Should return populated infrastructure assert isinstance(infra, MimoModelInfra) - assert "llm" in infra.module_to_grid_map - assert "llm" in infra.pg_collections - assert "llm" in infra.participating_modules + assert "language" in infra.module_to_grid_map + assert "language" in infra.pg_collections + assert "language" in infra.participating_modules @patch("torch.distributed.new_group") @patch("torch.distributed.get_process_group_ranks") @@ -184,7 +183,7 @@ def test_build_infra_is_idempotent(self, mock_build_grids, mock_get_rank, mock_g mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), }, ) @@ -192,7 +191,7 @@ def test_build_infra_is_idempotent(self, mock_build_grids, mock_get_rank, mock_g mock_grid.rank_offset = 0 mock_grid.size = 2 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -222,7 +221,7 @@ def test_provide_with_parallelism( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, ), @@ -233,7 +232,7 @@ def test_provide_with_parallelism( mock_grid.rank_offset = 0 mock_grid.size = 4 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -248,13 +247,12 @@ def test_provide_with_parallelism( # Should return model directly assert model == mock_model_instance config_arg = mock_mimo_model.call_args[0][0] - assert config_arg.module_to_grid_map == {"llm": mock_grid} - assert config_arg.language_module_key == "llm" + assert config_arg.module_to_grid_map == {"language": mock_grid} # Infrastructure should be available via build_infra() infra = provider.build_infra() - assert "llm" in infra.module_to_grid_map - assert "llm" in infra.pg_collections + assert "language" in infra.module_to_grid_map + assert "language" in infra.pg_collections def test_inject_pg_collection_into_language_spec(self): """Test that pg_collection is injected into language specs.""" @@ -328,7 +326,7 @@ def test_per_encoder_parallelism( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=8, data_parallel_size=1), + "language": ModuleParallelismConfig(tensor_model_parallel_size=8, data_parallel_size=1), "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1), "dino_encoder": ModuleParallelismConfig(tensor_model_parallel_size=4, data_parallel_size=1), }, @@ -351,7 +349,7 @@ def test_per_encoder_parallelism( dino_grid.get_pg.return_value = MagicMock() mock_build_grids.return_value = { - "llm": llm_grid, + "language": llm_grid, "clip_encoder": clip_grid, "dino_encoder": dino_grid, } @@ -379,7 +377,7 @@ def test_per_encoder_parallelism( mock_build_grids.assert_called_with(mimo_parallelism_config) # Should have pg_collections for all modules - assert "llm" in infra.pg_collections + assert "language" in infra.pg_collections assert "clip_encoder" in infra.pg_collections assert "dino_encoder" in infra.pg_collections @@ -436,10 +434,10 @@ class TestMimoModelInfra: def test_infra_initialization(self): """Test infrastructure dataclass initializes correctly.""" - grids = {"llm": MagicMock()} - topology = {"llm": []} - pg_collections = {"llm": MagicMock()} - participating = ["llm"] + grids = {"language": MagicMock()} + topology = {"language": []} + pg_collections = {"language": MagicMock()} + participating = ["language"] infra = MimoModelInfra( module_to_grid_map=grids, @@ -593,7 +591,7 @@ def test_pg_collection_includes_embedding_groups_first_stage( language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), }, ) @@ -608,11 +606,11 @@ def test_pg_collection_includes_embedding_groups_first_stage( mimo_parallelism_config=mimo_parallelism_config, ) - pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) # First stage should have pos_embd but not embd (not last stage) - assert pg_collections["llm"].pos_embd == mock_pos_embd - assert pg_collections["llm"].embd == mock_embd # First stage gets embd too + assert pg_collections["language"].pos_embd == mock_pos_embd + assert pg_collections["language"].embd == mock_embd # First stage gets embd too @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") @@ -632,7 +630,7 @@ def test_pg_collection_middle_stage_no_embedding_groups( language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), }, ) @@ -647,11 +645,11 @@ def test_pg_collection_middle_stage_no_embedding_groups( mimo_parallelism_config=mimo_parallelism_config, ) - pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) # Middle stage should have neither embedding group - assert pg_collections["llm"].pos_embd is None - assert pg_collections["llm"].embd is None + assert pg_collections["language"].pos_embd is None + assert pg_collections["language"].embd is None @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") @@ -667,7 +665,7 @@ def test_pg_collection_includes_composite_groups(self, mock_get_rank, mock_popul language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), }, ) @@ -701,9 +699,9 @@ def test_pg_collection_includes_composite_groups(self, mock_get_rank, mock_popul mimo_parallelism_config=mimo_parallelism_config, ) - pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) - pgc = pg_collections["llm"] + pgc = pg_collections["language"] assert pgc.tp == mock_tp assert pgc.dp == mock_dp assert pgc.pp == mock_pp diff --git a/tests/unit_tests/training/mimo/test_mimo_config.py b/tests/unit_tests/training/mimo/test_mimo_config.py index 3770043d48..3548c0853c 100644 --- a/tests/unit_tests/training/mimo/test_mimo_config.py +++ b/tests/unit_tests/training/mimo/test_mimo_config.py @@ -22,7 +22,7 @@ def test_mimo_heterogeneous_rank_offset_overlap(): """Test that overlapping rank ranges are detected in heterogeneous deployment.""" module_parallelisms = { "encoder": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=0), - "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=2), } mimo_parallelism_config = MimoParallelismConfig( module_parallelisms=module_parallelisms, @@ -36,7 +36,7 @@ def test_mimo_heterogeneous_valid_contiguous(): # Note: encoder DP must be >= LLM DP for embedding alignment module_parallelisms = { "encoder": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=0), - "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=4), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=4), } mimo_parallelism_config = MimoParallelismConfig( module_parallelisms=module_parallelisms, diff --git a/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py b/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py index b64cb96542..a5c20d6666 100644 --- a/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py +++ b/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py @@ -1,51 +1,51 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Unit tests for MIMO parallel utilities.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest class TestIsCurrentRankInGrid: """Test cases for is_current_rank_in_grid().""" - - @patch('megatron.bridge.training.mimo_parallel_utils.dist') + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") def test_rank_in_grid(self, mock_dist): """Test rank within grid range returns True.""" from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid - + mock_dist.get_rank.return_value = 2 mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is True - - @patch('megatron.bridge.training.mimo_parallel_utils.dist') + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") def test_rank_not_in_grid(self, mock_dist): """Test rank outside grid range returns False.""" from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid - + mock_dist.get_rank.return_value = 5 mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 - + assert is_current_rank_in_grid(mock_grid) is False - - @patch('megatron.bridge.training.mimo_parallel_utils.dist') + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") def test_rank_at_grid_boundary(self, mock_dist): """Test rank at grid boundary.""" from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid - + mock_grid = MagicMock() mock_grid.rank_offset = 4 mock_grid.size = 4 - + # At start boundary (inclusive) mock_dist.get_rank.return_value = 4 assert is_current_rank_in_grid(mock_grid) is True - + # At end boundary (exclusive) mock_dist.get_rank.return_value = 8 assert is_current_rank_in_grid(mock_grid) is False @@ -53,74 +53,74 @@ def test_rank_at_grid_boundary(self, mock_dist): class TestValidateNoStubRanks: """Test cases for validate_no_stub_ranks().""" - + def test_all_ranks_participate(self): """Test validation passes when all ranks participate.""" from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks - + mock_grid1 = MagicMock() mock_grid1.rank_offset = 0 mock_grid1.size = 4 - + mock_grid2 = MagicMock() mock_grid2.rank_offset = 4 mock_grid2.size = 4 - + module_to_grid_map = { "encoder": mock_grid1, - "llm": mock_grid2, + "language": mock_grid2, } - + # Should not raise validate_no_stub_ranks(module_to_grid_map, world_size=8) - + def test_stub_ranks_detected(self): """Test validation fails when stub ranks exist.""" from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks - + mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 - - module_to_grid_map = {"llm": mock_grid} - + + module_to_grid_map = {"language": mock_grid} + with pytest.raises(ValueError, match="do not participate in any module"): validate_no_stub_ranks(module_to_grid_map, world_size=8) - + def test_overlapping_grids(self): """Test validation with overlapping grids (colocated case).""" from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks - + mock_grid1 = MagicMock() mock_grid1.rank_offset = 0 mock_grid1.size = 4 - + mock_grid2 = MagicMock() mock_grid2.rank_offset = 0 mock_grid2.size = 4 - + module_to_grid_map = { "encoder": mock_grid1, - "llm": mock_grid2, + "language": mock_grid2, } - + # Should not raise (all 4 ranks participate) validate_no_stub_ranks(module_to_grid_map, world_size=4) class TestValidateDataLoaderContract: """Test cases for validate_data_loader_contract().""" - + def test_valid_configuration(self): """Test validation passes for valid configuration.""" from megatron.bridge.training.mimo_parallel_utils import validate_data_loader_contract - + mock_grid = MagicMock() mock_grid.get_pg_size.return_value = 2 # DP size = 2 - + mock_infra = MagicMock() - mock_infra.module_to_grid_map = {"llm": mock_grid} - + mock_infra.module_to_grid_map = {"language": mock_grid} + # global_batch=16, dp=2, per_dp_batch=8, microbatches=4, micro_batch_size=2 # 4 * 2 = 8 == 16 / 2 ✓ validate_data_loader_contract( @@ -129,17 +129,17 @@ def test_valid_configuration(self): micro_batch_size=2, num_microbatches=4, ) - + def test_batch_not_divisible_by_dp(self): """Test validation fails when batch not divisible by DP size.""" from megatron.bridge.training.mimo_parallel_utils import validate_data_loader_contract - + mock_grid = MagicMock() mock_grid.get_pg_size.return_value = 3 # DP size = 3 - + mock_infra = MagicMock() - mock_infra.module_to_grid_map = {"llm": mock_grid} - + mock_infra.module_to_grid_map = {"language": mock_grid} + with pytest.raises(ValueError, match="not divisible"): validate_data_loader_contract( infra=mock_infra, @@ -151,40 +151,40 @@ def test_batch_not_divisible_by_dp(self): class TestBuildPgCollectionForSchedule: """Test cases for build_pg_collection_for_schedule().""" - + def test_fallback_to_list(self): """Test fallback to list when MultiModuleProcessGroupCollection not available.""" from megatron.bridge.training.mimo_parallel_utils import build_pg_collection_for_schedule - + mock_pg1 = MagicMock() mock_pg2 = MagicMock() - + mock_infra = MagicMock() mock_infra.pg_collections = { "encoder": mock_pg1, - "llm": mock_pg2, + "language": mock_pg2, } - + # This will likely fall back to list since import may fail in test env result = build_pg_collection_for_schedule(mock_infra) - + # Should be either a list or MultiModuleProcessGroupCollection assert result is not None - + def test_filters_none_pg_collections(self): """Test that None pg_collections are filtered out.""" from megatron.bridge.training.mimo_parallel_utils import build_pg_collection_for_schedule - + mock_pg = MagicMock() - + mock_infra = MagicMock() mock_infra.pg_collections = { "encoder": None, # Non-participating module - "llm": mock_pg, + "language": mock_pg, } - + result = build_pg_collection_for_schedule(mock_infra) - + # Should filter out None values if isinstance(result, list): assert len(result) == 1 @@ -193,79 +193,79 @@ def test_filters_none_pg_collections(self): class TestMultimoduleNoSync: """Test cases for multimodule_no_sync context manager.""" - - @patch('megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid') + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") def test_enters_and_exits_contexts(self, mock_in_grid): """Test that no_sync contexts are properly entered and exited.""" from megatron.bridge.training.mimo_parallel_utils import multimodule_no_sync - + mock_in_grid.return_value = True - + mock_module = MagicMock() mock_context = MagicMock() mock_module.no_sync.return_value = mock_context - + mock_grid = MagicMock() - + module_to_grid_tuple = [(mock_module, mock_grid)] - + with multimodule_no_sync(module_to_grid_tuple=module_to_grid_tuple): pass - + # Verify context was entered and exited mock_context.__enter__.assert_called_once() mock_context.__exit__.assert_called_once() - - @patch('megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid') + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") def test_skips_non_participating_modules(self, mock_in_grid): """Test that non-participating modules are skipped.""" from megatron.bridge.training.mimo_parallel_utils import multimodule_no_sync - + mock_in_grid.return_value = False # Not participating - + mock_module = MagicMock() mock_grid = MagicMock() - + module_to_grid_tuple = [(mock_module, mock_grid)] - + with multimodule_no_sync(module_to_grid_tuple=module_to_grid_tuple): pass - + # no_sync should not be called mock_module.no_sync.assert_not_called() class TestZeroGradBufferForMultimodule: """Test cases for zero_grad_buffer_for_multimodule().""" - - @patch('megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid') + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") def test_zeros_grad_buffers(self, mock_in_grid): """Test gradient buffers are zeroed for participating modules.""" from megatron.bridge.training.mimo_parallel_utils import zero_grad_buffer_for_multimodule - + mock_in_grid.return_value = True - + mock_module = MagicMock() mock_grid = MagicMock() - + module_to_grid_tuple = [(mock_module, mock_grid)] - + zero_grad_buffer_for_multimodule(module_to_grid_tuple) - + mock_module.zero_grad_buffer.assert_called_once() - - @patch('megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid') + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") def test_skips_non_participating(self, mock_in_grid): """Test non-participating modules are skipped.""" from megatron.bridge.training.mimo_parallel_utils import zero_grad_buffer_for_multimodule - + mock_in_grid.return_value = False - + mock_module = MagicMock() mock_grid = MagicMock() - + module_to_grid_tuple = [(mock_module, mock_grid)] - + zero_grad_buffer_for_multimodule(module_to_grid_tuple) - + mock_module.zero_grad_buffer.assert_not_called() diff --git a/tests/unit_tests/training/mimo/test_mimo_step.py b/tests/unit_tests/training/mimo/test_mimo_step.py index 7076601d1d..dc26ff26d2 100644 --- a/tests/unit_tests/training/mimo/test_mimo_step.py +++ b/tests/unit_tests/training/mimo/test_mimo_step.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Unit tests for MIMO forward step functions.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest import torch @@ -9,165 +9,165 @@ class TestLossFunc: """Test cases for loss_func().""" - + def test_loss_computation(self): """Test loss is computed correctly with mask.""" from megatron.bridge.training.mimo_step import loss_func - + # Create test data output_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) loss_mask = torch.tensor([1.0, 1.0, 0.0, 1.0]) # Mask out 3rd element - + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) - + # Expected: (1.0*1 + 2.0*1 + 3.0*0 + 4.0*1) = 7.0 assert total_loss.item() == 7.0 # Expected tokens: 3 (sum of mask) assert num_tokens.item() == 3 # Check metrics dict structure - assert 'lm loss' in metrics - + assert "lm loss" in metrics + def test_loss_with_all_ones_mask(self): """Test loss with all-ones mask.""" from megatron.bridge.training.mimo_step import loss_func - + output_tensor = torch.tensor([1.0, 2.0, 3.0]) loss_mask = torch.ones(3) - + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) - + assert total_loss.item() == 6.0 assert num_tokens.item() == 3 - + def test_loss_with_all_zeros_mask(self): """Test loss with all-zeros mask.""" from megatron.bridge.training.mimo_step import loss_func - + output_tensor = torch.tensor([1.0, 2.0, 3.0]) loss_mask = torch.zeros(3) - + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) - + assert total_loss.item() == 0.0 assert num_tokens.item() == 0 class TestGetBatch: """Test cases for get_batch().""" - + def test_returns_none_for_none_iterator(self): """Test returns None when iterator is None.""" from megatron.bridge.training.mimo_step import get_batch - + result = get_batch(None) assert result is None - + def test_returns_none_on_stop_iteration(self): """Test returns None when iterator is exhausted.""" from megatron.bridge.training.mimo_step import get_batch - + empty_iter = iter([]) result = get_batch(empty_iter) assert result is None - + def test_returns_batch_from_iterator(self): """Test returns batch from iterator.""" from megatron.bridge.training.mimo_step import get_batch - - batch = {'input_ids': torch.tensor([1, 2, 3])} + + batch = {"input_ids": torch.tensor([1, 2, 3])} data_iter = iter([batch]) - + result = get_batch(data_iter) - + assert result is not None - assert 'input_ids' in result + assert "input_ids" in result class TestForwardStep: """Test cases for forward_step().""" - - @patch('megatron.bridge.training.mimo_step.unwrap_mimo_model') + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") def test_forward_step_last_stage(self, mock_unwrap): """Test forward step at last pipeline stage returns loss func.""" from megatron.bridge.training.mimo_step import forward_step - + # Create mock state mock_state = MagicMock() - + # Create mock model with role=None (indicates last stage) mock_model = MagicMock() mock_model.role = None # role=None means is_last_stage=True mock_output = torch.tensor([1.0, 2.0]) mock_loss_mask = torch.ones(2) mock_model.return_value = (mock_output, mock_loss_mask) - + # unwrap_mimo_model returns the mock model itself mock_unwrap.return_value = mock_model - + # Create mock iterator - batch = {'input_ids': torch.tensor([1, 2])} + batch = {"input_ids": torch.tensor([1, 2])} data_iter = iter([batch]) - + output, loss_fn = forward_step(mock_state, data_iter, mock_model) - + # At last stage, should return loss function assert loss_fn is not None assert callable(loss_fn) - - @patch('megatron.bridge.training.mimo_step.unwrap_mimo_model') + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") def test_forward_step_intermediate_stage(self, mock_unwrap): """Test forward step at intermediate stage returns None for loss func.""" from megatron.bridge.training.mimo_step import forward_step - + mock_state = MagicMock() mock_model = MagicMock() # Configure role to indicate intermediate stage (not last stage) mock_role = MagicMock() mock_role.has_language_module = True mock_role.has_modality_modules = False - mock_role.language_module_name = 'llm' mock_role.is_last_stage.return_value = False mock_role.is_first_stage.return_value = True mock_model.role = mock_role mock_model.return_value = (torch.tensor([1.0]), None) - + mock_unwrap.return_value = mock_model - - batch = {'input_ids': torch.tensor([1, 2])} + + batch = {"input_ids": torch.tensor([1, 2])} data_iter = iter([batch]) - + output, loss_fn = forward_step(mock_state, data_iter, mock_model) - + # Intermediate stage should return None for loss_fn assert loss_fn is None - - @patch('megatron.bridge.training.mimo_step.unwrap_mimo_model') + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") def test_forward_step_rejects_dict_at_last_stage(self, mock_unwrap): """Test forward step raises error if dict returned at last stage.""" from megatron.bridge.training.mimo_step import forward_step - + mock_state = MagicMock() mock_model = MagicMock() mock_model.role = None # role=None means is_last_stage=True # Return dict (incorrect for last stage) - mock_model.return_value = ({'encoder': torch.tensor([1.0])}, None) - + mock_model.return_value = ({"encoder": torch.tensor([1.0])}, None) + mock_unwrap.return_value = mock_model - - batch = {'input_ids': torch.tensor([1, 2])} + + batch = {"input_ids": torch.tensor([1, 2])} data_iter = iter([batch]) - + with pytest.raises(ValueError, match="Last pipeline stage must return scalar loss"): forward_step(mock_state, data_iter, mock_model) - + def test_forward_step_uses_global_state_signature(self): """Test forward step uses 3-arg signature with GlobalState.""" - from megatron.bridge.training.mimo_step import forward_step import inspect - + + from megatron.bridge.training.mimo_step import forward_step + sig = inspect.signature(forward_step) params = list(sig.parameters.keys()) - + # Should have state as first parameter - assert params[0] == 'state' + assert params[0] == "state" assert len(params) == 3 diff --git a/tests/unit_tests/training/mimo/test_pretrain_mimo.py b/tests/unit_tests/training/mimo/test_pretrain_mimo.py index ec57f4550f..90f0c55c27 100644 --- a/tests/unit_tests/training/mimo/test_pretrain_mimo.py +++ b/tests/unit_tests/training/mimo/test_pretrain_mimo.py @@ -48,15 +48,14 @@ def test_pretrain_mimo_uses_constructor_wired_config( mock_dist.get_rank.return_value = 0 - sentinel_grid_map = {"llm": MagicMock()} + sentinel_grid_map = {"language": MagicMock()} setup_output = _make_setup_output(module_to_grid_map=sentinel_grid_map) mock_setup_mimo.return_value = setup_output - original_grid_map = {"llm": MagicMock()} + original_grid_map = {"language": MagicMock()} unwrapped_model = MagicMock() unwrapped_model.mimo_config = SimpleNamespace( module_to_grid_map=original_grid_map, - language_module_key="llm", ) mock_unwrap_mimo_model.return_value = unwrapped_model @@ -79,7 +78,6 @@ def test_pretrain_mimo_uses_constructor_wired_config( # No post-construction mutation: keep original references/values. assert unwrapped_model.mimo_config.module_to_grid_map is original_grid_map - assert unwrapped_model.mimo_config.language_module_key == "llm" mock_train_mimo.assert_called_once() @@ -97,13 +95,12 @@ def test_pretrain_mimo_asserts_when_constructor_fields_missing(mock_dist, mock_u mock_dist.get_rank.return_value = 0 # Infra indicates MIMO-parallel path is active. - mock_setup_mimo.return_value = _make_setup_output(module_to_grid_map={"llm": MagicMock()}) + mock_setup_mimo.return_value = _make_setup_output(module_to_grid_map={"language": MagicMock()}) # Missing constructor-wired fields should trigger assertion. unwrapped_model = MagicMock() unwrapped_model.mimo_config = SimpleNamespace( module_to_grid_map=None, - language_module_key=None, ) mock_unwrap_mimo_model.return_value = unwrapped_model