diff --git a/.gitignore b/.gitignore index a9ce4aa0a93..5556d1d5a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ runs/ # Sphinx documentation docs/_build -docs/apidocs \ No newline at end of file +docs/apidocs + +# Git worktrees +.worktrees/ \ No newline at end of file diff --git a/megatron/core/models/mimo/config/__init__.py b/megatron/core/models/mimo/config/__init__.py index 8371675a22d..3da744a6fb2 100644 --- a/megatron/core/models/mimo/config/__init__.py +++ b/megatron/core/models/mimo/config/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole -__all__ = ['MimoModelConfig'] +__all__ = ['MimoModelConfig', 'ModuleStageInfo', 'RankRole'] diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index 961f8930e23..a92484a5a48 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -2,8 +2,9 @@ import warnings from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, Optional +from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.transformer.spec_utils import ModuleSpec @@ -20,6 +21,11 @@ class MimoModelConfig: Dictionary mapping modality names to their special token IDs. For example, {"vision": -200, "audio":32000}, these represent placeholders in the input_ids to insert the modality embeddings at the correct positions. + module_to_grid_map (Optional[Dict[str, HyperCommGrid]]): + Dictionary mapping module keys (e.g., "vision", "language") to their + corresponding HyperCommGrid configurations for non-colocated pipeline + parallelism. The language model must use the key MIMO_LANGUAGE_MODULE_KEY. + When None, all modules are assumed to be colocated on the same ranks. kv_format (str): Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim). Default is "sbhd". @@ -35,4 +41,5 @@ class MimoModelConfig: language_model_spec: ModuleSpec = field(default_factory=ModuleSpec) modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) special_token_ids: Dict[str, int] = field(default_factory=dict) + module_to_grid_map: Optional[Dict[str, HyperCommGrid]] = None kv_format: str = "sbhd" diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py new file mode 100644 index 00000000000..77c2512e8e6 --- /dev/null +++ b/megatron/core/models/mimo/config/role.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Data classes for MIMO rank role management in multi-module pipeline parallelism.""" + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List + +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid + +logger = logging.getLogger(__name__) + +# Fixed key for the language module in module_to_grid_map and RankRole. +# MIMO always has exactly one language model, so this is not configurable. +MIMO_LANGUAGE_MODULE_KEY = "language" + + +class ModuleLayout(Enum): + """Pipeline mode for MIMO multi-module parallelism. + + Determines how modules are distributed across ranks and which + forward path is used. + + UNIFIED: No module_to_grid_map. All modules share same ranks and + parallelism. Uses the unified forward path (_forward_all_modules). + + NON_COLOCATED: module_to_grid_map is set with non-overlapping rank + ranges. Each rank runs EITHER encoder(s) OR the language model. + Uses role-based dispatch with separate forward paths. + + COLOCATED: (future) module_to_grid_map is set with overlapping rank + ranges. Encoder(s) and language model share ranks but have + different parallelism configs. Uses role-based dispatch but + allows both module types on the same rank. + """ + + UNIFIED = "unified" + NON_COLOCATED = "non_colocated" + COLOCATED = "colocated" + + +@dataclass +class ModuleStageInfo: + """Information about a rank's stage position within a module's pipeline. + + Args: + is_first_stage: True if this rank is the first PP stage for this module. + is_last_stage: True if this rank is the last PP stage for this module. + """ + + is_first_stage: bool + is_last_stage: bool + + +@dataclass +class RankRole: + """Describes what modules this rank participates in for multi-module PP. + + This class captures the role of a specific rank in a multi-module pipeline + parallel setup, tracking which modules the rank participates in and their + stage positions. The language module is always identified by MIMO_LANGUAGE_MODULE_KEY. + + Args: + modules: Dict mapping module names to their stage info for modules + this rank participates in. + mode: Pipeline mode determining forward path dispatch. + """ + + modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) + mode: ModuleLayout = ModuleLayout.UNIFIED + + @classmethod + def unified(cls, module_names: List[str]) -> 'RankRole': + """Create a role for the unified case: every module, first+last stage.""" + return cls( + modules={ + name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) + for name in module_names + }, + mode=ModuleLayout.UNIFIED, + ) + + @classmethod + def from_grid_map( + cls, module_to_grid_map: Dict[str, HyperCommGrid], modality_module_names: List[str] + ) -> 'RankRole': + """Create a role from a module-to-grid mapping for non-colocated PP. + + Determines which modules the current rank participates in and its + pipeline stage position within each module. + + Args: + module_to_grid_map: Dict mapping module names to HyperCommGrid objects. + Must contain keys matching modality_module_names + MIMO_LANGUAGE_MODULE_KEY. + modality_module_names: List of modality module names (e.g., ["images", "audio"]). + + Returns: + RankRole for the current rank. + + Raises: + ValueError: If grid map keys don't match expected module names. + RuntimeError: If current rank is not in any module grid. + """ + # Validate keys + expected_keys = set(modality_module_names) | {MIMO_LANGUAGE_MODULE_KEY} + grid_keys = set(module_to_grid_map.keys()) + if grid_keys != expected_keys: + raise ValueError( + f"module_to_grid_map keys must match modality module names + " + f"'{MIMO_LANGUAGE_MODULE_KEY}'. Missing: {expected_keys - grid_keys}, " + f"Extra: {grid_keys - expected_keys}" + ) + + current_rank = dist.get_rank() + modules = {} + + for module_name, grid in module_to_grid_map.items(): + if not (grid.rank_offset <= current_rank < grid.rank_offset + grid.size): + continue + + if "pp" not in grid.dim_names: + modules[module_name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + continue + + pp_group = grid.get_pg("pp") + pp_rank = pp_group.rank() + pp_size = pp_group.size() + is_first = pp_rank == 0 + is_last = pp_rank == pp_size - 1 + logger.info( + f"[RankRole.from_grid_map] Rank {current_rank}: module={module_name}, " + f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}" + ) + modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) + + if not modules: + raise RuntimeError( + f"Rank {current_rank} is not in any module grid. " + f"Check module_to_grid_map configuration." + ) + + return cls(modules=modules, mode=ModuleLayout.NON_COLOCATED) + + @property + def has_modality_modules(self) -> bool: + """Return True if this rank participates in any modality (non-language) module.""" + return any(name != MIMO_LANGUAGE_MODULE_KEY for name in self.modules) + + @property + def has_language_module(self) -> bool: + """Return True if this rank participates in the language module.""" + return MIMO_LANGUAGE_MODULE_KEY in self.modules + + @property + def modality_module_names(self) -> List[str]: + """Return names of modality modules (non-language) this rank participates in.""" + return [name for name in self.modules if name != MIMO_LANGUAGE_MODULE_KEY] + + def is_first_stage(self, module_name: str) -> bool: + """Check if this rank is the first stage for a given module.""" + if module_name not in self.modules: + return False + return self.modules[module_name].is_first_stage + + def is_last_stage(self, module_name: str) -> bool: + """Check if this rank is the last stage for a given module.""" + if module_name not in self.modules: + return False + return self.modules[module_name].is_last_stage diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index bb1c92c9f80..49e2fd42116 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -4,13 +4,15 @@ import warnings from typing import Any, Dict, Optional -import torch # type: ignore[import-not-found] +import torch from megatron.core.models.mimo.config import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import unwrap_model logger = logging.getLogger(__name__) @@ -54,6 +56,11 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - ) self.mimo_config = mimo_config + modality_names = list(mimo_config.modality_submodules_spec.keys()) + if mimo_config.module_to_grid_map: + self.role = RankRole.from_grid_map(mimo_config.module_to_grid_map, modality_names) + else: + self.role = RankRole.unified(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) # Use special token IDs from the config self.special_token_ids = ( @@ -62,9 +69,6 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # Extract language model config for partition adapter language_config = mimo_config.language_model_spec.params['config'] - assert ( - language_config.pipeline_model_parallel_size == 1 - ), "Pipeline parallelism is not supported in MimoModel" max_seq_len = mimo_config.language_model_spec.params.get('max_sequence_length', 4096) self.partition_adapter: Optional[PartitionAdapter] = None @@ -155,21 +159,41 @@ def align_embeddings_by_token_positions( def _initialize_submodules(self) -> None: """Initialize modality submodules from the ModuleSpec configurations. - Only modalities present in the config will be instantiated. - For each modality in the config, builds the corresponding submodule using from_spec. + When role is set, only initializes submodules this rank participates in. + Stage info is passed to from_spec() to conditionally skip projection. """ - for modality_name, submodule_spec in self.mimo_config.modality_submodules_spec.items(): - # Get the submodule class + if modality_name not in self.role.modules: + logger.debug(f"Skipping {modality_name} submodule (not in role)") + continue + + stage_info = self.role.modules[modality_name] + is_first_stage = stage_info.is_first_stage + is_last_stage = stage_info.is_last_stage + submodule_class = submodule_spec.module - logger.debug(f"Building {modality_name} submodule using {submodule_class.__name__}") + logger.debug( + f"Building {modality_name} submodule using {submodule_class.__name__} " + f"(is_first_stage={is_first_stage}, is_last_stage={is_last_stage})" + ) + + # Pass stage info to from_spec so projections are only built when needed + submodule = submodule_class.from_spec( + submodule_spec, is_first_stage=is_first_stage, is_last_stage=is_last_stage + ) - # Use from_spec to instantiate the submodule - submodule = submodule_class.from_spec(submodule_spec) self.modality_submodules[modality_name] = submodule def _initialize_language_model(self) -> None: - """Initialize the language model.""" + """Initialize the language model. + + When role is set, only initializes if this rank participates in language module. + """ + if not self.role.has_language_module: + logger.debug("Skipping language model initialization (not in role)") + self.language_model = None + return + logger.debug( f"Building language model using {self.mimo_config.language_model_spec.module.__name__}" ) @@ -182,18 +206,30 @@ def set_input_tensor(self, input_tensor): It passes the output tensor from the previous stage as input to this stage. Args: - input_tensor: Tensor or list of tensors passed between pipeline stages + input_tensor: Either: + - Dict[str, Tensor]: Maps module names to their input tensors (for multi-module PP) + - Tensor or List[Tensor]: Single tensor for language model (backward compat) Returns: None """ - # Handle case where input_tensor might be a list or a single tensor + # The schedule wraps input_tensor in a list (schedules.py:415-416), + # so unwrap first before checking type. if isinstance(input_tensor, list): - # For simplicity, just use the first tensor input_tensor = input_tensor[0] - # Pass the input tensor to the language model if it has a set_input_tensor method - if hasattr(self.language_model, 'set_input_tensor'): + # Store dict input for multi-module PP + if isinstance(input_tensor, dict): + # P2P recv may return [tensor] (list) for VPP compat — unwrap to tensor + self.input_tensors = { + k: v[0] if isinstance(v, list) and len(v) == 1 else v + for k, v in input_tensor.items() + } + return + + self.input_tensors = input_tensor + + if self.language_model is not None and hasattr(self.language_model, 'set_input_tensor'): self.language_model.set_input_tensor(input_tensor) def get_text_embeddings( @@ -223,10 +259,10 @@ def get_text_embeddings( position_ids[batch_idx, seq_idx].unsqueeze(0) if position_ids is not None else None ) - text_embeddings = self.language_model.embedding( - input_ids=input_ids_text, position_ids=position_ids_text - ).squeeze( - 1 + text_embeddings = ( + unwrap_model(self.language_model) + .embedding(input_ids=input_ids_text, position_ids=position_ids_text) + .squeeze(1) ) # Shape: [num_text_tokens, hidden_dim] return text_embeddings @@ -274,9 +310,160 @@ def forward( } Returns: - tuple: Tuple containing model outputs and loss mask - - lm_output: Model output. Shape: (B, S, ...) or (B, S, V) - - loss_mask: Loss mask. Shape: (B, S) + tuple: (output, loss_mask) where output semantics depend on role: + - Encoder-only ranks: Dict[str, Tensor] of encoder outputs + - Language module ranks: language model output (logits or loss) + - No role (all modules colocated): language model output + """ + # Get any tensors passed via set_input_tensor + input_tensors = getattr(self, 'input_tensors', None) + + if self.role.mode == ModuleLayout.UNIFIED: + return self._forward_all_modules( + input_ids, + position_ids, + attention_mask, + loss_mask, + labels, + modality_inputs, + packing_kwargs, + ) + + if self.role.mode == ModuleLayout.NON_COLOCATED: + if self.role.has_modality_modules: + return self._forward_encoders(modality_inputs, input_tensors), loss_mask + + if self.role.has_language_module: + return ( + self._forward_language_module( + input_ids, position_ids, attention_mask, labels, input_tensors + ), + loss_mask, + ) + + raise RuntimeError(f"Rank has no modules assigned in role: {self.role}") + + raise NotImplementedError(f"Pipeline mode {self.role.mode} is not yet supported") + + def _forward_encoders( + self, + modality_inputs: Optional[Dict[str, Dict[str, Any]]], + input_tensors: Optional[Dict[str, torch.Tensor]], + ) -> Dict[str, torch.Tensor]: + """Forward pass for encoder modules on this rank. + + Args: + modality_inputs: Raw inputs for each modality (images, audio, etc.) + input_tensors: Hidden states from previous pipeline stages + + Returns: + Dict mapping encoder names to their output tensors + """ + outputs = {} + + for encoder_name in self.role.modality_module_names: + if encoder_name not in self.modality_submodules: + continue + + submodule = self.modality_submodules[encoder_name] + output = submodule.forward( + encoder_inputs=modality_inputs.get(encoder_name) if modality_inputs else None, + hidden_states=input_tensors.get(encoder_name) if input_tensors else None, + ) + + if output is not None: + outputs[encoder_name] = output + + return outputs + + def _forward_language_module( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + input_tensors: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass for language module on this rank. + + Args: + input_ids: Token IDs + position_ids: Position IDs + attention_mask: Attention mask + labels: Labels for loss computation + input_tensors: Hidden states or embeddings from previous stage + + Returns: + Language model output (hidden states, logits, or loss depending on stage) + """ + lang_name = MIMO_LANGUAGE_MODULE_KEY + + if self.role.is_first_stage(lang_name): + # First stage: receive encoder embeddings, combine with text, pass to LM + # Build modality embeddings dict from encoder outputs + modality_embeddings = {} + if input_tensors: + for name, tensor in input_tensors.items(): + if name != lang_name: + modality_embeddings[name] = tensor + + # Get text embeddings + text_embeddings = self.get_text_embeddings( + input_ids, position_ids, self.special_token_ids + ) + modality_embeddings["text"] = text_embeddings + + # Combine all embeddings + combined_embeddings = self.align_embeddings_by_token_positions( + modality_embeddings=modality_embeddings, + input_ids=input_ids, + special_token_ids=self.special_token_ids, + ) + + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=combined_embeddings, + labels=labels, + attention_mask=attention_mask, + ) + else: + # Non-first stage: receive hidden states from previous LM stage + hidden_states = input_tensors.get(lang_name) if input_tensors else None + + # Set input tensor on language model for PP (unwrap DDP to reach GPTModel) + if hidden_states is not None: + underlying_lm = unwrap_model(self.language_model) + if hasattr(underlying_lm, 'set_input_tensor'): + underlying_lm.set_input_tensor(hidden_states) + + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=None, + labels=labels, + attention_mask=attention_mask, + ) + + # Key output for non-last stages so schedule can route to next LM stage + if not self.role.is_last_stage(lang_name): + return {lang_name: lm_output} + + return lm_output + + def _forward_all_modules( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + loss_mask: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + modality_inputs: Optional[Dict[str, Dict[str, Any]]], + packing_kwargs: Optional[dict] = None, + ): + """Forward pass when all modules are on all ranks (no multi-module PP). + + This is the original behavior, preserved for backward compatibility. """ # If packing_kwargs is provided, construct PackedSeqParams packed_seq_params = None @@ -293,17 +480,14 @@ def forward( modality_embeddings = {} for modality_name, submodule in self.modality_submodules.items(): - # Process the modality through its submodule if ( modality_inputs and modality_name in modality_inputs and modality_inputs[modality_name] is not None ): logger.debug(f"Processing {modality_name} modality") - # Get embeddings for this modality embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) if embeddings is not None: - # All embeddings are now in the format [num_tokens, hidden_dim] modality_embeddings[modality_name] = embeddings logger.debug( f"Generated embeddings for {modality_name} with shape {embeddings.shape}" diff --git a/megatron/core/models/mimo/submodules/audio.py b/megatron/core/models/mimo/submodules/audio.py index ae907d7ac86..6db2782d82f 100644 --- a/megatron/core/models/mimo/submodules/audio.py +++ b/megatron/core/models/mimo/submodules/audio.py @@ -1,16 +1,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import logging -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -import torch import torch.nn as nn from megatron.core.models.mimo.submodules.base import ModalitySubmodules -# Initialize logger -logger = logging.getLogger(__name__) - class AudioModalitySubmodules(ModalitySubmodules): """Audio modality submodules for encoding, decoding, and projecting audio data.""" @@ -32,7 +27,13 @@ def __init__( output_projections: List of output projection modules **kwargs: Additional keyword arguments """ - super().__init__(encoders, decoders, input_projections, output_projections, **kwargs) + super().__init__( + encoders=encoders, + decoders=decoders, + input_projections=input_projections, + output_projections=output_projections, + **kwargs, + ) if self.input_projections: assert ( @@ -44,112 +45,6 @@ def __init__( len(self.output_projections) <= 1 ), "AudioModalitySubmodules currently supports only one output projection" - def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: - """Encode audio data into a sequence of embeddings. - - Args: - encoders_data_batch: Dictionary containing encoder-specific inputs. - Keys should match encoder names in self.encoders. - Each encoder receives its own specific inputs. - - Returns: - List of encoded audio embeddings, one from each encoder. - Each embedding is a flattened tensor of shape [total_tokens, hidden_dim] - - Raises: - ValueError: If no data is provided for any encoder or if there's a parameter mismatch. - """ - if not encoders_data_batch: - return [] - - embeddings = [] - - for name, encoder in self.encoders.items(): - if name not in encoders_data_batch: - raise ValueError(f"No inputs found for encoder '{name}'") - - encoder_inputs = encoders_data_batch[name] - - # Process inputs through the encoder - encoder_outputs = encoder(**encoder_inputs) - logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") - if encoder_outputs.ndim == 3: - # its b,s,h -> we need to flatten it to b*s,h - encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) - embeddings.append(encoder_outputs) - elif encoder_outputs.ndim == 2: - # its b*s,h -> encoder already returned the flattened output - embeddings.append(encoder_outputs) - else: - raise ValueError( - f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" - "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" - ) - return embeddings - - def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: + def decode(self, embeddings, data_batch: Dict): """Decode embeddings into audio data.""" raise NotImplementedError("Audio decoding not implemented yet") - - def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine embeddings from different encoders.""" - if not embeddings: - raise ValueError("Cannot combine empty list of embeddings") - - if len(embeddings) == 1: - return embeddings[0] - - # Concatenate along sequence dimension - # each embedding is [total_tokens, hidden_dim] - combined = torch.cat(embeddings, dim=0) - logger.debug(f"Combined audio embeddings shape: {combined.shape}") - return combined - - def project_embeddings( - self, embeddings: List[torch.Tensor], is_input: bool = True - ) -> torch.Tensor: - """Project embeddings to the language model dimension space.""" - - if is_input: - embeddings = self.combine_embeddings(embeddings) - - # Get the appropriate projections - projections = self.input_projections if is_input else self.output_projections - - # Apply projection if available - if projections: - # We've asserted in __init__ that there's only one projection - projection = projections[0] - projected = projection(embeddings) - logger.debug(f"Post-projection audio embeddings shape: {projected.shape}") - return projected - - return embeddings - - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - """Forward pass for audio modality submodules. - - Args: - encoder_inputs: Dictionary where keys match encoder names in self.encoders - and values are dictionaries of encoder-specific parameters. - Example: { - "whisper": {"input_features": features}, - "wav2vec": {"input_values": waveform} - } - - Returns: - Flattened audio embeddings with shape [total_embeddings, hidden_dim], - or None if no valid inputs were provided. - """ - - embeddings = self.encode(encoder_inputs) - # embeddings is a list of tensors, each tensor is a flattened audio embedding - - # If no embeddings were produced, return None - if not embeddings: - return None - - # Project embeddings - projected = self.project_embeddings(embeddings, is_input=True) - logger.debug(f"Projected audio embeddings shape: {projected.shape}") - return projected # [total_embeddings, hidden_dim] diff --git a/megatron/core/models/mimo/submodules/base.py b/megatron/core/models/mimo/submodules/base.py index 8b11ba7fcb9..58f61f81d3c 100644 --- a/megatron/core/models/mimo/submodules/base.py +++ b/megatron/core/models/mimo/submodules/base.py @@ -42,15 +42,30 @@ def __init__( decoders: Optional[Dict[str, nn.Module]] = None, input_projections: Optional[List[nn.Module]] = None, output_projections: Optional[List[nn.Module]] = None, + is_first_stage: bool = True, + is_last_stage: bool = True, **kwargs, ) -> None: - """Initialize the modality submodules.""" + """Initialize the modality submodules. + + Args: + encoders: Dict of encoder modules + decoders: Dict of decoder modules + input_projections: List of input projection modules + output_projections: List of output projection modules + is_first_stage: Whether this is the first PP stage for this module + is_last_stage: Whether this is the last PP stage for this module + """ super().__init__() self.encoders = nn.ModuleDict(encoders or {}) self.decoders = nn.ModuleDict(decoders or {}) self.input_projections = nn.ModuleList(input_projections or []) self.output_projections = nn.ModuleList(output_projections or []) + # Stage info for multi-module pipeline parallelism (immutable after init) + self._is_first_stage: bool = is_first_stage + self._is_last_stage: bool = is_last_stage + warnings.warn( "ModalitySubmodules is experimental and still under active development. " "The API may change without notice in future releases.", @@ -58,21 +73,42 @@ def __init__( stacklevel=2, ) + @property + def is_first_stage(self) -> bool: + """Whether this is the first pipeline stage for this module.""" + return self._is_first_stage + + @property + def is_last_stage(self) -> bool: + """Whether this is the last pipeline stage for this module.""" + return self._is_last_stage + @classmethod - def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': + def from_spec( + cls, module_spec: ModuleSpec, is_first_stage: bool = True, is_last_stage: bool = True + ) -> 'ModalitySubmodules': """Create a modality submodule from ModuleSpec configuration. Args: module_spec (ModuleSpec): The module specification for this modality submodule + is_first_stage (bool): Whether this is the first pipeline stage for this module. + Controls encoder initialization and output projection initialization + (output projections only built on first stage). Defaults to True. + is_last_stage (bool): Whether this is the last pipeline stage for this module. + Controls input projection initialization (only built on last stage). + Defaults to True. Returns: ModalitySubmodules: An instance of the modality submodule """ - logger.debug(f"Creating {cls.__name__} from spec") + logger.debug( + f"Creating {cls.__name__} from spec (is_first_stage={is_first_stage}, " + f"is_last_stage={is_last_stage})" + ) params = module_spec.params or {} submodules = module_spec.submodules or {} - # Build component lists from submodules dictionary + # Build encoders (needed on all stages for pipeline processing) encoders = {} if 'encoders' in submodules: for encoder_name, encoder_spec in submodules['encoders'].items(): @@ -80,6 +116,7 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': encoder = build_module(encoder_spec) encoders[encoder_name] = encoder + # Build decoders (needed on all stages for pipeline processing) decoders = {} if 'decoders' in submodules: for decoder_name, decoder_spec in submodules['decoders'].items(): @@ -87,23 +124,31 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': decoder = build_module(decoder_spec) decoders[decoder_name] = decoder + # Build input projections only on last stage + # (projection happens after encoding, before sending to language model) input_projections = [] - if 'input_projections' in submodules: + if is_last_stage and 'input_projections' in submodules: for proj_spec in submodules['input_projections']: logger.debug( f"Building {cls.__name__} input projection: {proj_spec.module.__name__}" ) projection = build_module(proj_spec) input_projections.append(projection) + elif 'input_projections' in submodules: + logger.debug(f"Skipping {cls.__name__} input projections (not last stage)") + # Build output projections only on first stage + # (projection happens before decoding, after receiving from language model) output_projections = [] - if 'output_projections' in submodules: + if is_first_stage and 'output_projections' in submodules: for proj_spec in submodules['output_projections']: logger.debug( f"Building {cls.__name__} output projection: {proj_spec.module.__name__}" ) projection = build_module(proj_spec) output_projections.append(projection) + elif 'output_projections' in submodules: + logger.debug(f"Skipping {cls.__name__} output projections (not first stage)") # Pass any additional parameters from the params dictionary additional_params = params.copy() @@ -117,34 +162,66 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': decoders=decoders, input_projections=input_projections, output_projections=output_projections, + is_first_stage=is_first_stage, + is_last_stage=is_last_stage, **additional_params, ) - @abstractmethod def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine multiple embeddings from different encoders. + """Combine multiple embeddings from different encoders by concatenation. Args: embeddings (List[torch.Tensor]): - List of embeddings to combine + List of embeddings to combine. Each is [total_tokens, hidden_dim]. Returns: torch.Tensor: Combined embedding tensor """ - pass + if not embeddings: + raise ValueError("Cannot combine empty list of embeddings") - @abstractmethod - def encode(self, data_batch: Dict) -> List[torch.Tensor]: + if len(embeddings) == 1: + return embeddings[0] + + combined = torch.cat(embeddings, dim=0) + logger.debug(f"Combined embeddings shape after concatenation: {combined.shape}") + return combined + + def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: """Encode data batch into a list of tensors. Args: - data_batch (Dict): - Dictionary containing input data + encoders_data_batch (Dict): + Dictionary containing encoder-specific inputs. + Keys should match encoder names in self.encoders. Returns: - List[torch.Tensor]: List of encoded embeddings + List[torch.Tensor]: List of encoded embeddings, each [total_tokens, hidden_dim] """ - pass + if not encoders_data_batch: + return [] + + embeddings = [] + + for name, encoder in self.encoders.items(): + if name not in encoders_data_batch: + raise ValueError(f"No inputs found for encoder '{name}'") + + encoder_inputs = encoders_data_batch[name] + encoder_outputs = encoder(**encoder_inputs) + logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") + + if encoder_outputs.ndim == 3: + encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) + elif encoder_outputs.ndim != 2: + raise ValueError( + f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported. " + f"Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" + ) + + embeddings.append(encoder_outputs) + + return embeddings @abstractmethod def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: @@ -161,11 +238,10 @@ def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: """ pass - @abstractmethod def project_embeddings( self, embeddings: List[torch.Tensor], is_input: bool = True ) -> Optional[torch.Tensor]: - """Project embeddings into a tensor. + """Project embeddings using input or output projections. Args: embeddings (List[torch.Tensor]): @@ -176,18 +252,49 @@ def project_embeddings( Returns: Optional[torch.Tensor]: Projected embeddings or None """ - pass + combined = self.combine_embeddings(embeddings) - @abstractmethod - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + projections = self.input_projections if is_input else self.output_projections + + if projections: + projection = projections[0] + projected = projection(combined) + logger.debug(f"Post-projection embeddings shape: {projected.shape}") + return projected + + return combined + + def forward( + self, + encoder_inputs: Optional[Dict[str, Any]] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: """Process data for this modality through encoding and projection. Args: encoder_inputs (Dict[str, Any]): Dictionary containing encoder-specific inputs. Keys should match encoder names. + Used when is_first_stage=True. + hidden_states (Optional[torch.Tensor]): + Hidden states from previous pipeline stage. Used when is_first_stage=False. Returns: Optional[torch.Tensor]: Processed and projected embeddings tensor, or None if no embeddings were produced. """ - pass + if self.is_first_stage: + if encoder_inputs is None: + return None + embeddings = self.encode(encoder_inputs) + if not embeddings: + return None + combined = self.combine_embeddings(embeddings) + else: + if hidden_states is None: + return None + combined = hidden_states + + if self.is_last_stage: + return self.project_embeddings([combined], is_input=True) + + return combined diff --git a/megatron/core/models/mimo/submodules/vision.py b/megatron/core/models/mimo/submodules/vision.py index 795cb18a119..0bb1a45e013 100644 --- a/megatron/core/models/mimo/submodules/vision.py +++ b/megatron/core/models/mimo/submodules/vision.py @@ -1,16 +1,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import logging -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -import torch import torch.nn as nn from megatron.core.models.mimo.submodules.base import ModalitySubmodules -# Initialize logger -logger = logging.getLogger(__name__) - class VisionModalitySubmodules(ModalitySubmodules): """Vision modality submodules for encoding, decoding, and projecting image data. @@ -40,6 +35,7 @@ def __init__( decoders=decoders, input_projections=input_projections, output_projections=output_projections, + **kwargs, ) if self.input_projections: @@ -52,133 +48,6 @@ def __init__( len(self.output_projections) <= 1 ), "VisionModalitySubmodules currently supports only one output projection" - def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: - """Encode image data batch into a list of tensors. - - Args: - encoders_data_batch: Dictionary containing encoder-specific inputs. - Keys should match encoder names in self.encoders. - Each encoder receives its own specific inputs. - - Returns: - List of encoded image embeddings, one from each encoder. - Each embedding is a flattened tensor of shape [total_tokens, hidden_dim] - - Raises: - ValueError: If no data is provided for any encoder or if there's a parameter mismatch. - """ - if not encoders_data_batch: - return [] - - embeddings = [] - - for name, encoder in self.encoders.items(): - if name not in encoders_data_batch: - raise ValueError(f"No inputs found for encoder '{name}'") - - encoder_inputs = encoders_data_batch[name] - - # Process inputs through the encoder - encoder_outputs = encoder(**encoder_inputs) - logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") - if encoder_outputs.ndim == 3: - # its b,s,h -> we need to flatten it to b*s,h - encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) - embeddings.append(encoder_outputs) - elif encoder_outputs.ndim == 2: - # its b*s,h -> encoder already returned the flattened output - embeddings.append(encoder_outputs) - else: - raise ValueError( - f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" - "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" - ) - - return embeddings - - def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: - """Decode embeddings into image tensors. - - Args: - embeddings: Tensor of embeddings to decode. - data_batch: Dictionary containing additional data for decoding. - - Returns: - Tensor containing generated images. - """ - + def decode(self, embeddings, data_batch: Dict): + """Decode embeddings into image tensors.""" raise NotImplementedError("No decoders support yet") - - def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine multiple embeddings from different encoders by concatenation. - - This method is used for combining encoder outputs before input projection. - - Args: - embeddings: List of embeddings to combine - - Returns: - Combined embedding tensor - """ - if not embeddings: - raise ValueError("Cannot combine empty list of embeddings") - - if len(embeddings) == 1: - return embeddings[0] - - # each embedding is [total_tokens, hidden_dim] - # Make this configurable in the future - combined = torch.cat(embeddings, dim=0) - logger.debug(f"Combined embeddings shape after concatenation: {combined.shape}") - return combined - - def project_embeddings( - self, embeddings: List[torch.Tensor], is_input: bool = True - ) -> torch.Tensor: - """Project image embeddings using input or output projections. - - Args: - embeddings: List of image embeddings to project - is_input: If True, use input projections, otherwise use output projections - - Returns: - Projected image embeddings or None if no embeddings - """ - if is_input: - embeddings = self.combine_embeddings(embeddings) - - # Get the appropriate projection (input or output) - projections = self.input_projections if is_input else self.output_projections - - # Apply projection if available - if projections: - # We've asserted in __init__ that there's only one projection - projection = projections[0] - projected = projection(embeddings) - logger.debug(f"Post-projection embeddings shape: {projected.shape}") - return projected - - return embeddings - - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - """Process image data through encoding and projection. - - Args: - encoder_inputs: Dictionary where keys match encoder names in self.encoders - and values are dictionaries of encoder-specific parameters. - Example: {"clip": {"pixel_values": images}, "vit": {"images": vit_images}} - - Returns: - Flattened image embeddings with shape [total_embeddings, hidden_dim], - or None if no valid inputs were provided. - """ - # Encode the images - embeddings = self.encode(encoder_inputs) - - # If no embeddings were produced, return None - if not embeddings: - return None - - projected = self.project_embeddings(embeddings, is_input=True) - logging.debug(f"Projected audio embeddings shape: {projected.shape}") - return projected # [total_embeddings, hidden_dim] diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py new file mode 100644 index 00000000000..1a6b29cc58a --- /dev/null +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -0,0 +1,691 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Integration tests for MIMO model with 1F1B pipeline schedule. + +Run with: + uv run python -m torch.distributed.run --nproc-per-node=2 -m pytest tests/unit_tests/models/test_mimo_1f1b_schedule.py -v +""" + +import logging +from contextlib import ExitStack, contextmanager +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ( + MultiModuleProcessGroupCollection, + ProcessGroupCollection, +) +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Helper Functions (with grid tracking and PG caching from edc8159) +# ============================================================================ + +_active_grids: list = [] +_embedding_pg_cache: dict = {} + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + """Create a HyperCommGrid with specified parallelism.""" + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, 1, 1], # [tp, cp, pp, dp, ep, expt_dp] + dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["ep"]) + grid.create_pg(["expt_dp"]) + _active_grids.append(grid) + return grid + + +def destroy_all_grids(): + """Destroy all tracked grids and bridge communicator PGs.""" + for grid in _active_grids: + grid.destroy() + _active_grids.clear() + _embedding_pg_cache.clear() + BridgeCommunicator.destroy_broadcast_pgs() + + +def get_pg_collection(grid): + """Get ProcessGroupCollection from grid.""" + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + pg_collection.dp = grid.get_pg("dp") + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + pg_collection.expt_dp = grid.get_pg("expt_dp") + return pg_collection + + +def create_all_embedding_groups(grids): + """Create embedding PGs for all grids upfront. + + dist.new_group is a collective — ALL ranks must call it, even non-members. + We create all embedding groups in a consistent order across all ranks to + avoid hangs from asymmetric new_group calls. + + Args: + grids: List of all HyperCommGrids that need embedding groups. + """ + for grid in grids: + pp_group = grid.get_pg("pp") + if not pp_group: + continue + + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + cache_key = tuple(pp_ranks) + + if cache_key not in _embedding_pg_cache: + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + _embedding_pg_cache[cache_key] = ( + dist.new_group(ranks=pos_embd_ranks), + dist.new_group(ranks=embd_ranks), + ) + + +def add_embedding_groups(pg_collection, is_language_model=False): + """Add cached embedding groups to a process group collection. + + Must call create_all_embedding_groups() first to ensure PGs exist. + + Args: + pg_collection: ProcessGroupCollection to add embedding groups to. + is_language_model: If True, set embd group for word embedding sync. + """ + if not pg_collection.pp: + return pg_collection + + pp_ranks = sorted(dist.get_process_group_ranks(pg_collection.pp)) + cache_key = tuple(pp_ranks) + pos_embd_pg, embd_pg = _embedding_pg_cache[cache_key] + + pg_collection.pos_embd = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + + if is_language_model: + pg_collection.embd = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + else: + # Encoder submodules have no shared word embeddings to sync + pg_collection.embd = None + + return pg_collection + + +def get_pg_collection_with_embedding_groups(grid, is_language_model=False): + """Get ProcessGroupCollection with embedding groups (PGs must be pre-created).""" + return add_embedding_groups(get_pg_collection(grid), is_language_model=is_language_model) + + +def is_rank_in_grid(grid): + """Check if current rank is in grid.""" + rank = dist.get_rank() + return grid.rank_offset <= rank < grid.rank_offset + grid.size + + +# ============================================================================ +# Model Spec Helpers +# ============================================================================ + + +def get_language_model_spec( + num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection +): + """Get the language model spec.""" + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='te', + ) + return ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + "pg_collection": pg_collection, + }, + ) + + +def get_projection_config(hidden_size): + """Return a TransformerConfig for the vision projection MLP.""" + cfg = TransformerConfig(num_layers=1, hidden_size=hidden_size, num_attention_heads=1) + cfg.ffn_hidden_size = hidden_size + cfg.bias_activation_fusion = True + cfg.add_bias_linear = True + cfg.activation_func = torch.nn.functional.gelu + return cfg + + +def get_projection_layer_spec(): + """Layer spec for the vision-projection MLP.""" + if TEColumnParallelLinear is None or TERowParallelLinear is None: + raise RuntimeError("TEColumnParallelLinear and TERowParallelLinear are required") + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear), + ) + + +def get_vision_submodules_spec( + num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection +): + """Get the submodule spec for the vision modality.""" + from megatron.core.transformer.transformer_block import TransformerBlock + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + pp_rank = dist.get_rank(pg_collection.pp) + + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + vision_encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + }, + ) + + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": get_projection_config(hidden_size=language_hidden_size), + "submodules": get_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": pg_collection.tp, + }, + ) + + return ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": {"clip_encoder": vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) + + +def get_mimo_model( + encoder_name, encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seq_len +): + """Create MIMO model with TransformerBlock encoder and GPTModel LLM.""" + language_pg = get_pg_collection_with_embedding_groups(llm_grid, is_language_model=True) + vision_pg = get_pg_collection_with_embedding_groups(encoder_grid, is_language_model=False) + + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_len, + pg_collection=language_pg, + ) + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg, + ) + + module_to_grid_map = {encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid} + topology = {encoder_name: [MIMO_LANGUAGE_MODULE_KEY], MIMO_LANGUAGE_MODULE_KEY: []} + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map=module_to_grid_map, + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + + # Wrap with DDP + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True + ) + + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg, + ) + + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + submodule = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg, + ) + mimo_model.modality_submodules[encoder_name] = submodule + + return mimo_model, module_to_grid_map, topology, language_pg, vision_pg + + +# ============================================================================ +# Data Iterator +# ============================================================================ + + +class DataIterator: + """Simple data iterator returning VLM-like batches.""" + + def __init__( + self, + hidden_size, + seq_length, + micro_batch_size, + vocab_size, + encoder_name, + image_token_id=50257, + image_seq_length=None, + ): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.vocab_size = vocab_size + self.encoder_name = encoder_name + self.image_token_id = image_token_id + self.image_seq_length = image_seq_length or (seq_length // 2) + + def __iter__(self): + return self + + def __next__(self): + encoder_hidden_states = torch.randn( + self.image_seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + self.image_token_id, + dtype=torch.long, + device='cuda', + ) + text_tokens = torch.randint( + 1, + self.vocab_size, + (self.micro_batch_size, self.seq_length - self.image_seq_length), + device='cuda', + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + + labels = input_ids.clone() + labels[input_ids == self.image_token_id] = -100 + + loss_mask = torch.ones( + self.micro_batch_size, self.seq_length, device='cuda', dtype=torch.float32 + ) + loss_mask[input_ids == self.image_token_id] = 0.0 + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": torch.arange(self.seq_length, device='cuda') + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone(), + "modality_inputs": { + self.encoder_name: { + "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} + } + }, + } + + +# ============================================================================ +# Test Runner +# ============================================================================ + + +def run_mimo_1f1b_test( + encoder_tp, + encoder_pp, + encoder_dp, + encoder_offset, + llm_tp, + llm_pp, + llm_dp, + llm_offset, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, +): + """Run MIMO model through 1F1B schedule and verify.""" + # Clear NVTE env vars that the conftest set_env fixture sets to '0'. + # GPTModel (LanguageModule) asserts these are unset or match the attention backend. + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + encoder_name = "images" + + encoder_grid = create_hypercomm_grid( + offset=encoder_offset, tp=encoder_tp, cp=1, pp=encoder_pp, dp=encoder_dp + ) + llm_grid = create_hypercomm_grid(offset=llm_offset, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + + # Create all embedding PGs upfront — dist.new_group is a collective that + # requires ALL ranks to participate, so we must create them before any + # rank-specific pg_collection calls. + create_all_embedding_groups([encoder_grid, llm_grid]) + + torch.manual_seed(12345) + + mimo_model, module_to_grid_map, topology, language_pg, vision_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=seq_length, + ) + + # Build schedule functions using pre-created pg_collections (no leaks) + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + stack.enter_context(submodule.no_sync()) + yield + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + finalize_model_grads([submodule], num_tokens=None, pg_collection=vision_pg) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + # Create data iterator on ranks that need it + data_iterator = None + encoder_needs_data = is_rank_in_grid(encoder_grid) and is_pp_first_stage( + encoder_grid.get_pg("pp") + ) + llm_needs_data = is_rank_in_grid(llm_grid) and ( + is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp")) + ) + if encoder_needs_data or llm_needs_data: + data_iterator = DataIterator( + hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name + ) + + # Build MultiModuleProcessGroupCollection (reuse pre-created pg_collections) + module_pgs = {} + language_model_module_name = None + if is_rank_in_grid(encoder_grid): + module_pgs[encoder_name] = vision_pg + if is_rank_in_grid(llm_grid): + module_pgs[MIMO_LANGUAGE_MODULE_KEY] = language_pg + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY + + pg_collection = MultiModuleProcessGroupCollection( + module_pgs=module_pgs, language_model_module_name=language_model_module_name + ) + + def step_func(data_iterator, model): + def loss_func(loss_mask, output_tensor): + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + + if isinstance(output_tensor, dict): + output = output_tensor.get( + MIMO_LANGUAGE_MODULE_KEY, next(iter(output_tensor.values()), None) + ) + else: + output = output_tensor + + if output is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + + loss = output.float().sum() + return loss, {'loss_reduced': loss} + + batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} + output_tensor, loss_mask = model(**batch) + return output_tensor, partial(loss_func, loss_mask) + + losses = schedule.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=data_iterator, + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=pg_collection, + ) + + # Verify results on last LLM stage + if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): + assert len(losses) > 0, "Expected losses on last LLM stage" + for loss_dict in losses: + assert 'loss_reduced' in loss_dict + + return losses + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimo1F1BSchedule: + """Test MIMO model with 1F1B pipeline schedule.""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_baseline_2gpu(self): + """Encoder PP=1, LLM PP=1 on 2 GPUs.""" + if self.world_size != 2: + pytest.skip(f"Requires 2 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=1, + llm_dp=1, + llm_offset=1, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_lm_pp3_4gpu(self): + """Encoder PP=1, LLM PP=3 on 4 GPUs.""" + if self.world_size != 4: + pytest.skip(f"Requires 4 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=3, + llm_dp=1, + llm_offset=1, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_encoder_tp2_llm_tp2_pp3_8gpu(self): + """Encoder TP=2 PP=1, LLM TP=2 PP=3 on 8 GPUs.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=2, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=3, + llm_dp=1, + llm_offset=2, + hidden_size=256, + num_layers=3, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_full_pp_8gpu(self): + """Encoder PP=2, LLM PP=2 with TP=2 each on 8 GPUs.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=2, + encoder_pp=2, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=2, + llm_dp=1, + llm_offset=4, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) diff --git a/tests/unit_tests/models/test_mimo_audio_submodules.py b/tests/unit_tests/models/test_mimo_audio_submodules.py index 0f3865d940f..f9a18838f60 100644 --- a/tests/unit_tests/models/test_mimo_audio_submodules.py +++ b/tests/unit_tests/models/test_mimo_audio_submodules.py @@ -394,3 +394,33 @@ def test_multiple_audio_encoders(self, model_name, batch_size): print( f"Model {model_name} (d_model={self.d_model}) successfully processed audio and projected to dimension 768" ) + + +class TestAudioSubmoduleStageAware: + """Tests for stage-aware forward in AudioModalitySubmodules.""" + + def test_stage_aware_forward(self): + """Test stage-aware forward: hidden_states input and projection skipping.""" + import torch.nn as nn + + hidden_size = 64 + projection_size = 128 + hidden_states = torch.randn(10, hidden_size) + + # Non-first stage uses hidden_states, last stage projects + submodule_last = AudioModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=True, + ) + output = submodule_last.forward(hidden_states=hidden_states) + assert output.shape == (10, projection_size) # Projected + + # Non-last stage skips projection + submodule_mid = AudioModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=False, + ) + output = submodule_mid.forward(hidden_states=hidden_states) + assert output.shape == (10, hidden_size) # Not projected diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 3496087ac3b..e1c4b6e89bf 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. ''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_model.py +WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_model.py ''' import math @@ -9,12 +9,14 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from transformers import WhisperConfig, WhisperModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -31,13 +33,11 @@ class AudioEncoderWrapper(torch.nn.Module): """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - def __init__(self, config): + def __init__(self, **kwargs): super().__init__() - # Use a local Whisper model (tiny config) to avoid checkpoint download self.encoder = WhisperModel(WhisperConfig()).encoder def forward(self, input_features): - # Process through encoder and extract last_hidden_state with torch.no_grad(): return self.encoder(input_features).last_hidden_state @@ -60,7 +60,6 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - # Create vision projection spec vision_projection_spec = ModuleSpec( module=nn.Linear, params={ @@ -69,8 +68,7 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - # Create vision modality spec - vision_submodule_spec = ModuleSpec( + return ModuleSpec( module=VisionModalitySubmodules, submodules={ "encoders": {"clip_encoder": vision_encoder_spec}, @@ -78,36 +76,17 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - return vision_submodule_spec - def get_audio_submodules_spec(hidden_size): """Get the submodule spec for the audio modality.""" - - class AudioEncoderWrapper(torch.nn.Module): - """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - - def __init__(self, model_name="openai/whisper-tiny"): - super().__init__() - # Local tiny Whisper model with random weights - self.encoder = WhisperModel(WhisperConfig()).encoder - - def forward(self, input_features): - # Process through encoder and extract last_hidden_state - with torch.no_grad(): - return self.encoder(input_features).last_hidden_state - - # Audio modality configuration - audio_encoder_spec = ModuleSpec( - module=AudioEncoderWrapper, params={"model_name": "openai/whisper-tiny"} - ) + audio_encoder_spec = ModuleSpec(module=AudioEncoderWrapper, params={}) audio_projection_spec = ModuleSpec( module=nn.Linear, params={"in_features": 384, "out_features": hidden_size}, # Whisper tiny hidden size ) - audio_submodule_spec = ModuleSpec( + return ModuleSpec( module=AudioModalitySubmodules, submodules={ "encoders": {"whisper_encoder": audio_encoder_spec}, @@ -115,8 +94,6 @@ def forward(self, input_features): }, ) - return audio_submodule_spec - def get_language_model_spec(hidden_size, vocab_size, seq_len): """Get the language model spec.""" @@ -124,7 +101,7 @@ def get_language_model_spec(hidden_size, vocab_size, seq_len): num_layers=2, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True ) language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - language_model_spec = ModuleSpec( + return ModuleSpec( module=GPTModel, params={ "config": lm_config, @@ -135,55 +112,44 @@ def get_language_model_spec(hidden_size, vocab_size, seq_len): "post_process": True, }, ) - return language_model_spec def get_avlm_mimo_model( hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids ): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - audio_submodule_spec = get_audio_submodules_spec(hidden_size) - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec, "audio": audio_submodule_spec}, + language_model_spec=get_language_model_spec(hidden_size, vocab_size, seq_len), + modality_submodules_spec={ + "images": get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim), + "audio": get_audio_submodules_spec(hidden_size), + }, special_token_ids=special_token_ids, ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model + return MimoModel(mimo_config) def get_vlm_mimo_model( hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids ): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec}, + language_model_spec=get_language_model_spec(hidden_size, vocab_size, seq_len), + modality_submodules_spec={ + "images": get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) + }, special_token_ids=special_token_ids, ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model + return MimoModel(mimo_config) class TestMimoModel: """Test the MimoModel class.""" def setup_method(self, method): - '''setup env and model''' try: Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") + except Exception: + pass - # Set dimensions self.hidden_size = 64 self.batch_size = 2 self.seq_len = 2048 @@ -191,21 +157,28 @@ def setup_method(self, method): self.img_w = 224 self.patch_dim = 16 self.vocab_size = 48000 - - # Define special token IDs, not in LLM vocab self.special_token_ids = {"images": 50257, "audio": 50258} + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def teardown_method(self, method): - '''teardown env''' try: Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") + except Exception: + pass - def test_constructor(self): - """Test constructor initialization.""" + def _make_vlm(self): + return get_vlm_mimo_model( + self.hidden_size, + self.vocab_size, + self.seq_len, + self.img_h, + self.img_w, + self.patch_dim, + self.special_token_ids, + ).to(self.device) - mimo_model = get_avlm_mimo_model( + def _make_avlm(self): + return get_avlm_mimo_model( self.hidden_size, self.vocab_size, self.seq_len, @@ -213,252 +186,129 @@ def test_constructor(self): self.img_w, self.patch_dim, self.special_token_ids, + ).to(self.device) + + def _make_input_ids(self): + return torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device ) - # Move to device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - mimo_model = mimo_model.to(device) + def _make_position_ids(self): + return ( + torch.arange(self.seq_len, device=self.device).unsqueeze(0).expand(self.batch_size, -1) + ) + + def test_constructor(self): + """Test constructor initialization.""" + mimo_model = self._make_avlm() - # Test that modality submodules were initialized correctly assert "images" in mimo_model.modality_submodules assert "audio" in mimo_model.modality_submodules assert isinstance(mimo_model.modality_submodules["images"], VisionModalitySubmodules) assert isinstance(mimo_model.modality_submodules["audio"], AudioModalitySubmodules) - # Test that language model was initialized - assert hasattr(mimo_model, "language_model") assert isinstance(mimo_model.language_model, GPTModel) - - # Test that special token IDs were set correctly assert mimo_model.special_token_ids == self.special_token_ids def test_get_text_embeddings(self): """Test getting text embeddings.""" - # Create random input and position IDs (within vocab size range) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Get text embeddings + mimo_model = self._make_avlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() + text_embeddings = mimo_model.get_text_embeddings( input_ids, position_ids, self.special_token_ids ) - # Verify shape - # [b*s, h] assert text_embeddings.shape == (self.batch_size * self.seq_len, self.hidden_size) def test_forward_text_only(self): """Test forward pass with only text input.""" - # Create inputs - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Run forward pass with explicit parameters outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=None ) - assert outputs is not None - - # Verify output shape assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_forward_with_image_modality(self): """Test forward pass with text and image input.""" - # Calculate expected number of image tokens based on image size and patch dimension - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") expected_img_seq_len = (self.img_h // self.patch_dim) * ( self.img_w // self.patch_dim ) + 1 # +1 for CLS token - # Create a fixed distribution of images: 3 in first sample, 2 in second sample num_images = 5 - images_per_sample = [3, 2] # Must sum to num_images - assert sum(images_per_sample) == num_images - assert len(images_per_sample) == self.batch_size + images_per_sample = [3, 2] + images = torch.rand(num_images, 3, self.img_h, self.img_w, device=self.device) + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - # Create images tensor - images = torch.rand( - num_images, 3, self.img_h, self.img_w, device=device - ) # [num_images, 3, h, w] format - - # Create input_ids with text tokens - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - - # Create position_ids - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - - # Include image special tokens in input IDs + # Place image special tokens in each batch sample image_token_id = self.special_token_ids["images"] - start_pos = 5 # Start position for image tokens - - # Make sure there's enough space in the sequence for all image tokens in each sample - for b in range(self.batch_size): - tokens_needed = images_per_sample[b] * expected_img_seq_len - assert ( - start_pos + tokens_needed <= self.seq_len - ), f"Sequence length too short for image tokens in sample {b}" - - # Add image tokens to each batch sample according to its number of images + start_pos = 5 for b in range(self.batch_size): tokens_in_this_batch = images_per_sample[b] * expected_img_seq_len - if tokens_in_this_batch > 0: - input_ids[b, start_pos : start_pos + tokens_in_this_batch] = image_token_id + input_ids[b, start_pos : start_pos + tokens_in_this_batch] = image_token_id - # Create modality inputs using the new structure modality_inputs = {"images": {"clip_encoder": {"x": images}}} - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Run forward pass with new interface + mimo_model = self._make_vlm() outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs ) - assert outputs is not None - - # Verify output shape assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_forward_with_image_and_audio_modality(self): """Test forward pass with text, image, and audio input.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) + mimo_model = self._make_avlm() - # Calculate image sequence length img_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 - encoder_down_sampling = 2 - - # Create simple audio input (30 sec) - mel_bins = 80 # Whisper uses 80 mel bins + mel_bins = 80 time_bins = 3000 # 30 seconds of audio at 10ms per frame - audio_features = torch.rand(2, mel_bins, time_bins, device=device) - - # Calculate audio sequence length using Whisper's formula - audio_seq_len = math.ceil(time_bins / encoder_down_sampling) # 1500 tokens - - # Create batch data - batch_size = 2 - seq_len = self.seq_len + audio_seq_len = math.ceil(time_bins / encoder_down_sampling) - # Create input_ids with special tokens - input_ids = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - # Add special tokens at specific positions + # Place image and audio special tokens start_pos = 5 image_token_id = self.special_token_ids["images"] audio_token_id = self.special_token_ids["audio"] - - # Place image tokens followed by audio tokens in each batch item - for i in range(batch_size): - # Add image tokens + for i in range(self.batch_size): input_ids[i, start_pos : start_pos + img_seq_len] = image_token_id - # Add audio tokens after a gap - input_ids[ - i, start_pos + img_seq_len + 10 : start_pos + img_seq_len + 10 + audio_seq_len - ] = audio_token_id + audio_start = start_pos + img_seq_len + 10 + input_ids[i, audio_start : audio_start + audio_seq_len] = audio_token_id - # Prepare modality inputs modality_inputs = { "images": { - "clip_encoder": {"x": torch.rand(2, 3, self.img_h, self.img_w, device=device)} + "clip_encoder": {"x": torch.rand(2, 3, self.img_h, self.img_w, device=self.device)} + }, + "audio": { + "whisper_encoder": { + "input_features": torch.rand(2, mel_bins, time_bins, device=self.device) + } }, - "audio": {"whisper_encoder": {"input_features": audio_features}}, } - # Run forward pass outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs ) - - # Verify output shape - assert outputs is not None - assert outputs.shape == (batch_size, seq_len, self.vocab_size) + assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_state_dict(self): """Test state dict methods.""" - # Get state dict - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) + mimo_model = self._make_avlm() state_dict = mimo_model.state_dict() assert len(state_dict) > 0 + assert any(k.startswith("language_model.") for k in state_dict) + assert any(k.startswith("modality_submodules.") for k in state_dict) - # Make sure we have keys for language model and modality submodules - has_lm_keys = False - has_modality_keys = False - - for key in state_dict.keys(): - if key.startswith("language_model."): - has_lm_keys = True - if key.startswith("modality_submodules."): - has_modality_keys = True - - assert has_lm_keys - assert has_modality_keys - - # Test checkpoint state dict checkpoint_dict = mimo_model.state_dict_for_save_checkpoint() assert len(checkpoint_dict) > 0 - def test_pipeline_model_parallel_assertion(self): - """Test that MimoModel raises AssertionError when pipeline_model_parallel_size > 1.""" + def test_pipeline_model_parallel_accepted(self): + """Test that MimoModel accepts pipeline_model_parallel_size > 1.""" lm_config_pp2 = TransformerConfig( num_layers=2, hidden_size=self.hidden_size, @@ -467,12 +317,11 @@ def test_pipeline_model_parallel_assertion(self): pipeline_model_parallel_size=2, pipeline_dtype=torch.float32, ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() language_model_spec_pp2 = ModuleSpec( module=GPTModel, params={ "config": lm_config_pp2, - "transformer_layer_spec": language_layer_spec, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), "vocab_size": self.vocab_size, "max_sequence_length": self.seq_len, "pre_process": True, @@ -485,64 +334,37 @@ def test_pipeline_model_parallel_assertion(self): special_token_ids=self.special_token_ids, ) - with pytest.raises(AssertionError, match="Pipeline parallelism is not supported"): - MimoModel(mimo_config) + model = MimoModel(mimo_config) + assert model is not None def test_partition_adapter_none_by_default(self): """Test that partition_adapter is None with default config (no CP/SP).""" - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - # TransformerConfig defaults: context_parallel_size=1, sequence_parallel=False + mimo_model = self._make_vlm() assert mimo_model.partition_adapter is None def test_forward_with_packing_kwargs(self): """Test that packing_kwargs builds PackedSeqParams with qkv_format='thd' and int32 seqlens.""" from megatron.core.packed_seq_params import PackedSeqParams - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - # cu_seqlens covering full batch: [0, seq_len, 2*seq_len] cu_seqlens = torch.tensor( - [0, self.seq_len, 2 * self.seq_len], dtype=torch.int64, device=device + [0, self.seq_len, 2 * self.seq_len], dtype=torch.int64, device=self.device ) packing_kwargs = {"cu_seqlens_q": cu_seqlens.clone(), "cu_seqlens_kv": cu_seqlens.clone()} - # Mock get_text_embeddings and align_embeddings_by_token_positions to avoid full forward - text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=device) - combined_emb = torch.zeros(self.seq_len, self.batch_size, self.hidden_size, device=device) + text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=self.device) + combined_emb = torch.zeros( + self.seq_len, self.batch_size, self.hidden_size, device=self.device + ) - # Capture packed_seq_params via a side_effect on language_model.forward. - # Direct assignment (mimo_model.language_model = MagicMock()) is rejected by - # PyTorch because language_model is a registered nn.Module child. captured = {} def capture_lm_forward(*args, **kwargs): captured['packed_seq_params'] = kwargs.get('packed_seq_params') - return torch.zeros(self.batch_size, self.seq_len, self.vocab_size, device=device) + return torch.zeros(self.batch_size, self.seq_len, self.vocab_size, device=self.device) with ( patch.object(mimo_model, 'get_text_embeddings', return_value=text_emb), @@ -558,10 +380,7 @@ def capture_lm_forward(*args, **kwargs): packing_kwargs=packing_kwargs, ) - # Verify language model received a properly constructed PackedSeqParams packed_seq_params = captured['packed_seq_params'] - - assert packed_seq_params is not None assert isinstance(packed_seq_params, PackedSeqParams) assert packed_seq_params.qkv_format == 'thd' assert packed_seq_params.cu_seqlens_q.dtype == torch.int32 @@ -569,41 +388,30 @@ def capture_lm_forward(*args, **kwargs): def test_forward_with_partition_adapter(self): """Test that partition_adapter.shard() is called and embeddings are transposed correctly.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Inject a mock partition adapter that halves the sequence dimension sharded_seq_len = self.seq_len // 2 - sharded_emb = torch.zeros(self.batch_size, sharded_seq_len, self.hidden_size, device=device) + sharded_emb = torch.zeros( + self.batch_size, sharded_seq_len, self.hidden_size, device=self.device + ) mock_adapter = MagicMock() mock_adapter.shard.return_value = (sharded_emb, None, None, None, None) mimo_model.partition_adapter = mock_adapter - text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=device) - # align_embeddings_by_token_positions returns [S, B, H] - combined_emb = torch.zeros(self.seq_len, self.batch_size, self.hidden_size, device=device) + text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=self.device) + combined_emb = torch.zeros( + self.seq_len, self.batch_size, self.hidden_size, device=self.device + ) captured = {} def capture_lm_forward(*args, **kwargs): captured['decoder_input'] = kwargs.get('decoder_input') - return torch.zeros(self.batch_size, sharded_seq_len, self.vocab_size, device=device) + return torch.zeros( + self.batch_size, sharded_seq_len, self.vocab_size, device=self.device + ) with ( patch.object(mimo_model, 'get_text_embeddings', return_value=text_emb), @@ -614,16 +422,202 @@ def capture_lm_forward(*args, **kwargs): ): mimo_model(input_ids=input_ids, position_ids=position_ids, modality_inputs=None) - # shard() should have been called once mock_adapter.shard.assert_called_once() - - # The embeddings passed to shard() must be [B, S, H] (transposed from [S, B, H]) shard_kwargs = mock_adapter.shard.call_args[1] assert shard_kwargs['embeddings'].shape == (self.batch_size, self.seq_len, self.hidden_size) - - # The language model decoder_input must be [S/cp, B, H] (re-transposed after shard) assert captured['decoder_input'].shape == ( sharded_seq_len, self.batch_size, self.hidden_size, ) + + +class MockProcessGroup: + """Mock process group for testing.""" + + def __init__(self, rank, world_size): + self._rank = rank + self._size = world_size + + def rank(self): + return self._rank + + def size(self): + return self._size + + +class MockGrid: + """Mock grid with HyperCommGrid-compatible interface.""" + + def __init__(self, rank_offset=0, size=1, dim_names=None, pp_rank=0, pp_size=1): + self.rank_offset = rank_offset + self.size = size + self.dim_names = dim_names or [] + self._pp_group = MockProcessGroup(pp_rank, pp_size) + + def get_pg(self, dims): + if dims == "pp": + return self._pp_group + raise KeyError(f"Process group for {dims} not found") + + +class TestMimoModelNonColocated: + """Tests for non-colocated multi-module pipeline parallelism.""" + + def setup_method(self, method): + try: + Utils.initialize_model_parallel(1, 1) + except Exception: + pass + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.hidden_size = 64 + self.vocab_size = 48000 + self.seq_len = 256 + self.batch_size = 2 + self.img_h = 224 + self.img_w = 224 + self.patch_dim = 16 + + def teardown_method(self, method): + try: + Utils.destroy_model_parallel() + except Exception: + pass + + def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, pp_size=1): + """Helper to create MimoModelConfig with mock grids.""" + language_model_spec = get_language_model_spec( + self.hidden_size, self.vocab_size, self.seq_len + ) + vision_submodule_spec = get_vision_submodules_spec( + self.hidden_size, self.img_h, self.img_w, self.patch_dim + ) + + world_size = dist.get_world_size() + encoder_offset = 0 if encoder_in_grid else world_size + language_offset = 0 if language_in_grid else world_size + + return MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={ + "images": MockGrid( + rank_offset=encoder_offset, + size=world_size, + dim_names=["pp"] if pp_size > 1 else [], + pp_rank=pp_rank, + pp_size=pp_size, + ), + MIMO_LANGUAGE_MODULE_KEY: MockGrid( + rank_offset=language_offset, + size=world_size, + dim_names=["pp"] if pp_size > 1 else [], + pp_rank=pp_rank, + pp_size=pp_size, + ), + }, + ) + + def test_grid_validation_rejects_mismatched_keys(self): + """Test validation fails when grid_map keys don't match expected modules.""" + language_model_spec = get_language_model_spec( + self.hidden_size, self.vocab_size, self.seq_len + ) + vision_submodule_spec = get_vision_submodules_spec( + self.hidden_size, self.img_h, self.img_w, self.patch_dim + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={MIMO_LANGUAGE_MODULE_KEY: MockGrid()}, + ) + + with pytest.raises(ValueError, match="module_to_grid_map keys must match"): + MimoModel(mimo_config) + + def test_role_determination(self): + """Test role correctly identifies modules and stage positions.""" + # No grid map = colocated role with all modules + model_no_grid = get_vlm_mimo_model( + self.hidden_size, + self.vocab_size, + self.seq_len, + self.img_h, + self.img_w, + self.patch_dim, + {"images": 50257}, + ) + assert model_no_grid.role.mode == ModuleLayout.UNIFIED + assert model_no_grid.role.has_language_module is True + assert model_no_grid.role.has_modality_modules is True + + # Encoder-only rank + model_encoder = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + assert model_encoder.role.has_modality_modules is True + assert model_encoder.role.has_language_module is False + + # Language-only rank + model_language = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + assert model_language.role.has_modality_modules is False + assert model_language.role.has_language_module is True + + # Stage info with PP + model_pp = MimoModel( + self._make_config(encoder_in_grid=True, language_in_grid=True, pp_rank=1, pp_size=3) + ) + assert model_pp.role.is_first_stage("images") is False + assert model_pp.role.is_last_stage("images") is False + + def test_selective_init_encoder_only(self): + """Test encoder-only rank initializes encoder but not language model.""" + model = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + assert "images" in model.modality_submodules + assert model.language_model is None + + def test_selective_init_language_only(self): + """Test language-only rank initializes language model but not encoder.""" + model = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + assert "images" not in model.modality_submodules + assert model.language_model is not None + + def test_forward_encoder_only(self): + """Test encoder-only forward returns dict of embeddings.""" + model = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + model = model.to(self.device) + + images = torch.rand(2, 3, self.img_h, self.img_w, device=self.device) + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device + ) + + outputs, _ = model( + input_ids=input_ids, modality_inputs={"images": {"clip_encoder": {"x": images}}} + ) + assert isinstance(outputs, dict) + assert "images" in outputs + + def test_forward_language_only(self): + """Test language-only forward returns tensor.""" + model = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + model = model.to(self.device) + + img_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device + ) + input_ids[:, 5 : 5 + img_seq_len] = 50257 + position_ids = ( + torch.arange(self.seq_len, device=self.device).unsqueeze(0).expand(self.batch_size, -1) + ) + + encoder_embeddings = torch.randn( + self.batch_size * img_seq_len, self.hidden_size, device=self.device + ) + model.set_input_tensor({"images": encoder_embeddings}) + + outputs, _ = model(input_ids=input_ids, position_ids=position_ids, modality_inputs=None) + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) diff --git a/tests/unit_tests/models/test_mimo_role.py b/tests/unit_tests/models/test_mimo_role.py new file mode 100644 index 00000000000..e1ffe218083 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_role.py @@ -0,0 +1,47 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Tests for MIMO role data classes.""" + +import pytest + +from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole + + +class TestMimoRole: + """Tests for ModuleStageInfo and RankRole dataclasses.""" + + def test_module_stage_info(self): + """Test ModuleStageInfo creation and attributes.""" + first = ModuleStageInfo(is_first_stage=True, is_last_stage=False) + last = ModuleStageInfo(is_first_stage=False, is_last_stage=True) + only = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + + assert (first.is_first_stage, first.is_last_stage) == (True, False) + assert (last.is_first_stage, last.is_last_stage) == (False, True) + assert (only.is_first_stage, only.is_last_stage) == (True, True) + + def test_rank_role(self): + """Test RankRole properties and methods.""" + # Encoder-only role + encoder_role = RankRole(modules={"vision": ModuleStageInfo(True, False)}) + assert encoder_role.has_modality_modules is True + assert encoder_role.has_language_module is False + assert encoder_role.modality_module_names == ["vision"] + + # Language-only role + lang_role = RankRole(modules={"language": ModuleStageInfo(True, True)}) + assert lang_role.has_modality_modules is False + assert lang_role.has_language_module is True + + # Mixed role with stage checks + mixed = RankRole( + modules={ + "vision": ModuleStageInfo(is_first_stage=True, is_last_stage=False), + "language": ModuleStageInfo(is_first_stage=False, is_last_stage=True), + } + ) + assert mixed.is_first_stage("vision") is True + assert mixed.is_last_stage("vision") is False + assert mixed.is_first_stage("language") is False + assert mixed.is_last_stage("language") is True + assert mixed.is_first_stage("nonexistent") is False diff --git a/tests/unit_tests/models/test_mimo_submodules.py b/tests/unit_tests/models/test_mimo_submodules.py index 6111394cc13..5f8de29cc0f 100644 --- a/tests/unit_tests/models/test_mimo_submodules.py +++ b/tests/unit_tests/models/test_mimo_submodules.py @@ -303,3 +303,42 @@ def test_empty_data_batch(self): # Test forward pass output = self.vision_submodule(data_batch) assert output is None + + +@pytest.mark.experimental +class TestVisionSubmoduleStageAware: + """Tests for stage-aware forward in VisionModalitySubmodules.""" + + def test_stage_aware_forward(self): + """Test stage-aware forward: hidden_states input and projection skipping.""" + hidden_size = 64 + projection_size = 128 + hidden_states = torch.randn(10, hidden_size) + + # Default: first and last stage + submodule_default = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)] + ) + assert submodule_default.is_first_stage is True + assert submodule_default.is_last_stage is True + + # Non-first stage uses hidden_states, last stage projects + submodule_last = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=True, + ) + output = submodule_last.forward(hidden_states=hidden_states) + assert output.shape == (10, projection_size) # Projected + + # Non-last stage skips projection + submodule_mid = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=False, + ) + output = submodule_mid.forward(hidden_states=hidden_states) + assert output.shape == (10, hidden_size) # Not projected + + # No input returns None + assert submodule_mid.forward(hidden_states=None) is None