diff --git a/src/megatron/bridge/models/mimo/__init__.py b/src/megatron/bridge/models/mimo/__init__.py index 8ae74e58cb..4167f85ba3 100644 --- a/src/megatron/bridge/models/mimo/__init__.py +++ b/src/megatron/bridge/models/mimo/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. from megatron.bridge.models.mimo.llava_provider import LlavaMimoProvider from megatron.bridge.models.mimo.mimo_config import ( @@ -12,9 +12,9 @@ __all__ = [ + "LlavaMimoProvider", + "MimoModelInfra", + "MimoModelProvider", "MimoParallelismConfig", "ModuleParallelismConfig", - "MimoModelProvider", - "MimoModelInfra", - "LlavaMimoProvider", ] diff --git a/src/megatron/bridge/models/mimo/mimo_builder.py b/src/megatron/bridge/models/mimo/mimo_builder.py index 4648154d79..e3f9ad3d45 100644 --- a/src/megatron/bridge/models/mimo/mimo_builder.py +++ b/src/megatron/bridge/models/mimo/mimo_builder.py @@ -1,14 +1,16 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import torch.distributed as dist from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig if TYPE_CHECKING: - from megatron.core.process_groups_config import HyperCommGrid + from megatron.core.hyper_comm_grid import HyperCommGrid def build_hypercomm_grids( @@ -56,3 +58,55 @@ def build_hypercomm_grids( def _default_topology(mimo_parallelism_config: MimoParallelismConfig) -> Dict[str, List[str]]: """Infer a default multi-encoder -> LLM topology.""" return {name: ["llm"] for name in mimo_parallelism_config.module_names if name != "llm"} | {"llm": []} + + +def create_embedding_and_position_groups( + pp_group: dist.ProcessGroup, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """Create embedding-related process groups from PP group ranks. + + Following MCore semantics: + - pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings + - embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings + + IMPORTANT: This calls dist.new_group which is a collective operation. + Must be called on all ranks that could participate. + + Note: VPP (virtual_pipeline_model_parallel_size > 1) is not supported. + With VPP, pp_ranks[0]/pp_ranks[-1] do not reliably identify the stages + that own embeddings. The caller is responsible for asserting VPP is disabled. + + Args: + pp_group: The pipeline parallel process group. + + Returns: + Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None. + """ + if pp_group is None: + return None, None + + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + + # Position embeddings only on first PP stage + pos_embd_ranks = [pp_ranks[0]] + pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) + + # Word embeddings on first and last PP stages (for tied embeddings) + embd_ranks = [pp_ranks[0]] + if len(pp_ranks) > 1 and pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + embd_pg = dist.new_group(ranks=embd_ranks) + + return pos_embd_pg, embd_pg + + +def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool: + """Check if the current rank participates in this grid. + + Args: + grid: A HyperCommGrid instance. + + Returns: + True if dist.get_rank() is within the grid's rank range. + """ + return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size) diff --git a/src/megatron/bridge/models/mimo/mimo_config.py b/src/megatron/bridge/models/mimo/mimo_config.py index cd83ccff23..b45598c6a0 100644 --- a/src/megatron/bridge/models/mimo/mimo_config.py +++ b/src/megatron/bridge/models/mimo/mimo_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. from __future__ import annotations @@ -86,18 +86,48 @@ def total_world_size(self) -> int: def _validate_heterogeneous(self) -> None: """Validate heterogeneous deployment: no overlapping rank ranges.""" ranges = [] - for parallelism in self.module_parallelisms.values(): + for name, parallelism in self.module_parallelisms.items(): if parallelism.data_parallel_size is None: raise ValueError("data_parallel_size must be set for heterogeneous deployment.") - ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks)) + ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks, name)) - ranges.sort() + ranges.sort(key=lambda x: x[0]) for idx in range(1, len(ranges)): prev_end = ranges[idx - 1][1] cur_start = ranges[idx][0] if cur_start < prev_end: raise ValueError("rank_offset ranges overlap in heterogeneous deployment.") + # Check for gaps between modules (likely misconfiguration) + # Gaps in the middle are errors; leading gaps (rank_offset > 0) are warnings + if ranges: + min_rank = ranges[0][0] # Already sorted by rank_offset + max_rank = ranges[-1][1] + + # Collect all covered ranks + covered_ranks = set() + for parallelism in self.module_parallelisms.values(): + start = parallelism.rank_offset + end = start + parallelism.total_ranks + covered_ranks.update(range(start, end)) + + # Check for gaps between min and max (error - likely misconfiguration) + expected_middle = set(range(min_rank, max_rank)) + gaps_in_middle = expected_middle - covered_ranks + if gaps_in_middle: + raise ValueError( + f"Ranks {sorted(gaps_in_middle)} are not assigned to any module in heterogeneous " + f"deployment. This creates a gap between modules which is not allowed." + ) + + # Check for leading gap (ranks 0 to min_rank-1 unused) - warning only + if min_rank > 0: + warnings.warn( + f"Ranks {list(range(min_rank))} (before first module) are not assigned to any " + f"module in heterogeneous deployment. These ranks will be idle during training.", + stacklevel=3, + ) + def finalize(self, world_size: Optional[int]) -> None: """Finalize parallelism config: compute data_parallel_size and validate.""" if "llm" not in self.module_parallelisms: diff --git a/src/megatron/bridge/models/mimo/mimo_ddp.py b/src/megatron/bridge/models/mimo/mimo_ddp.py new file mode 100644 index 0000000000..450051ed4c --- /dev/null +++ b/src/megatron/bridge/models/mimo/mimo_ddp.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""DDP wrapping utilities for MIMO models. + +Called from the training layer after MimoModelProvider.provide(). + +Note: This module only supports DDP wrapping. FSDP is not yet implemented. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid + + +if TYPE_CHECKING: + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.core.hyper_comm_grid import HyperCommGrid + from megatron.core.models.mimo import MimoModel + from megatron.core.process_groups_config import ProcessGroupCollection + + from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig + + +def wrap_mimo_model_distributed( + mimo_model: "MimoModel", + ddp_config: "DistributedDataParallelConfig", + mimo_parallelism_config: "MimoParallelismConfig", + grids: Dict[str, "HyperCommGrid"], + pg_collections: Dict[str, Optional["ProcessGroupCollection"]], +) -> "MimoModel": + """Wrap MIMO model's submodules with DDP. + + Modifies mimo_model in-place and returns it. + + Args: + mimo_model: The MimoModel to wrap. + ddp_config: DDP configuration from Bridge. + mimo_parallelism_config: MIMO parallelism configuration. + grids: Module name to HyperCommGrid mapping. + pg_collections: Module name to ProcessGroupCollection mapping. + + Returns: + The same mimo_model with wrapped submodules. + """ + from megatron.core.distributed import DistributedDataParallel + + # Wrap language model if present and rank participates + if mimo_model.language_model is not None: + llm_grid = grids["llm"] + if is_current_rank_in_grid(llm_grid): + llm_pg = pg_collections.get("llm") + if llm_pg is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=llm_pg, + ) + + # Wrap modality submodules + if hasattr(mimo_model, "modality_submodules"): + for module_name, submodule in mimo_model.modality_submodules.items(): + if submodule is None: + continue + module_grid = grids[module_name] + if not is_current_rank_in_grid(module_grid): + continue + + module_pg = pg_collections.get(module_name) + if module_pg is None: + continue + + # Get config from first encoder in the submodule. + # Note: We use the first encoder's config for DDP bucket sizing. + # This assumes all encoders in a modality submodule share similar + # parallelism settings, which is typical for MIMO models. + if hasattr(submodule, "encoders") and submodule.encoders: + encoder_key = next(iter(submodule.encoders.keys())) + first_encoder = submodule.encoders[encoder_key] + + if not hasattr(first_encoder, "config"): + raise AttributeError( + f"Encoder '{encoder_key}' in modality '{module_name}' does not have " + f"a 'config' attribute. Encoders must be MegatronModule subclasses." + ) + + wrapped = DistributedDataParallel( + config=first_encoder.config, + ddp_config=ddp_config, + module=submodule, + pg_collection=module_pg, + ) + mimo_model.modality_submodules[module_name] = wrapped + + return mimo_model diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index 8722669013..ef34347ea6 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """MIMO Model Provider for heterogeneous multi-module training. This module provides MimoModelProvider, which integrates with the standard @@ -18,9 +18,10 @@ import torch import torch.distributed as dist -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.mimo import MimoModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec @@ -29,8 +30,10 @@ from megatron.bridge.models.mimo.mimo_builder import ( _default_topology, build_hypercomm_grids, + create_embedding_and_position_groups, ) from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig +from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed from megatron.bridge.models.model_provider import ModelProviderMixin @@ -183,6 +186,7 @@ def _get_pg_collections_from_grids( ) -> Dict[str, Optional[ProcessGroupCollection]]: """Get ProcessGroupCollections from HyperCommGrids. + Creates all standard process groups plus embedding groups for PP > 1. Returns None for modules this rank doesn't participate in. """ pg_collections: Dict[str, Optional[ProcessGroupCollection]] = {} @@ -191,13 +195,34 @@ def _get_pg_collections_from_grids( for module_name, grid in grids.items(): # Check if current rank is in this grid's range if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): + pp_group = grid.get_pg(["pp"]) + + assert ( + self.virtual_pipeline_model_parallel_size is None or self.virtual_pipeline_model_parallel_size <= 1 + ), ( + f"VPP (virtual_pipeline_model_parallel_size={self.virtual_pipeline_model_parallel_size}) " + f"is not supported with MIMO embedding groups. pp_ranks[0]/pp_ranks[-1] do not " + f"reliably identify embedding stages under VPP." + ) + + # Create embedding groups for PP > 1 (collective operation on all PP ranks) + pos_embd_pg, embd_pg = create_embedding_and_position_groups(pp_group) + + # Only assign embedding groups to ranks that should have them + first_stage = is_pp_first_stage(pp_group) + last_stage = is_pp_last_stage(pp_group) + pg_collections[module_name] = ProcessGroupCollection( tp=grid.get_pg(["tp"]), dp=grid.get_pg(["dp"]), - pp=grid.get_pg(["pp"]), + pp=pp_group, cp=grid.get_pg(["cp"]), ep=grid.get_pg(["ep"]), dp_cp=grid.get_pg(["dp", "cp"]), + # Position embeddings only on first PP stage + pos_embd=pos_embd_pg if first_stage else None, + # Word embeddings on first and last PP stages (for tied embeddings) + embd=embd_pg if (first_stage or last_stage) else None, ) else: pg_collections[module_name] = None @@ -226,7 +251,7 @@ def _inject_pg_collection_into_modality_spec( # Inject into encoders if spec.submodules and "encoders" in spec.submodules: - for encoder_name, encoder_spec in spec.submodules["encoders"].items(): + for _encoder_name, encoder_spec in spec.submodules["encoders"].items(): if encoder_spec.params is None: encoder_spec.params = {} encoder_spec.params["pg_collection"] = pg_collection @@ -330,15 +355,17 @@ def provide_distributed_model( - Uses per-module HyperCommGrids instead of global mpu - Has different pg_collections per module - May have ranks that don't participate in all modules + - Requires per-submodule DDP wrapping for correct gradient sync The method: 1. Calls finalize() to validate parallelism config 2. Calls build_infra() to create grids and pg_collections 3. Calls provide() to build the model 4. Applies pre-wrap hooks - 5. Moves to device and applies mixed precision - 6. Wraps with DDP using LLM's pg_collection - 7. Applies post-wrap hooks + 5. Moves to device + 6. Wraps each submodule with DDP using its own pg_collection + 7. Applies mixed precision (Float16Module) + 8. Applies post-wrap hooks Args: ddp_config: Configuration for distributed data parallel. @@ -394,6 +421,26 @@ def provide_distributed_model( for m in model_list: m.cuda(torch.cuda.current_device()) + # Set variable_seq_lengths=True for multimodule pipeline support (required by PR 3129) + # This must be set before the model is used in the training loop + for m in model_list: + model_config = get_model_config(m) + model_config.variable_seq_lengths = True + + # Wrap submodules with DDP (before Float16Module) + # MIMO uses per-submodule DDP for heterogeneous parallelism + if wrap_with_ddp and ddp_config is not None and self.mimo_parallelism_config: + model_list = [ + wrap_mimo_model_distributed( + mimo_model=m, + ddp_config=ddp_config, + mimo_parallelism_config=self.mimo_parallelism_config, + grids=infra.module_to_grid_map, + pg_collections=infra.pg_collections, + ) + for m in model_list + ] + # Apply mixed precision wrapper use_fp16 = fp16 if fp16 is not None else self.fp16 use_bf16 = bf16 if bf16 is not None else self.bf16 @@ -407,23 +454,6 @@ def provide_distributed_model( model_config.bf16 = use_bf16 model_list = [Float16Module(model_config, m) for m in model_list] - # Wrap with DDP - if wrap_with_ddp and ddp_config is not None: - # Get LLM's pg_collection for DDP (whole model uses LLM's parallelism for DDP) - if self.mimo_parallelism_config: - pg_collection = infra.pg_collections.get("llm") - else: - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - - if pg_collection is not None: - model_list = self._wrap_with_ddp( - model_list, - ddp_config, - pg_collection, - data_parallel_random_init, - overlap_param_gather_with_optimizer_step, - ) - # Apply post-wrap hooks if final_post_wrap_hook: result = final_post_wrap_hook(model_list) @@ -456,39 +486,6 @@ def composed_hook(model: List[MegatronModule]) -> List[MegatronModule]: return pre_wrap_hook return self.pre_wrap_hook - def _wrap_with_ddp( - self, - model_list: List[MegatronModule], - ddp_config: DistributedDataParallelConfig, - pg_collection: ProcessGroupCollection, - data_parallel_random_init: bool, - overlap_param_gather_with_optimizer_step: bool, - ) -> List[MegatronModule]: - """Wrap model with DistributedDataParallel.""" - ddp_stream = torch.cuda.Stream() - ddp_stream.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(ddp_stream): - model_list = [ - DistributedDataParallel( - config=get_model_config(model_chunk), - ddp_config=ddp_config, - module=model_chunk, - disable_bucketing=(idx > 0) or overlap_param_gather_with_optimizer_step, - pg_collection=pg_collection, - ) - for idx, model_chunk in enumerate(model_list) - ] - - torch.cuda.current_stream().wait_stream(ddp_stream) - - # Broadcast params from data parallel src rank - if data_parallel_random_init: - for m in model_list: - m.broadcast_params() - - return model_list - def initialize_model_parallel( self, seed: Optional[int] = None, @@ -538,44 +535,8 @@ def finalize(self) -> None: Raises: ValueError: If any rank doesn't participate in at least one module. This indicates the parallelism configuration doesn't cover all - ranks in the world. + ranks in the world (validated by MimoParallelismConfig.finalize()). """ if self.mimo_parallelism_config is not None: world_size = dist.get_world_size() if dist.is_initialized() else None self.mimo_parallelism_config.finalize(world_size) - - # Validate all ranks participate in at least one module - self._validate_all_ranks_participate(world_size) - - def _validate_all_ranks_participate(self, world_size: Optional[int]) -> None: - """Validate that all ranks participate in at least one module. - - Args: - world_size: Total number of ranks. If None, validation is skipped. - - Raises: - ValueError: If any rank doesn't participate in a module. - """ - if world_size is None or self.mimo_parallelism_config is None: - return - - # Build grids to determine rank coverage - grids = build_hypercomm_grids(self.mimo_parallelism_config) - - # Collect all ranks that participate in at least one module - participating_ranks = set() - for module_name, grid in grids.items(): - for rank in range(grid.rank_offset, grid.rank_offset + grid.size): - participating_ranks.add(rank) - - # Check for non-participating ranks - all_ranks = set(range(world_size)) - non_participating_ranks = all_ranks - participating_ranks - - if non_participating_ranks: - raise ValueError( - f"Ranks {sorted(non_participating_ranks)} do not participate in any MIMO module. " - f"All {world_size} ranks must be assigned to at least one module. " - f"Adjust MimoParallelismConfig to cover all ranks, or reduce world_size to " - f"{len(participating_ranks)}." - ) diff --git a/tests/unit_tests/models/mimo/test_mimo_ddp.py b/tests/unit_tests/models/mimo/test_mimo_ddp.py new file mode 100644 index 0000000000..64af7b8493 --- /dev/null +++ b/tests/unit_tests/models/mimo/test_mimo_ddp.py @@ -0,0 +1,314 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MIMO DDP wrapping utilities.""" + +from unittest.mock import MagicMock, patch + +from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid +from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig +from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed + + +class TestIsCurrentRankInGrid: + """Test cases for is_current_rank_in_grid helper.""" + + @patch("torch.distributed.get_rank") + def test_rank_in_grid(self, mock_get_rank): + """Rank within grid range should return True.""" + mock_get_rank.return_value = 2 + + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is True + + @patch("torch.distributed.get_rank") + def test_rank_at_grid_start(self, mock_get_rank): + """Rank at grid start should return True.""" + mock_get_rank.return_value = 4 + + mock_grid = MagicMock() + mock_grid.rank_offset = 4 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is True + + @patch("torch.distributed.get_rank") + def test_rank_at_grid_end_exclusive(self, mock_get_rank): + """Rank at grid end (exclusive) should return False.""" + mock_get_rank.return_value = 8 + + mock_grid = MagicMock() + mock_grid.rank_offset = 4 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is False + + @patch("torch.distributed.get_rank") + def test_rank_before_grid(self, mock_get_rank): + """Rank before grid range should return False.""" + mock_get_rank.return_value = 2 + + mock_grid = MagicMock() + mock_grid.rank_offset = 4 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is False + + @patch("torch.distributed.get_rank") + def test_rank_after_grid(self, mock_get_rank): + """Rank after grid range should return False.""" + mock_get_rank.return_value = 10 + + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is False + + +class TestWrapMimoModelDistributed: + """Test cases for wrap_mimo_model_distributed.""" + + def _create_mock_mimo_model(self, has_language_model=True, modality_names=None): + """Create a mock MimoModel for testing.""" + mock_model = MagicMock() + + if has_language_model: + mock_model.language_model = MagicMock() + mock_model.language_model.config = MagicMock() + else: + mock_model.language_model = None + + if modality_names: + mock_model.modality_submodules = {} + for name in modality_names: + submodule = MagicMock() + submodule.encoders = {"encoder": MagicMock()} + submodule.encoders["encoder"].config = MagicMock() + mock_model.modality_submodules[name] = submodule + else: + mock_model.modality_submodules = {} + + return mock_model + + def _create_mock_grid(self, rank_offset=0, size=4): + """Create a mock HyperCommGrid.""" + mock_grid = MagicMock() + mock_grid.rank_offset = rank_offset + mock_grid.size = size + return mock_grid + + def _create_mimo_parallelism_config(self, modules): + """Create a MimoParallelismConfig.""" + module_parallelisms = { + name: ModuleParallelismConfig( + tensor_model_parallel_size=config.get("tp", 1), + data_parallel_size=config.get("dp", 1), + rank_offset=config.get("rank_offset", 0), + ) + for name, config in modules.items() + } + return MimoParallelismConfig( + module_parallelisms=module_parallelisms, + ) + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_wrap_language_model(self, mock_get_rank, mock_ddp): + """Test that language model is wrapped with DDP when rank participates.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True) + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + } + ) + + grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"llm": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should wrap language model + mock_ddp.assert_called_once() + assert result.language_model == mock_ddp.return_value + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_skip_language_model_non_participating_rank(self, mock_get_rank, mock_ddp): + """Test that language model is NOT wrapped when rank doesn't participate.""" + mock_get_rank.return_value = 10 # Outside grid range + + mimo_model = self._create_mock_mimo_model(has_language_model=True) + original_lm = mimo_model.language_model + + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + } + ) + + grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"llm": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should NOT wrap language model + mock_ddp.assert_not_called() + assert result.language_model == original_lm + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_wrap_modality_submodules(self, mock_get_rank, mock_ddp): + """Test that modality submodules are wrapped with DDP.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images"]) + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + } + ) + + grids = { + "llm": self._create_mock_grid(rank_offset=0, size=4), + "images": self._create_mock_grid(rank_offset=0, size=4), + } + pg_collections = { + "llm": MagicMock(), + "images": MagicMock(), + } + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should wrap both language model and images submodule + assert mock_ddp.call_count == 2 + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_heterogeneous_different_rank_ranges(self, mock_get_rank, mock_ddp): + """Test heterogeneous deployment with different rank ranges per module.""" + mock_get_rank.return_value = 4 # In images grid but not llm grid + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images"]) + original_lm = mimo_model.language_model + + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2, "rank_offset": 0}, + "images": {"tp": 2, "dp": 2, "rank_offset": 4}, + } + ) + + grids = { + "llm": self._create_mock_grid(rank_offset=0, size=4), + "images": self._create_mock_grid(rank_offset=4, size=4), + } + pg_collections = { + "llm": None, # Rank 4 doesn't participate in LLM + "images": MagicMock(), + } + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should wrap only images (rank 4 is in images grid, not llm grid) + assert mock_ddp.call_count == 1 + # Language model should be unchanged + assert result.language_model == original_lm + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_no_language_model(self, mock_get_rank, mock_ddp): + """Test model without language model.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=False, modality_names=["images"]) + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + } + ) + + grids = { + "llm": self._create_mock_grid(rank_offset=0, size=4), + "images": self._create_mock_grid(rank_offset=0, size=4), + } + pg_collections = { + "llm": MagicMock(), + "images": MagicMock(), + } + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should wrap only images (no language model) + assert mock_ddp.call_count == 1 + assert result.language_model is None + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_returns_same_model_instance(self, mock_get_rank, mock_ddp): + """Test that wrap_mimo_model_distributed returns the same model instance.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True) + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + } + ) + + grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"llm": MagicMock()} + + result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should return the same model instance (modified in-place) + assert result is mimo_model + + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_ddp_called_with_correct_args(self, mock_get_rank, mock_ddp): + """Test that DDP is called with correct arguments.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True) + # Capture original config before wrapping (wrapping replaces language_model) + original_lm_config = mimo_model.language_model.config + original_lm = mimo_model.language_model + + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "llm": {"tp": 2, "dp": 2}, + } + ) + + grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} + llm_pg_collection = MagicMock() + pg_collections = {"llm": llm_pg_collection} + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Verify DDP call arguments + mock_ddp.assert_called_once() + call_kwargs = mock_ddp.call_args.kwargs + assert call_kwargs["ddp_config"] == ddp_config + assert call_kwargs["pg_collection"] == llm_pg_collection + assert call_kwargs["config"] == original_lm_config + assert call_kwargs["module"] == original_lm diff --git a/tests/unit_tests/models/mimo/test_mimo_provider.py b/tests/unit_tests/models/mimo/test_mimo_provider.py index 52be4d1a69..25aa6c76fb 100644 --- a/tests/unit_tests/models/mimo/test_mimo_provider.py +++ b/tests/unit_tests/models/mimo/test_mimo_provider.py @@ -1,10 +1,9 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """Unit tests for MIMO Model Provider.""" from unittest.mock import MagicMock, Mock, patch import pytest -import torch.nn as nn from megatron.core.transformer.spec_utils import ModuleSpec from megatron.bridge.models.mimo import ( @@ -14,33 +13,6 @@ from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig -class MockModule(nn.Module): - """Mock Module for testing that is a proper torch.nn.Module subclass.""" - - def __init__(self, *args, **kwargs): - super().__init__() - self.args = args - self.kwargs = kwargs - # Add config attribute that Float16Module looks for - self.config = MagicMock() - - def forward(self, *args, **kwargs): - return None - - def cuda(self, device=None): - """Mock cuda() method.""" - # Return self to avoid actual CUDA calls - return self - - def bfloat16(self): - """Mock bfloat16() method.""" - return self - - def half(self): - """Mock half() method.""" - return self - - class TestMimoModelProvider: """Test cases for MimoModelProvider.""" @@ -124,7 +96,7 @@ def test_provide_returns_model_directly(self, mock_build_grids, mock_mimo_model) @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - def test_provide_signature_matches_mixin(self, mock_build_grids, mock_mimo_model): + def test_provide_signature_matches_mixin(self, _mock_build_grids, mock_mimo_model): """Test provide() accepts standard mixin signature arguments.""" language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) provider = MimoModelProvider(language_model_spec=language_spec) @@ -155,12 +127,18 @@ def test_build_infra_without_parallelism(self, mock_build_grids): # Should not build grids mock_build_grids.assert_not_called() + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_build_infra_with_parallelism(self, mock_topology, mock_build_grids, mock_get_rank): + def test_build_infra_with_parallelism( + self, mock_topology, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group + ): """Test build_infra() with parallelism config.""" mock_get_rank.return_value = 0 + mock_get_pg_ranks.return_value = [0, 1, 2, 3] + mock_new_group.return_value = MagicMock() language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( @@ -196,17 +174,23 @@ def test_build_infra_with_parallelism(self, mock_topology, mock_build_grids, moc assert "llm" in infra.pg_collections assert "llm" in infra.participating_modules + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_build_infra_is_idempotent(self, mock_topology, mock_build_grids, mock_get_rank): + def test_build_infra_is_idempotent( + self, mock_topology, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group + ): """Test build_infra() can be called multiple times.""" mock_get_rank.return_value = 0 + mock_get_pg_ranks.return_value = [0, 1] + mock_new_group.return_value = MagicMock() language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), + "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), }, ) @@ -229,13 +213,19 @@ def test_build_infra_is_idempotent(self, mock_topology, mock_build_grids, mock_g # Should return equivalent results (not cached, but same structure) assert infra1.participating_modules == infra2.participating_modules + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_provide_with_parallelism(self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank): + def test_provide_with_parallelism( + self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group + ): """Test provide() with parallelism config.""" mock_get_rank.return_value = 0 + mock_get_pg_ranks.return_value = [0, 1, 2, 3] + mock_new_group.return_value = MagicMock() language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( @@ -321,8 +311,8 @@ def test_non_participating_rank_raises_error( mimo_parallelism_config=mimo_parallelism_config, ) - # Should raise ValueError because ranks 4-7 don't participate - with pytest.raises(ValueError, match="do not participate in any MIMO module"): + # Should raise ValueError because ranks 4-7 are a gap between modules + with pytest.raises(ValueError, match="not assigned to any module"): provider.finalize() def test_inject_pg_collection_into_language_spec(self): @@ -374,192 +364,24 @@ def test_freezing_language_model(self, mock_mimo_model): freeze_language_model=True, ) - _ = provider.provide() + provider.provide() # Check parameter was frozen assert mock_param.requires_grad is False - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_freezing_modality_encoders(self, mock_mimo_model): - """Test freeze_modality_encoders works.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - # Create mock model with modality submodules - mock_model = MagicMock() - mock_param = MagicMock() - mock_param.requires_grad = True - - # Create mock encoder with parameters - mock_encoder = MagicMock() - mock_encoder.parameters.return_value = [mock_param] - - # Create mock modality submodule with encoders - mock_submodule = MagicMock() - mock_submodule.encoders = mock_encoder - - mock_model.modality_submodules = {"images": mock_submodule} - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - freeze_modality_encoders={"images": True}, - ) - - _ = provider.provide() - - # Check encoder parameters were frozen - assert mock_param.requires_grad is False - - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_freezing_modality_projections(self, mock_mimo_model): - """Test freeze_modality_projections works.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - # Create mock model with modality submodules - mock_model = MagicMock() - mock_param = MagicMock() - mock_param.requires_grad = True - - # Create mock projection with parameters - mock_projection = MagicMock() - mock_projection.parameters.return_value = [mock_param] - - # Create mock modality submodule with projections - mock_submodule = MagicMock() - mock_submodule.input_projections = mock_projection - - mock_model.modality_submodules = {"images": mock_submodule} - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - freeze_modality_projections={"images": True}, - ) - - _ = provider.provide() - - # Check projection parameters were frozen - assert mock_param.requires_grad is False - - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_combined_freezing(self, mock_mimo_model): - """Test freezing language model, encoders, and projections together.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - # Create mock model with all components - mock_model = MagicMock() - - # Language model param - mock_lang_param = MagicMock() - mock_lang_param.requires_grad = True - mock_model.language_model.parameters.return_value = [mock_lang_param] - - # Encoder param - mock_enc_param = MagicMock() - mock_enc_param.requires_grad = True - mock_encoder = MagicMock() - mock_encoder.parameters.return_value = [mock_enc_param] - - # Projection param - mock_proj_param = MagicMock() - mock_proj_param.requires_grad = True - mock_projection = MagicMock() - mock_projection.parameters.return_value = [mock_proj_param] - - # Modality submodule - mock_submodule = MagicMock() - mock_submodule.encoders = mock_encoder - mock_submodule.input_projections = mock_projection - - mock_model.modality_submodules = {"images": mock_submodule} - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - freeze_language_model=True, - freeze_modality_encoders={"images": True}, - freeze_modality_projections={"images": True}, - ) - - _ = provider.provide() - - # Check all parameters were frozen - assert mock_lang_param.requires_grad is False - assert mock_enc_param.requires_grad is False - assert mock_proj_param.requires_grad is False - - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_partial_modality_freezing(self, mock_mimo_model): - """Test freezing only specific modalities.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - # Create mock model with multiple modalities - mock_model = MagicMock() - - # Images modality (frozen) - mock_images_param = MagicMock() - mock_images_param.requires_grad = True - mock_images_encoder = MagicMock() - mock_images_encoder.parameters.return_value = [mock_images_param] - mock_images_submodule = MagicMock() - mock_images_submodule.encoders = mock_images_encoder - - # Audio modality (not frozen) - mock_audio_param = MagicMock() - mock_audio_param.requires_grad = True - mock_audio_encoder = MagicMock() - mock_audio_encoder.parameters.return_value = [mock_audio_param] - mock_audio_submodule = MagicMock() - mock_audio_submodule.encoders = mock_audio_encoder - - mock_model.modality_submodules = { - "images": mock_images_submodule, - "audio": mock_audio_submodule, - } - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - freeze_modality_encoders={"images": True}, # Only freeze images - ) - - _ = provider.provide() - - # Check only images parameters were frozen - assert mock_images_param.requires_grad is False - assert mock_audio_param.requires_grad is True # Should remain unfrozen - - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_freezing_with_missing_attributes(self, mock_mimo_model): - """Test freezing handles missing attributes gracefully.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - # Create mock model without expected attributes - mock_model = MagicMock() - # No language_model attribute - del mock_model.language_model - # No modality_submodules attribute - del mock_model.modality_submodules - - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - freeze_language_model=True, - freeze_modality_encoders={"images": True}, - freeze_modality_projections={"images": True}, - ) - - # Should not raise an error - _ = provider.provide() - + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_per_encoder_parallelism(self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank): + def test_per_encoder_parallelism( + self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group + ): """Test per-encoder parallelism with different TP per encoder.""" mock_get_rank.return_value = 0 + mock_get_pg_ranks.return_value = [0, 1, 2, 3] + mock_new_group.return_value = MagicMock() language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) clip_spec = ModuleSpec(module=Mock, params={}) dino_spec = ModuleSpec(module=Mock, params={}) @@ -639,553 +461,253 @@ def test_initialize_model_parallel_is_noop(self): provider.initialize_model_parallel(seed=42) provider.initialize_model_parallel() - def test_tensor_model_parallel_size_property_with_config(self): - """Test tensor_model_parallel_size property returns LLM's TP size.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=8), - } - ) - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, - ) - - assert provider.tensor_model_parallel_size == 8 - - def test_tensor_model_parallel_size_property_without_config(self): - """Test tensor_model_parallel_size property returns 1 without config.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - provider = MimoModelProvider(language_model_spec=language_spec) - - assert provider.tensor_model_parallel_size == 1 - - def test_pipeline_model_parallel_size_property_with_config(self): - """Test pipeline_model_parallel_size property returns LLM's PP size.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(pipeline_model_parallel_size=4), - } - ) - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, - ) - - assert provider.pipeline_model_parallel_size == 4 - - def test_pipeline_model_parallel_size_property_without_config(self): - """Test pipeline_model_parallel_size property returns 1 without config.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - provider = MimoModelProvider(language_model_spec=language_spec) - - assert provider.pipeline_model_parallel_size == 1 - - def test_context_parallel_size_property_with_config(self): - """Test context_parallel_size property returns LLM's CP size.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(context_parallel_size=2), - } - ) - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, - ) - - assert provider.context_parallel_size == 2 - - def test_context_parallel_size_property_without_config(self): - """Test context_parallel_size property returns 1 without config.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - provider = MimoModelProvider(language_model_spec=language_spec) - - assert provider.context_parallel_size == 1 - - -class TestMimoModelProviderDistributed: - """Test cases for MimoModelProvider.provide_distributed_model().""" - - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("torch.distributed.get_world_size") - @patch("torch.distributed.get_rank") + @patch("megatron.core.transformer.module.Float16Module") + @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_basic_provide_distributed_model_flow( - self, - mock_topology, - mock_build_grids, - mock_mimo_model, - mock_get_rank, - mock_get_world_size, - mock_is_initialized, - mock_current_device, - ): - """Test basic provide_distributed_model() flow without DDP.""" - mock_is_initialized.return_value = True - mock_get_world_size.return_value = 4 - mock_get_rank.return_value = 0 - mock_current_device.return_value = 0 - - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), - } - ) - - mock_grid = MagicMock() - mock_grid.rank_offset = 0 - mock_grid.size = 4 - mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, - ) - - result = provider.provide_distributed_model(wrap_with_ddp=False) - - # Should return list with model - assert isinstance(result, list) - assert len(result) == 1 - assert result[0] is not None - - @patch("torch.cuda.stream") - @patch("torch.cuda.is_available") - @patch("torch.cuda.set_device") - @patch("torch.cuda.device_count") - @patch("torch.cuda.Stream") - @patch("torch.cuda.current_stream") - @patch("torch.cuda.current_device") @patch("torch.distributed.is_initialized") - @patch("torch.distributed.get_world_size") - @patch("torch.distributed.get_rank") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - @patch("megatron.bridge.models.mimo.mimo_provider.DistributedDataParallel") - @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") - def test_with_ddp_wrapping( - self, - mock_get_config, - mock_topology, - mock_build_grids, - mock_ddp, - mock_mimo_model, - mock_get_rank, - mock_get_world_size, - mock_is_initialized, - mock_current_device, - mock_current_stream, - mock_stream_class, - mock_device_count, - mock_set_device, - mock_is_available, - mock_stream_ctx, + def test_provide_distributed_model_sets_variable_seq_lengths( + self, mock_is_init, mock_build_grids, mock_mimo_model, mock_get_config, mock_float16 ): - """Test DDP wrapping with data_parallel_random_init=True.""" - from megatron.core.distributed import DistributedDataParallelConfig - - mock_is_initialized.return_value = True - mock_get_world_size.return_value = 4 - mock_get_rank.return_value = 0 - mock_current_device.return_value = 0 - mock_device_count.return_value = 8 # Mock sufficient GPUs - mock_set_device.return_value = None # Mock set_device to avoid CUDA calls - mock_is_available.return_value = True # Mock CUDA availability - - # Mock the stream context manager - mock_stream_ctx.return_value.__enter__ = MagicMock(return_value=None) - mock_stream_ctx.return_value.__exit__ = MagicMock(return_value=None) - - # Mock streams - mock_stream = MagicMock() - mock_stream_class.return_value = mock_stream - mock_current_stream.return_value = MagicMock() - + """Test that provide_distributed_model sets variable_seq_lengths=True.""" + mock_is_init.return_value = False language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), - } - ) - - mock_grid = MagicMock() - mock_grid.rank_offset = 0 - mock_grid.size = 4 - mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model - mock_get_config.return_value = Mock() - - # Mock DDP wrapper - mock_ddp_model = MagicMock() - mock_ddp_model.broadcast_params = MagicMock() - mock_ddp.return_value = mock_ddp_model provider = MimoModelProvider( language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, + bf16=False, # Disable to simplify test + fp16=False, ) - ddp_config = DistributedDataParallelConfig() - provider.provide_distributed_model(ddp_config=ddp_config, wrap_with_ddp=True, data_parallel_random_init=True) - - # Should wrap with DDP - assert mock_ddp.called - # Should broadcast params - mock_ddp_model.broadcast_params.assert_called_once() - - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_ddp_config_required_when_wrap_with_ddp_true( - self, mock_mimo_model, mock_is_initialized, mock_current_device - ): - """Test ValueError raised when wrap_with_ddp=True but ddp_config=None.""" - mock_is_initialized.return_value = False - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - provider = MimoModelProvider(language_model_spec=language_spec) - - with pytest.raises(ValueError, match="ddp_config is required when wrap_with_ddp is True"): - provider.provide_distributed_model(wrap_with_ddp=True, ddp_config=None) - - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_cpu_initialization(self, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test model stays on CPU when use_cpu_initialization=True.""" - mock_is_initialized.return_value = False - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider(language_model_spec=language_spec) - - result = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) - - # Should NOT move to CUDA (we can't easily assert on MockModule methods, so just check result) - assert len(result) == 1 - - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_meta_device_initialization(self, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test model stays on meta device when init_model_with_meta_device=True.""" - mock_is_initialized.return_value = False - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model - - provider = MimoModelProvider(language_model_spec=language_spec) - - result = provider.provide_distributed_model(wrap_with_ddp=False, init_model_with_meta_device=True) - - # Should NOT move to CUDA (we can't easily assert on MockModule methods, so just check result) - assert len(result) == 1 - - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") - def test_fp16_handling(self, mock_get_config, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test FP16 mixed precision wrapper.""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model + mock_model_instance = MagicMock() + mock_model_instance.cuda = MagicMock(return_value=None) + mock_mimo_model.return_value = mock_model_instance mock_config = MagicMock() + mock_config.variable_seq_lengths = False # Initial value mock_get_config.return_value = mock_config - provider = MimoModelProvider(language_model_spec=language_spec) - - with patch("megatron.core.transformer.module.Float16Module") as mock_float16: - mock_wrapped = MagicMock() - mock_float16.return_value = mock_wrapped + # No parallelism config means no DDP wrapping needed + provider.provide_distributed_model(wrap_with_ddp=False) - result = provider.provide_distributed_model(wrap_with_ddp=False, fp16=True) + # Should have set variable_seq_lengths=True + assert mock_config.variable_seq_lengths is True - # Should wrap with Float16Module - mock_float16.assert_called_once() - assert mock_config.fp16 is True - assert result[0] == mock_wrapped - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") - def test_bf16_handling(self, mock_get_config, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test BF16 mixed precision wrapper (default).""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model +class TestMimoModelInfra: + """Test cases for MimoModelInfra dataclass.""" - mock_config = MagicMock() - mock_get_config.return_value = mock_config + def test_infra_initialization(self): + """Test infrastructure dataclass initializes correctly.""" + grids = {"llm": MagicMock()} + topology = {"llm": []} + pg_collections = {"llm": MagicMock()} + participating = ["llm"] - provider = MimoModelProvider(language_model_spec=language_spec, bf16=True) + infra = MimoModelInfra( + module_to_grid_map=grids, + topology=topology, + pg_collections=pg_collections, + participating_modules=participating, + ) - with patch("megatron.core.transformer.module.Float16Module") as mock_float16: - mock_wrapped = MagicMock() - mock_float16.return_value = mock_wrapped + assert infra.module_to_grid_map == grids + assert infra.topology == topology + assert infra.pg_collections == pg_collections + assert infra.participating_modules == participating - provider.provide_distributed_model(wrap_with_ddp=False) - # Should wrap with Float16Module for BF16 - mock_float16.assert_called_once() - assert mock_config.bf16 is True +class TestEmbeddingGroupHelpers: + """Test cases for embedding group helper functions.""" - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") - def test_custom_mixed_precision_wrapper( - self, mock_get_config, mock_mimo_model, mock_is_initialized, mock_current_device - ): - """Test custom mixed precision wrapper is used.""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") + def test_populate_embedding_groups_single_pp_rank(self, mock_get_ranks, mock_new_group): + """Test embedding groups with single PP rank (PP=1).""" + from megatron.bridge.models.mimo.mimo_builder import ( + create_embedding_and_position_groups, + ) - mock_model = MagicMock() - mock_model.cuda = MagicMock(return_value=mock_model) - mock_mimo_model.return_value = mock_model + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0] # Single PP rank + mock_new_group.return_value = MagicMock() - mock_config = MagicMock() - mock_get_config.return_value = mock_config + create_embedding_and_position_groups(mock_pp_group) - # Custom wrapper - mock_custom_wrapper = MagicMock() - mock_wrapped = MagicMock() - mock_custom_wrapper.return_value = mock_wrapped + # Should create groups for both position and word embeddings + assert mock_new_group.call_count == 2 + # Both groups should include only rank 0 + calls = mock_new_group.call_args_list + assert calls[0].kwargs["ranks"] == [0] + assert calls[1].kwargs["ranks"] == [0] - provider = MimoModelProvider(language_model_spec=language_spec, fp16=True) + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_process_group_ranks") + def test_populate_embedding_groups_multiple_pp_ranks(self, mock_get_ranks, mock_new_group): + """Test embedding groups with multiple PP ranks (PP>1).""" + from megatron.bridge.models.mimo.mimo_builder import ( + create_embedding_and_position_groups, + ) - result = provider.provide_distributed_model(wrap_with_ddp=False, mixed_precision_wrapper=mock_custom_wrapper) + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0, 4, 8, 12] # PP=4 + mock_new_group.return_value = MagicMock() - # Should use custom wrapper instead of Float16Module - mock_custom_wrapper.assert_called_once_with(mock_config, mock_model) - assert result[0] == mock_wrapped + create_embedding_and_position_groups(mock_pp_group) - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_pre_wrap_hook_single(self, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test single pre-wrap hook is called.""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + # Should create two groups + assert mock_new_group.call_count == 2 + calls = mock_new_group.call_args_list + # pos_embd only on first rank + assert calls[0].kwargs["ranks"] == [0] + # embd on first and last ranks + assert calls[1].kwargs["ranks"] == [0, 12] - mock_model = MockModule() - mock_mimo_model.return_value = mock_model + def test_populate_embedding_groups_none_pp_group(self): + """Test embedding groups with None PP group.""" + from megatron.bridge.models.mimo.mimo_builder import ( + create_embedding_and_position_groups, + ) - # Pre-wrap hook - hook_called = [] + pos_embd_pg, embd_pg = create_embedding_and_position_groups(None) - def pre_hook(models): - hook_called.append(True) - return models + assert pos_embd_pg is None + assert embd_pg is None - provider = MimoModelProvider(language_model_spec=language_spec) + def test_is_pp_first_stage_true(self): + """Test is_pp_first_stage returns True when group rank is 0.""" + from megatron.core.pipeline_parallel.utils import is_pp_first_stage - provider.provide_distributed_model(wrap_with_ddp=False, pre_wrap_hook=pre_hook) + mock_pp_group = MagicMock() + mock_pp_group.rank.return_value = 0 + mock_pp_group.size.return_value = 4 - # Hook should be called - assert len(hook_called) == 1 + with patch("torch.distributed.is_initialized", return_value=True): + assert is_pp_first_stage(mock_pp_group) is True - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_pre_wrap_hooks_multiple(self, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test multiple pre-wrap hooks are called in order.""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + def test_is_pp_first_stage_false(self): + """Test is_pp_first_stage returns False when group rank is not 0.""" + from megatron.core.pipeline_parallel.utils import is_pp_first_stage - mock_model = MockModule() - mock_mimo_model.return_value = mock_model + mock_pp_group = MagicMock() + mock_pp_group.rank.return_value = 1 + mock_pp_group.size.return_value = 4 - # Track hook execution order - hook_order = [] + with patch("torch.distributed.is_initialized", return_value=True): + assert is_pp_first_stage(mock_pp_group) is False - def hook1(models): - hook_order.append(1) - return models + def test_is_pp_first_stage_none_group(self): + """Test is_pp_first_stage returns True for None group (treated as rank 0, no PP).""" + from megatron.core.pipeline_parallel.utils import is_pp_first_stage - def hook2(models): - hook_order.append(2) - return models + assert is_pp_first_stage(None) is True - provider = MimoModelProvider(language_model_spec=language_spec) + def test_is_pp_last_stage_true(self): + """Test is_pp_last_stage returns True when group rank is last.""" + from megatron.core.pipeline_parallel.utils import is_pp_last_stage - provider.provide_distributed_model(wrap_with_ddp=False, pre_wrap_hook=[hook1, hook2]) + mock_pp_group = MagicMock() + mock_pp_group.rank.return_value = 3 + mock_pp_group.size.return_value = 4 - # Hooks should be called in order - assert hook_order == [1, 2] + with patch("torch.distributed.is_initialized", return_value=True): + assert is_pp_last_stage(mock_pp_group) is True - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - def test_post_wrap_hook(self, mock_mimo_model, mock_is_initialized, mock_current_device): - """Test post-wrap hook is called after everything.""" - mock_is_initialized.return_value = False - mock_current_device.return_value = 0 - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + def test_is_pp_last_stage_false(self): + """Test is_pp_last_stage returns False when group rank is not last.""" + from megatron.core.pipeline_parallel.utils import is_pp_last_stage - mock_model = MockModule() - mock_mimo_model.return_value = mock_model + mock_pp_group = MagicMock() + mock_pp_group.rank.return_value = 1 + mock_pp_group.size.return_value = 4 - # Post-wrap hook - hook_called = [] + with patch("torch.distributed.is_initialized", return_value=True): + assert is_pp_last_stage(mock_pp_group) is False - def post_hook(models): - hook_called.append(True) - return models + def test_is_pp_last_stage_none_group(self): + """Test is_pp_last_stage returns True for None group (treated as rank 0 == size-1 == 0, no PP).""" + from megatron.core.pipeline_parallel.utils import is_pp_last_stage - provider = MimoModelProvider(language_model_spec=language_spec) + assert is_pp_last_stage(None) is True - provider.provide_distributed_model(wrap_with_ddp=False, post_wrap_hook=post_hook) - # Hook should be called - assert len(hook_called) == 1 +class TestProcessGroupCollectionWithEmbeddingGroups: + """Test that ProcessGroupCollection includes embedding groups.""" - @patch("torch.cuda.stream") - @patch("torch.cuda.is_available") - @patch("torch.cuda.set_device") - @patch("torch.cuda.device_count") - @patch("torch.cuda.Stream") - @patch("torch.cuda.current_stream") - @patch("torch.cuda.current_device") - @patch("torch.distributed.is_initialized") - @patch("torch.distributed.get_world_size") + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.create_embedding_and_position_groups") @patch("torch.distributed.get_rank") - @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") - @patch("megatron.bridge.models.mimo.mimo_provider.DistributedDataParallel") - @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") - def test_overlap_param_gather( - self, - mock_get_config, - mock_topology, - mock_build_grids, - mock_ddp, - mock_mimo_model, - mock_get_rank, - mock_get_world_size, - mock_is_initialized, - mock_current_device, - mock_current_stream, - mock_stream_class, - mock_device_count, - mock_set_device, - mock_is_available, - mock_stream_ctx, + def test_pg_collection_includes_embedding_groups_first_stage( + self, mock_get_rank, mock_populate, mock_is_first, mock_is_last ): - """Test overlap_param_gather_with_optimizer_step sets disable_bucketing.""" - from megatron.core.distributed import DistributedDataParallelConfig - - mock_is_initialized.return_value = True - mock_get_world_size.return_value = 4 + """Test that pg_collection includes embedding groups for first PP stage.""" mock_get_rank.return_value = 0 - mock_current_device.return_value = 0 - mock_device_count.return_value = 8 # Mock sufficient GPUs - mock_set_device.return_value = None # Mock set_device to avoid CUDA calls - mock_is_available.return_value = True # Mock CUDA availability - - # Mock the stream context manager - mock_stream_ctx.return_value.__enter__ = MagicMock(return_value=None) - mock_stream_ctx.return_value.__exit__ = MagicMock(return_value=None) - - # Mock streams - mock_stream = MagicMock() - mock_stream_class.return_value = mock_stream - mock_current_stream.return_value = MagicMock() + mock_pos_embd = MagicMock() + mock_embd = MagicMock() + mock_populate.return_value = (mock_pos_embd, mock_embd) + mock_is_first.return_value = True + mock_is_last.return_value = False language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mimo_config = MimoParallelismConfig( + mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), - } + }, ) + # Mock grid mock_grid = MagicMock() mock_grid.rank_offset = 0 mock_grid.size = 4 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} - - mock_model = MockModule() - mock_mimo_model.return_value = mock_model - mock_get_config.return_value = Mock() - - mock_ddp_model = MagicMock() - mock_ddp_model.broadcast_params = MagicMock() - mock_ddp.return_value = mock_ddp_model provider = MimoModelProvider( language_model_spec=language_spec, - mimo_parallelism_config=mimo_config, + mimo_parallelism_config=mimo_parallelism_config, ) - ddp_config = DistributedDataParallelConfig() - provider.provide_distributed_model( - ddp_config=ddp_config, - wrap_with_ddp=True, - overlap_param_gather_with_optimizer_step=True, - data_parallel_random_init=False, - ) + pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) - # Check disable_bucketing was set correctly - call_kwargs = mock_ddp.call_args[1] - assert call_kwargs["disable_bucketing"] is True + # First stage should have pos_embd but not embd (not last stage) + assert pg_collections["llm"].pos_embd == mock_pos_embd + assert pg_collections["llm"].embd == mock_embd # First stage gets embd too + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.create_embedding_and_position_groups") + @patch("torch.distributed.get_rank") + def test_pg_collection_middle_stage_no_embedding_groups( + self, mock_get_rank, mock_populate, mock_is_first, mock_is_last + ): + """Test that middle PP stages don't get embedding groups.""" + mock_get_rank.return_value = 4 + mock_pos_embd = MagicMock() + mock_embd = MagicMock() + mock_populate.return_value = (mock_pos_embd, mock_embd) + mock_is_first.return_value = False + mock_is_last.return_value = False -class TestMimoModelInfra: - """Test cases for MimoModelInfra dataclass.""" + language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + mimo_parallelism_config = MimoParallelismConfig( + module_parallelisms={ + "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + }, + ) - def test_infra_initialization(self): - """Test infrastructure dataclass initializes correctly.""" - grids = {"llm": MagicMock()} - topology = {"llm": []} - pg_collections = {"llm": MagicMock()} - participating = ["llm"] + # Mock grid + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 8 + mock_grid.get_pg.return_value = MagicMock() - infra = MimoModelInfra( - module_to_grid_map=grids, - topology=topology, - pg_collections=pg_collections, - participating_modules=participating, + provider = MimoModelProvider( + language_model_spec=language_spec, + mimo_parallelism_config=mimo_parallelism_config, ) - assert infra.module_to_grid_map == grids - assert infra.topology == topology - assert infra.pg_collections == pg_collections - assert infra.participating_modules == participating + pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + + # Middle stage should have neither embedding group + assert pg_collections["llm"].pos_embd is None + assert pg_collections["llm"].embd is None diff --git a/tests/unit_tests/training/mimo/test_mimo_config.py b/tests/unit_tests/training/mimo/test_mimo_config.py index c701ffc393..d961b1a3e9 100644 --- a/tests/unit_tests/training/mimo/test_mimo_config.py +++ b/tests/unit_tests/training/mimo/test_mimo_config.py @@ -1,7 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -import warnings - +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. import pytest from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig @@ -34,173 +31,15 @@ def test_mimo_heterogeneous_rank_offset_overlap(): mimo_parallelism_config.finalize(world_size=None) -def test_module_parallelism_total_model_parallel_size_property(): - """Test total_model_parallel_size calculation.""" - parallelism = ModuleParallelismConfig( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - context_parallel_size=2, - expert_tensor_parallel_size=2, - ) - assert parallelism.total_model_parallel_size == 16 # 2 * 2 * 2 * 2 - - -def test_module_parallelism_total_ranks_property(): - """Test total_ranks property.""" - parallelism = ModuleParallelismConfig( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - data_parallel_size=4, - ) - assert parallelism.total_ranks == 16 # (2 * 2) * 4 - - -def test_module_parallelism_total_ranks_raises_without_dp(): - """Test total_ranks raises error when data_parallel_size is None.""" - parallelism = ModuleParallelismConfig( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - ) - with pytest.raises(ValueError, match="data_parallel_size must be set"): - _ = parallelism.total_ranks - - -def test_module_parallelism_expert_tensor_parallel_warning(): - """Test warning when using expert_tensor_parallel_size > 1 with pipeline > 1.""" - parallelism = ModuleParallelismConfig( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=2, - expert_tensor_parallel_size=2, - data_parallel_size=2, - ) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - parallelism.finalize(world_size=None) - - # Check warning was raised - assert len(w) == 1 - assert "expert_tensor_parallel_size > 1 with pipeline_model_parallel_size > 1" in str(w[0].message) - - -def test_module_parallelism_data_parallel_validation(): - """Test data_parallel_size validation.""" - parallelism = ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=0, # Invalid - ) - with pytest.raises(ValueError, match="data_parallel_size must be positive"): - parallelism.finalize(world_size=None) - - -def test_mimo_parallelism_total_world_size_property(): - """Test total_world_size calculation.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=0, - ), - "encoder": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=4, - ), - } - ) - - # Total world size should be 8 (ranks 0-7) - assert mimo_config.total_world_size == 8 - - -def test_mimo_parallelism_module_names_property(): - """Test module_names property.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - "dino_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - module_names = mimo_config.module_names - assert "llm" in module_names - assert "clip_encoder" in module_names - assert "dino_encoder" in module_names - assert len(module_names) == 3 - - -def test_mimo_heterogeneous_edge_touching_ranges(): - """Test that edge-touching ranges (no overlap) are valid.""" +def test_mimo_heterogeneous_valid_contiguous(): + """Test that contiguous rank allocation works correctly.""" module_parallelisms = { - "llm": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=0, # ranks 0-3 - ), - "encoder": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=4, # ranks 4-7 (touching but not overlapping) - ), - } - mimo_config = MimoParallelismConfig(module_parallelisms=module_parallelisms) - - # Should not raise an error - mimo_config.finalize(world_size=None) - - -def test_mimo_heterogeneous_multiple_overlaps(): - """Test detection of multiple overlapping ranges.""" - module_parallelisms = { - "llm": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=0, # ranks 0-3 - ), - "encoder1": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=2, # ranks 2-5 (overlaps with llm) - ), - "encoder2": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=6, # ranks 6-9 - ), + "encoder": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=0), + "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=2), } - mimo_config = MimoParallelismConfig(module_parallelisms=module_parallelisms) - - # Should detect overlap - with pytest.raises(ValueError, match="overlap"): - mimo_config.finalize(world_size=None) - - -def test_mimo_finalize_missing_llm_module(): - """Test that finalize raises error when 'llm' module is missing.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "encoder": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), - } - ) - - with pytest.raises(ValueError, match="LLM module 'llm' must be in module_parallelisms"): - mimo_config.finalize(world_size=None) - - -def test_mimo_finalize_world_size_mismatch(): - """Test that finalize detects world size mismatch.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig( - tensor_model_parallel_size=2, - data_parallel_size=2, - rank_offset=0, - ), - } + mimo_parallelism_config = MimoParallelismConfig( + module_parallelisms=module_parallelisms, ) - - # Expected world size is 4, but providing 8 - with pytest.raises(ValueError, match="MIMO world size mismatch"): - mimo_config.finalize(world_size=8) + # No gaps, no overlap - should pass + mimo_parallelism_config.finalize(world_size=None) + assert mimo_parallelism_config.total_world_size == 6